-
Notifications
You must be signed in to change notification settings - Fork 55
Accelerate the mean operations without axis #589
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -42,6 +42,8 @@ | |
| #include <functional> | ||
| #include <numeric> | ||
| #include <algorithm> | ||
| #include <thread> | ||
| #include <vector> | ||
|
|
||
| #if defined(_MSC_VER) | ||
| #include <BaseTsd.h> | ||
|
|
@@ -141,6 +143,250 @@ struct select_real_t<Complex<U>> | |
| using type = U; | ||
| }; | ||
|
|
||
| template <typename A, typename T> | ||
| class SimpleArrayMixinSum | ||
| { | ||
|
|
||
| private: | ||
|
|
||
| using internal_types = detail::SimpleArrayInternalTypes<T>; | ||
|
|
||
| public: | ||
|
|
||
| using value_type = typename internal_types::value_type; | ||
|
|
||
| value_type sum() const | ||
| { | ||
| auto athis = static_cast<A const *>(this); | ||
| if (athis->is_c_contiguous()) | ||
| { | ||
| return sum_contiguous(); | ||
| } | ||
| else | ||
| { | ||
| return sum_non_contiguous(); | ||
| } | ||
| } | ||
|
|
||
| value_type sum_contiguous() const | ||
| { | ||
| auto athis = static_cast<A const *>(this); | ||
| value_type result; | ||
| if constexpr (is_complex_v<value_type>) | ||
| { | ||
| result = value_type{}; | ||
| } | ||
| else | ||
| { | ||
| result = 0; | ||
| } | ||
| sum_unrolled_generic(athis->data(), athis->size(), 1, result); | ||
| return result; | ||
| } | ||
|
Comment on lines
+171
to
+185
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I forget to implement simd for common data type. Would it become a seperate pull request? |
||
|
|
||
| private: | ||
| value_type sum_non_contiguous() const | ||
| { | ||
| auto athis = static_cast<A const *>(this); | ||
| const size_t ndim = athis->ndim(); | ||
| const auto & shape = athis->shape(); | ||
| const auto & stride = athis->stride(); | ||
|
|
||
| // Calculate the size of the last dimension for loop unrolling | ||
| const size_t last_dim_size = shape[ndim - 1]; | ||
| const size_t last_stride = stride[ndim - 1]; | ||
|
|
||
| // Calculate total prefix combinations to decide on parallelization | ||
| const size_t total_combinations = calculate_total_combinations(shape, ndim - 1); | ||
|
|
||
| // Use parallel processing for large arrays | ||
| if (total_combinations > 10000) | ||
| { | ||
| return sum_non_contiguous_parallel(athis, shape, stride, last_dim_size, last_stride); | ||
| } | ||
| else | ||
| { | ||
| return sum_non_contiguous_sequential(athis, shape, stride, last_dim_size, last_stride); | ||
| } | ||
| } | ||
|
|
||
| value_type sum_non_contiguous_sequential(A const * athis, const small_vector<size_t> & shape, const small_vector<size_t> & stride, size_t last_dim_size, size_t last_stride) const | ||
| { | ||
| value_type result; | ||
| if constexpr (is_complex_v<value_type>) | ||
| { | ||
| result = value_type{}; | ||
| } | ||
| else | ||
| { | ||
| result = 0; | ||
| } | ||
|
|
||
| const size_t ndim = shape.size(); | ||
| small_vector<size_t> prefix_idx(ndim - 1, 0); | ||
|
|
||
| do | ||
| { | ||
| size_t base_offset = 0; | ||
| for (size_t i = 0; i < ndim - 1; ++i) | ||
| { | ||
| base_offset += prefix_idx[i] * stride[i]; | ||
| } | ||
|
|
||
| const value_type * data_ptr = athis->data() + base_offset; | ||
| sum_unrolled_generic(data_ptr, last_dim_size, last_stride, result); | ||
|
|
||
| } while (next_cartesian_product_prefix(prefix_idx, shape, ndim - 1)); | ||
|
|
||
| return result; | ||
| } | ||
|
|
||
| value_type sum_non_contiguous_parallel(A const * athis, const small_vector<size_t> & shape, const small_vector<size_t> & stride, size_t last_dim_size, size_t last_stride) const | ||
| { | ||
| const size_t ndim = shape.size(); | ||
| const size_t prefix_len = ndim - 1; | ||
| const size_t total_combinations = calculate_total_combinations(shape, prefix_len); | ||
|
|
||
| const size_t num_threads = static_cast<size_t>(std::thread::hardware_concurrency()); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are not ready for using threads. Without a system to control how to use threads from outside the computing kernel here, the performance and resource consumption are not predictable. |
||
| const size_t combinations_per_thread = (total_combinations + num_threads - 1) / num_threads; | ||
|
|
||
| small_vector<value_type> thread_results(num_threads); | ||
| std::vector<std::thread> threads; | ||
|
|
||
| for (size_t t = 0; t < num_threads; ++t) | ||
| { | ||
| const size_t start_idx = t * combinations_per_thread; | ||
| const size_t end_idx = std::min(start_idx + combinations_per_thread, total_combinations); | ||
|
|
||
| if (start_idx < total_combinations) | ||
| { | ||
| threads.emplace_back([this, athis, &shape, &stride, last_dim_size, last_stride, start_idx, end_idx, &thread_results, t, prefix_len]() | ||
| { | ||
| value_type local_result; | ||
| if constexpr (is_complex_v<value_type>) | ||
| { | ||
| local_result = value_type{}; | ||
| } | ||
| else | ||
| { | ||
| local_result = 0; | ||
| } | ||
|
|
||
| for (size_t combo_idx = start_idx; combo_idx < end_idx; ++combo_idx) | ||
| { | ||
| small_vector<size_t> prefix_idx(prefix_len); | ||
| size_t temp_idx = combo_idx; | ||
| for (size_t i = 0; i < prefix_len; ++i) | ||
| { | ||
| prefix_idx[i] = temp_idx % shape[i]; | ||
| temp_idx /= shape[i]; | ||
| } | ||
|
|
||
| size_t base_offset = 0; | ||
| for (size_t i = 0; i < prefix_len; ++i) | ||
| { | ||
| base_offset += prefix_idx[i] * stride[i]; | ||
| } | ||
|
|
||
| const value_type* data_ptr = athis->data() + base_offset; | ||
| sum_unrolled_generic(data_ptr, last_dim_size, last_stride, local_result); | ||
| } | ||
|
|
||
| thread_results[t] = local_result; }); | ||
| } | ||
| } | ||
|
|
||
| for (auto & thread : threads) | ||
| { | ||
| thread.join(); | ||
| } | ||
|
|
||
| value_type result; | ||
| if constexpr (is_complex_v<value_type>) | ||
| { | ||
| result = value_type{}; | ||
| } | ||
| else | ||
| { | ||
| result = 0; | ||
| } | ||
|
|
||
| for (const auto & thread_result : thread_results) | ||
| { | ||
| result += thread_result; | ||
| } | ||
|
|
||
| return result; | ||
| } | ||
|
|
||
| size_t calculate_total_combinations(const small_vector<size_t> & shape, size_t prefix_len) const | ||
| { | ||
| size_t total = 1; | ||
| for (size_t i = 0; i < prefix_len; ++i) | ||
| { | ||
| total *= shape[i]; | ||
| } | ||
| return total; | ||
| } | ||
|
|
||
| void sum_unrolled_generic(const value_type * data_ptr, size_t size, size_t stride, value_type & result) const | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure whether it is really unroll the loop.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's hard to tell. If you are not sure about it, why adding it? |
||
| { | ||
| if constexpr (!std::is_same_v<bool, std::remove_const_t<value_type>>) | ||
| { | ||
| const int unroll = (sizeof(value_type) <= 1 ? 16 : sizeof(value_type) <= 2 ? 8 | ||
| : sizeof(value_type) <= 4 ? 4 | ||
| : 2); | ||
|
|
||
| size_t i = 0; | ||
| for (; i + unroll <= size; i += unroll) | ||
| { | ||
| for (int u = 0; u < unroll; ++u) | ||
| { | ||
| result += data_ptr[(i + u) * stride]; | ||
| } | ||
| } | ||
|
|
||
| for (; i < size; ++i) | ||
| { | ||
| result += data_ptr[i * stride]; | ||
| } | ||
| } | ||
| else | ||
| { | ||
| const int unroll = 8; | ||
| size_t i = 0; | ||
|
|
||
| for (; i + unroll <= size; i += unroll) | ||
| { | ||
| for (int u = 0; u < unroll; ++u) | ||
| { | ||
| result |= data_ptr[(i + u) * stride]; | ||
| } | ||
| } | ||
|
|
||
| for (; i < size; ++i) | ||
| { | ||
| result |= data_ptr[i * stride]; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| bool next_cartesian_product_prefix(small_vector<size_t> & idx, | ||
| const small_vector<size_t> & shape, | ||
| size_t prefix_len) const | ||
| { | ||
| for (size_t i = prefix_len; i > 0; --i) | ||
| { | ||
| if (++idx[i - 1] < shape[i - 1]) | ||
| { | ||
| return true; | ||
| } | ||
| idx[i - 1] = 0; | ||
| } | ||
| return false; | ||
| } | ||
| }; /* end class SimpleArrayMixinSum */ | ||
|
|
||
| template <typename A, typename T> | ||
| class SimpleArrayMixinCalculators | ||
| { | ||
|
|
@@ -342,14 +588,8 @@ class SimpleArrayMixinCalculators | |
| value_type mean() const | ||
| { | ||
| auto athis = static_cast<A const *>(this); | ||
| auto sidx = athis->first_sidx(); | ||
| value_type sum = 0; | ||
| int64_t total = 0; | ||
| do | ||
| { | ||
| sum += athis->at(sidx); | ||
| ++total; | ||
| } while (athis->next_sidx(sidx)); | ||
| int64_t total = athis->size(); | ||
| value_type sum = athis->sum(); | ||
| return sum / static_cast<value_type>(total); | ||
| } | ||
|
|
||
|
|
@@ -464,36 +704,6 @@ class SimpleArrayMixinCalculators | |
| return initial; | ||
| } | ||
|
|
||
| value_type sum() const | ||
| { | ||
| value_type initial; | ||
| if constexpr (is_complex_v<value_type>) | ||
| { | ||
| initial = value_type(); | ||
| } | ||
| else | ||
| { | ||
| initial = 0; | ||
| } | ||
|
|
||
| auto athis = static_cast<A const *>(this); | ||
| if constexpr (!std::is_same_v<bool, std::remove_const_t<value_type>>) | ||
| { | ||
| for (size_t i = 0; i < athis->size(); ++i) | ||
| { | ||
| initial += athis->data(i); | ||
| } | ||
| } | ||
| else | ||
| { | ||
| for (size_t i = 0; i < athis->size(); ++i) | ||
| { | ||
| initial |= athis->data(i); | ||
| } | ||
| } | ||
| return initial; | ||
| } | ||
|
|
||
| A abs() const | ||
| { | ||
| auto athis = static_cast<A const *>(this); | ||
|
|
@@ -1014,6 +1224,7 @@ class SimpleArrayMixinSearch | |
| template <typename T> | ||
| class SimpleArray | ||
| : public detail::SimpleArrayMixinModifiers<SimpleArray<T>, T> | ||
| , public detail::SimpleArrayMixinSum<SimpleArray<T>, T> | ||
| , public detail::SimpleArrayMixinCalculators<SimpleArray<T>, T> | ||
| , public detail::SimpleArrayMixinSort<SimpleArray<T>, T> | ||
| , public detail::SimpleArrayMixinSearch<SimpleArray<T>, T> | ||
|
|
@@ -1464,21 +1675,33 @@ class SimpleArray | |
| value_type const * body() const { return m_body; } | ||
| value_type * body() { return m_body; } | ||
|
|
||
| bool is_c_contiguous() const { return is_c_contiguous(m_shape, m_stride); } | ||
|
|
||
| private: | ||
| void check_c_contiguous(small_vector<size_t> const & shape, | ||
| small_vector<size_t> const & stride) const | ||
| bool is_c_contiguous(small_vector<size_t> const & shape, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can be static. |
||
| small_vector<size_t> const & stride) const | ||
| { | ||
| if (stride[stride.size() - 1] != 1) | ||
| { | ||
| throw std::runtime_error("SimpleArray: C contiguous stride must end with 1"); | ||
| return false; | ||
| } | ||
| for (size_t it = 0; it < shape.size() - 1; ++it) | ||
| { | ||
| if (stride[it] != shape[it + 1] * stride[it + 1]) | ||
| { | ||
| throw std::runtime_error("SimpleArray: C contiguous stride must match shape"); | ||
| return false; | ||
| } | ||
| } | ||
| return true; | ||
| } | ||
|
|
||
| void check_c_contiguous(small_vector<size_t> const & shape, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can be static. |
||
| small_vector<size_t> const & stride) const | ||
| { | ||
| if (!is_c_contiguous(shape, stride)) | ||
| { | ||
| throw std::runtime_error("SimpleArray: C contiguous stride must match shape and end with 1"); | ||
| } | ||
| } | ||
|
|
||
| void check_f_contiguous(small_vector<size_t> const & shape, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move sum operation to a seperate class because of complex optimization.