From bbb6f7339acd6b940b0936d372ab14c259277d57 Mon Sep 17 00:00:00 2001 From: Cedric AUGONNET Date: Thu, 5 Mar 2026 15:53:33 +0100 Subject: [PATCH 01/10] Remove unused exec_affinity.cuh include from place_partition.cuh place_partition.cuh does not use exec_affinity; the include was a leftover. Made-with: Cursor --- cudax/include/cuda/experimental/__stf/places/place_partition.cuh | 1 - 1 file changed, 1 deletion(-) diff --git a/cudax/include/cuda/experimental/__stf/places/place_partition.cuh b/cudax/include/cuda/experimental/__stf/places/place_partition.cuh index a5af120bac9..32f7e5f93f3 100644 --- a/cudax/include/cuda/experimental/__stf/places/place_partition.cuh +++ b/cudax/include/cuda/experimental/__stf/places/place_partition.cuh @@ -25,7 +25,6 @@ # pragma system_header #endif // no system header -#include #include #include #include From 7fd0545ea8b134c1c51837c734f3450bd1865b9d Mon Sep 17 00:00:00 2001 From: Cedric AUGONNET Date: Fri, 6 Mar 2026 16:09:50 +0100 Subject: [PATCH 02/10] Add missing changes --- .../cuda/experimental/__stf/places/places.cuh | 136 ++++++++---------- 1 file changed, 63 insertions(+), 73 deletions(-) diff --git a/cudax/include/cuda/experimental/__stf/places/places.cuh b/cudax/include/cuda/experimental/__stf/places/places.cuh index 1b6f738cb65..d13105422a8 100644 --- a/cudax/include/cuda/experimental/__stf/places/places.cuh +++ b/cudax/include/cuda/experimental/__stf/places/places.cuh @@ -27,7 +27,6 @@ # pragma system_header #endif // no system header -#include #include #include #include @@ -50,6 +49,9 @@ namespace cuda::experimental::stf { +template +class place_indexed_container; + class exec_place; class exec_place_host; class exec_place_grid; @@ -336,7 +338,7 @@ public: return ::std::hash()(devid); } - decorated_stream getDataStream(async_resources_handle& async_resources) const; + decorated_stream getDataStream(place_indexed_container<::std::pair>& stream_pools) const; private: /** @@ -710,23 +712,19 @@ public: return device_ordinal(affine) < device_ordinal(rhs.affine); } - /* Return the pool associated to this place + /* Return a pointer to a locally-owned stream pool, or nullptr. + * + * If this place owns its own stream pool (e.g. green contexts, CUDA stream + * places), the override returns a non-null pointer. Otherwise the base + * implementation returns nullptr, which tells exec_place::get_stream_pool to + * fall back to a container-based lookup. * - * If the stream is expected to perform computation, the - * for_computation should be true. If we plan to use this stream for data - * transfers, or other means (graph capture) we set the value to false. (This - * flag is intended for performance matters, not correctness) + * @param for_computation true for a computation pool, false for data transfer + * @return A pointer to the locally-owned pool, or nullptr. */ - virtual stream_pool& get_stream_pool(async_resources_handle& async_resources, bool for_computation) const + virtual stream_pool* get_local_stream_pool(bool) const { - if (!affine.is_device()) - { - fprintf(stderr, "Error: get_stream_pool virtual method is not implemented for this exec place.\n"); - abort(); - } - - int dev_id = device_ordinal(affine); - return async_resources.get_device_stream_pool(dev_id, for_computation); + return nullptr; } /** @@ -868,10 +866,8 @@ public: pimpl->set_affine_data_place(mv(place)); } - stream_pool& get_stream_pool(async_resources_handle& async_resources, bool for_computation) const - { - return pimpl->get_stream_pool(async_resources, for_computation); - } + stream_pool& get_stream_pool(place_indexed_container<::std::pair>& pools, + bool for_computation) const; /** * @brief Get a decorated stream from the stream pool associated to this execution place. @@ -880,19 +876,16 @@ public: * a CUDASTF context. This is useful when you want to use CUDASTF's place abstractions * (devices, green contexts) for stream management without the full task-based model. * - * @note If you are using a CUDASTF context, use `ctx.async_resources()` to ensure the - * same stream pools are shared between your code and the context's internal operations. - * - * @param async_resources Handle managing the stream pools. Create a standalone - * `async_resources_handle` for context-free usage, or use `ctx.async_resources()` - * when working alongside a CUDASTF context. + * @param stream_pools Container mapping each place to a pair of stream pools + * (computation pool, data transfer pool). * @param for_computation Hint for selecting which pool to use. When true, returns a stream * from the computation pool; when false, returns a stream from the data transfer pool. * Using separate pools for computation and transfers can improve overlapping. * This is a performance hint and does not affect correctness. * @return A decorated_stream containing the CUDA stream and metadata (device ID, pool index) */ - decorated_stream getStream(async_resources_handle& async_resources, bool for_computation) const; + decorated_stream getStream(place_indexed_container<::std::pair>& stream_pools, + bool for_computation) const; /** * @brief Create a stream valid for execution on this place. @@ -914,49 +907,38 @@ public: * a CUDASTF context. This is useful when you want to use CUDASTF's place abstractions * (devices, green contexts) for stream management without the full task-based model. * - * Example usage without a context: - * @code - * async_resources_handle resources; - * exec_place place = exec_place::device(0); - * cudaStream_t stream = place.pick_stream(resources); - * myKernel<<>>(...); - * @endcode - * * Example usage with a context (sharing resources): * @code * stream_ctx ctx; * exec_place place = exec_place::device(0); - * cudaStream_t stream = place.pick_stream(ctx.async_resources()); - * // Stream comes from the same pool used by ctx internally + * cudaStream_t stream = place.pick_stream(ctx.async_resources().stream_pools()); * @endcode * - * @note If you are using a CUDASTF context, use `ctx.async_resources()` to ensure the - * same stream pools are shared between your code and the context's internal operations. - * - * @param async_resources Handle managing the stream pools. Create a standalone - * `async_resources_handle` for context-free usage, or use `ctx.async_resources()` - * when working alongside a CUDASTF context. + * @param stream_pools Container mapping each place to a pair of stream pools + * (computation pool, data transfer pool). * @param for_computation Hint for selecting which pool to use. When true, returns a stream * from the computation pool; when false, returns a stream from the data transfer pool. * Using separate pools for computation and transfers can improve overlapping. * This is a performance hint and does not affect correctness. Defaults to true. * @return A CUDA stream associated with this execution place */ - cudaStream_t pick_stream(async_resources_handle& async_resources, bool for_computation = true) const + cudaStream_t pick_stream(place_indexed_container<::std::pair>& stream_pools, + bool for_computation = true) const { - return getStream(async_resources, for_computation).stream; + return getStream(stream_pools, for_computation).stream; } /** * @brief Get the number of streams available in the pool for this execution place. * - * @param async_resources Handle managing the stream pools + * @param pools Container mapping each place to a pair of stream pools * @param for_computation Hint for selecting which pool to query (computation or transfer pool) * @return The number of stream slots in the pool */ - size_t stream_pool_size(async_resources_handle& async_resources, bool for_computation = true) const + size_t stream_pool_size(place_indexed_container<::std::pair>& pools, + bool for_computation = true) const { - return get_stream_pool(async_resources, for_computation).size(); + return get_stream_pool(pools, for_computation).size(); } /** @@ -966,14 +948,14 @@ public: * created lazily, so calling this method will create any streams that haven't been * created yet. * - * @param async_resources Handle managing the stream pools + * @param pools Container mapping each place to a pair of stream pools * @param for_computation Hint for selecting which pool to use (computation or transfer pool) * @return A vector of CUDA streams from the pool */ - ::std::vector - pick_all_streams(async_resources_handle& async_resources, bool for_computation = true) const + ::std::vector pick_all_streams(place_indexed_container<::std::pair>& pools, + bool for_computation = true) const { - auto& pool = get_stream_pool(async_resources, for_computation); + auto& pool = get_stream_pool(pools, for_computation); ::std::vector result; result.reserve(pool.size()); for (size_t i = 0; i < pool.size(); ++i) @@ -1231,9 +1213,10 @@ inline decorated_stream stream_pool::next(const exec_place& place) return result; } -inline decorated_stream exec_place::getStream(async_resources_handle& async_resources, bool for_computation) const +inline decorated_stream exec_place::getStream( + place_indexed_container<::std::pair>& stream_pools, bool for_computation) const { - return get_stream_pool(async_resources, for_computation).next(*this); + return get_stream_pool(stream_pools, for_computation).next(*this); } /** @@ -1266,12 +1249,6 @@ public: { return data_place::host(); } - virtual stream_pool& get_stream_pool(async_resources_handle& async_resources, bool for_computation) const override - { - // There is no pool attached to the host itself, so we use the pool attached to the execution place of the - // current device - return exec_place::current_device().get_stream_pool(async_resources, for_computation); - } }; static ::std::shared_ptr make() @@ -1574,17 +1551,6 @@ public: return coords_to_place(p_index); } - virtual stream_pool& get_stream_pool(async_resources_handle& async_resources, bool for_computation) const override - { - // We "arbitrarily" select a pool from one of the place in the - // grid, which can be suffiicent for a data transfer, but we do not - // want to allow this for computation where we expect a more - // accurate placement. - assert(!for_computation); - assert(places.size() > 0); - return places[0].get_stream_pool(async_resources, for_computation); - } - private: // What is the execution place at theses coordinates in the exec place grid ? const exec_place& coords_to_place(size_t c0, size_t c1 = 0, size_t c2 = 0, size_t c3 = 0) const @@ -1967,9 +1933,10 @@ inline exec_place data_place::affine_exec_place() const return exec_place::device(devid); } -inline decorated_stream data_place::getDataStream(async_resources_handle& async_resources) const +inline decorated_stream +data_place::getDataStream(place_indexed_container<::std::pair>& stream_pools) const { - return affine_exec_place().getStream(async_resources, false); + return affine_exec_place().getStream(stream_pools, false); } inline const exec_place_grid& data_place::get_grid() const @@ -2329,6 +2296,29 @@ struct hash } }; +#include + +// We need the implementation of exec_place_grid, exec_place_device, ... +inline stream_pool& exec_place::get_stream_pool(place_indexed_container<::std::pair>& pools, + bool for_computation) const +{ + if (auto* local_pool = pimpl->get_local_stream_pool(for_computation)) + { + return *local_pool; + } + if (pimpl->is_host()) + { + return exec_place::current_device().get_stream_pool(pools, for_computation); + } + if (pimpl->is_grid()) + { + const auto& places = as_grid().get_places(); + _CCCL_ASSERT(places.size() > 0, "exec_place_grid has no places"); + return places[0].get_stream_pool(pools, for_computation); + } + return for_computation ? pools[*this].first : pools[*this].second; +} + #ifdef UNITTESTED_FILE UNITTEST("Data place as unordered_map key") { From b4f7c9d91025d53a7a56f6c3b7749102e3dd41ec Mon Sep 17 00:00:00 2001 From: Cedric AUGONNET Date: Fri, 6 Mar 2026 16:11:39 +0100 Subject: [PATCH 03/10] Revert "Add missing changes" (not intended for this branch) This reverts commit 7fd0545ea8b134c1c51837c734f3450bd1865b9d. --- .../cuda/experimental/__stf/places/places.cuh | 136 ++++++++++-------- 1 file changed, 73 insertions(+), 63 deletions(-) diff --git a/cudax/include/cuda/experimental/__stf/places/places.cuh b/cudax/include/cuda/experimental/__stf/places/places.cuh index d13105422a8..1b6f738cb65 100644 --- a/cudax/include/cuda/experimental/__stf/places/places.cuh +++ b/cudax/include/cuda/experimental/__stf/places/places.cuh @@ -27,6 +27,7 @@ # pragma system_header #endif // no system header +#include #include #include #include @@ -49,9 +50,6 @@ namespace cuda::experimental::stf { -template -class place_indexed_container; - class exec_place; class exec_place_host; class exec_place_grid; @@ -338,7 +336,7 @@ public: return ::std::hash()(devid); } - decorated_stream getDataStream(place_indexed_container<::std::pair>& stream_pools) const; + decorated_stream getDataStream(async_resources_handle& async_resources) const; private: /** @@ -712,19 +710,23 @@ public: return device_ordinal(affine) < device_ordinal(rhs.affine); } - /* Return a pointer to a locally-owned stream pool, or nullptr. - * - * If this place owns its own stream pool (e.g. green contexts, CUDA stream - * places), the override returns a non-null pointer. Otherwise the base - * implementation returns nullptr, which tells exec_place::get_stream_pool to - * fall back to a container-based lookup. + /* Return the pool associated to this place * - * @param for_computation true for a computation pool, false for data transfer - * @return A pointer to the locally-owned pool, or nullptr. + * If the stream is expected to perform computation, the + * for_computation should be true. If we plan to use this stream for data + * transfers, or other means (graph capture) we set the value to false. (This + * flag is intended for performance matters, not correctness) */ - virtual stream_pool* get_local_stream_pool(bool) const + virtual stream_pool& get_stream_pool(async_resources_handle& async_resources, bool for_computation) const { - return nullptr; + if (!affine.is_device()) + { + fprintf(stderr, "Error: get_stream_pool virtual method is not implemented for this exec place.\n"); + abort(); + } + + int dev_id = device_ordinal(affine); + return async_resources.get_device_stream_pool(dev_id, for_computation); } /** @@ -866,8 +868,10 @@ public: pimpl->set_affine_data_place(mv(place)); } - stream_pool& get_stream_pool(place_indexed_container<::std::pair>& pools, - bool for_computation) const; + stream_pool& get_stream_pool(async_resources_handle& async_resources, bool for_computation) const + { + return pimpl->get_stream_pool(async_resources, for_computation); + } /** * @brief Get a decorated stream from the stream pool associated to this execution place. @@ -876,16 +880,19 @@ public: * a CUDASTF context. This is useful when you want to use CUDASTF's place abstractions * (devices, green contexts) for stream management without the full task-based model. * - * @param stream_pools Container mapping each place to a pair of stream pools - * (computation pool, data transfer pool). + * @note If you are using a CUDASTF context, use `ctx.async_resources()` to ensure the + * same stream pools are shared between your code and the context's internal operations. + * + * @param async_resources Handle managing the stream pools. Create a standalone + * `async_resources_handle` for context-free usage, or use `ctx.async_resources()` + * when working alongside a CUDASTF context. * @param for_computation Hint for selecting which pool to use. When true, returns a stream * from the computation pool; when false, returns a stream from the data transfer pool. * Using separate pools for computation and transfers can improve overlapping. * This is a performance hint and does not affect correctness. * @return A decorated_stream containing the CUDA stream and metadata (device ID, pool index) */ - decorated_stream getStream(place_indexed_container<::std::pair>& stream_pools, - bool for_computation) const; + decorated_stream getStream(async_resources_handle& async_resources, bool for_computation) const; /** * @brief Create a stream valid for execution on this place. @@ -907,38 +914,49 @@ public: * a CUDASTF context. This is useful when you want to use CUDASTF's place abstractions * (devices, green contexts) for stream management without the full task-based model. * + * Example usage without a context: + * @code + * async_resources_handle resources; + * exec_place place = exec_place::device(0); + * cudaStream_t stream = place.pick_stream(resources); + * myKernel<<>>(...); + * @endcode + * * Example usage with a context (sharing resources): * @code * stream_ctx ctx; * exec_place place = exec_place::device(0); - * cudaStream_t stream = place.pick_stream(ctx.async_resources().stream_pools()); + * cudaStream_t stream = place.pick_stream(ctx.async_resources()); + * // Stream comes from the same pool used by ctx internally * @endcode * - * @param stream_pools Container mapping each place to a pair of stream pools - * (computation pool, data transfer pool). + * @note If you are using a CUDASTF context, use `ctx.async_resources()` to ensure the + * same stream pools are shared between your code and the context's internal operations. + * + * @param async_resources Handle managing the stream pools. Create a standalone + * `async_resources_handle` for context-free usage, or use `ctx.async_resources()` + * when working alongside a CUDASTF context. * @param for_computation Hint for selecting which pool to use. When true, returns a stream * from the computation pool; when false, returns a stream from the data transfer pool. * Using separate pools for computation and transfers can improve overlapping. * This is a performance hint and does not affect correctness. Defaults to true. * @return A CUDA stream associated with this execution place */ - cudaStream_t pick_stream(place_indexed_container<::std::pair>& stream_pools, - bool for_computation = true) const + cudaStream_t pick_stream(async_resources_handle& async_resources, bool for_computation = true) const { - return getStream(stream_pools, for_computation).stream; + return getStream(async_resources, for_computation).stream; } /** * @brief Get the number of streams available in the pool for this execution place. * - * @param pools Container mapping each place to a pair of stream pools + * @param async_resources Handle managing the stream pools * @param for_computation Hint for selecting which pool to query (computation or transfer pool) * @return The number of stream slots in the pool */ - size_t stream_pool_size(place_indexed_container<::std::pair>& pools, - bool for_computation = true) const + size_t stream_pool_size(async_resources_handle& async_resources, bool for_computation = true) const { - return get_stream_pool(pools, for_computation).size(); + return get_stream_pool(async_resources, for_computation).size(); } /** @@ -948,14 +966,14 @@ public: * created lazily, so calling this method will create any streams that haven't been * created yet. * - * @param pools Container mapping each place to a pair of stream pools + * @param async_resources Handle managing the stream pools * @param for_computation Hint for selecting which pool to use (computation or transfer pool) * @return A vector of CUDA streams from the pool */ - ::std::vector pick_all_streams(place_indexed_container<::std::pair>& pools, - bool for_computation = true) const + ::std::vector + pick_all_streams(async_resources_handle& async_resources, bool for_computation = true) const { - auto& pool = get_stream_pool(pools, for_computation); + auto& pool = get_stream_pool(async_resources, for_computation); ::std::vector result; result.reserve(pool.size()); for (size_t i = 0; i < pool.size(); ++i) @@ -1213,10 +1231,9 @@ inline decorated_stream stream_pool::next(const exec_place& place) return result; } -inline decorated_stream exec_place::getStream( - place_indexed_container<::std::pair>& stream_pools, bool for_computation) const +inline decorated_stream exec_place::getStream(async_resources_handle& async_resources, bool for_computation) const { - return get_stream_pool(stream_pools, for_computation).next(*this); + return get_stream_pool(async_resources, for_computation).next(*this); } /** @@ -1249,6 +1266,12 @@ public: { return data_place::host(); } + virtual stream_pool& get_stream_pool(async_resources_handle& async_resources, bool for_computation) const override + { + // There is no pool attached to the host itself, so we use the pool attached to the execution place of the + // current device + return exec_place::current_device().get_stream_pool(async_resources, for_computation); + } }; static ::std::shared_ptr make() @@ -1551,6 +1574,17 @@ public: return coords_to_place(p_index); } + virtual stream_pool& get_stream_pool(async_resources_handle& async_resources, bool for_computation) const override + { + // We "arbitrarily" select a pool from one of the place in the + // grid, which can be suffiicent for a data transfer, but we do not + // want to allow this for computation where we expect a more + // accurate placement. + assert(!for_computation); + assert(places.size() > 0); + return places[0].get_stream_pool(async_resources, for_computation); + } + private: // What is the execution place at theses coordinates in the exec place grid ? const exec_place& coords_to_place(size_t c0, size_t c1 = 0, size_t c2 = 0, size_t c3 = 0) const @@ -1933,10 +1967,9 @@ inline exec_place data_place::affine_exec_place() const return exec_place::device(devid); } -inline decorated_stream -data_place::getDataStream(place_indexed_container<::std::pair>& stream_pools) const +inline decorated_stream data_place::getDataStream(async_resources_handle& async_resources) const { - return affine_exec_place().getStream(stream_pools, false); + return affine_exec_place().getStream(async_resources, false); } inline const exec_place_grid& data_place::get_grid() const @@ -2296,29 +2329,6 @@ struct hash } }; -#include - -// We need the implementation of exec_place_grid, exec_place_device, ... -inline stream_pool& exec_place::get_stream_pool(place_indexed_container<::std::pair>& pools, - bool for_computation) const -{ - if (auto* local_pool = pimpl->get_local_stream_pool(for_computation)) - { - return *local_pool; - } - if (pimpl->is_host()) - { - return exec_place::current_device().get_stream_pool(pools, for_computation); - } - if (pimpl->is_grid()) - { - const auto& places = as_grid().get_places(); - _CCCL_ASSERT(places.size() > 0, "exec_place_grid has no places"); - return places[0].get_stream_pool(pools, for_computation); - } - return for_computation ? pools[*this].first : pools[*this].second; -} - #ifdef UNITTESTED_FILE UNITTEST("Data place as unordered_map key") { From 72eb9ffd95c15cca5cdc92dd632bbcd78fff7238 Mon Sep 17 00:00:00 2001 From: Cedric AUGONNET Date: Mon, 9 Mar 2026 18:21:27 +0100 Subject: [PATCH 04/10] Move loop_dispatch.cuh to a more STF specific header --- .../experimental/__stf/{places => internal}/loop_dispatch.cuh | 0 cudax/test/stf/loop_dispatch/loop_dispatch.cu | 2 +- cudax/test/stf/loop_dispatch/nested_loop_dispatch.cu | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename cudax/include/cuda/experimental/__stf/{places => internal}/loop_dispatch.cuh (100%) diff --git a/cudax/include/cuda/experimental/__stf/places/loop_dispatch.cuh b/cudax/include/cuda/experimental/__stf/internal/loop_dispatch.cuh similarity index 100% rename from cudax/include/cuda/experimental/__stf/places/loop_dispatch.cuh rename to cudax/include/cuda/experimental/__stf/internal/loop_dispatch.cuh diff --git a/cudax/test/stf/loop_dispatch/loop_dispatch.cu b/cudax/test/stf/loop_dispatch/loop_dispatch.cu index 86369203896..586e4114b37 100644 --- a/cudax/test/stf/loop_dispatch/loop_dispatch.cu +++ b/cudax/test/stf/loop_dispatch/loop_dispatch.cu @@ -8,7 +8,7 @@ // //===----------------------------------------------------------------------===// -#include +#include #include using namespace cuda::experimental::stf; diff --git a/cudax/test/stf/loop_dispatch/nested_loop_dispatch.cu b/cudax/test/stf/loop_dispatch/nested_loop_dispatch.cu index 29c8a455cb7..c4d855ccea5 100644 --- a/cudax/test/stf/loop_dispatch/nested_loop_dispatch.cu +++ b/cudax/test/stf/loop_dispatch/nested_loop_dispatch.cu @@ -10,7 +10,7 @@ #include -#include +#include #include using namespace cuda::experimental::stf; From 3f54633ecf9bc5332dbcadfd5a0855f565828edf Mon Sep 17 00:00:00 2001 From: Cedric AUGONNET Date: Mon, 9 Mar 2026 18:43:16 +0100 Subject: [PATCH 05/10] Move places/inner_shape.cuh in internal/inner_shape.cuh --- cudax/include/cuda/experimental/__stf/internal/context.cuh | 2 +- .../experimental/__stf/{places => internal}/inner_shape.cuh | 0 cudax/include/cuda/experimental/stf.cuh | 2 +- cudax/test/stf/CMakeLists.txt | 2 +- 4 files changed, 3 insertions(+), 3 deletions(-) rename cudax/include/cuda/experimental/__stf/{places => internal}/inner_shape.cuh (100%) diff --git a/cudax/include/cuda/experimental/__stf/internal/context.cuh b/cudax/include/cuda/experimental/__stf/internal/context.cuh index a4af539b9d5..58b97a42dc2 100644 --- a/cudax/include/cuda/experimental/__stf/internal/context.cuh +++ b/cudax/include/cuda/experimental/__stf/internal/context.cuh @@ -28,7 +28,7 @@ #include #include #include -#include +#include #include #include diff --git a/cudax/include/cuda/experimental/__stf/places/inner_shape.cuh b/cudax/include/cuda/experimental/__stf/internal/inner_shape.cuh similarity index 100% rename from cudax/include/cuda/experimental/__stf/places/inner_shape.cuh rename to cudax/include/cuda/experimental/__stf/internal/inner_shape.cuh diff --git a/cudax/include/cuda/experimental/stf.cuh b/cudax/include/cuda/experimental/stf.cuh index 664e76b666e..82ce5f2fa01 100644 --- a/cudax/include/cuda/experimental/stf.cuh +++ b/cudax/include/cuda/experimental/stf.cuh @@ -29,6 +29,6 @@ #include #include #include -#include +#include #include #include diff --git a/cudax/test/stf/CMakeLists.txt b/cudax/test/stf/CMakeLists.txt index 877f4259cd1..dd78f1c5f15 100644 --- a/cudax/test/stf/CMakeLists.txt +++ b/cudax/test/stf/CMakeLists.txt @@ -193,7 +193,7 @@ set( cuda/experimental/__stf/internal/slice.cuh cuda/experimental/__stf/internal/thread_hierarchy.cuh cuda/experimental/__stf/places/cyclic_shape.cuh - cuda/experimental/__stf/places/inner_shape.cuh + cuda/experimental/__stf/internal/inner_shape.cuh cuda/experimental/__stf/places/places.cuh cuda/experimental/__stf/places/tiled_partition.cuh cuda/experimental/__stf/stream/stream_ctx.cuh From 7e1ff11d18952b8e20949151237ca909bb1dce8a Mon Sep 17 00:00:00 2001 From: Cedric AUGONNET Date: Mon, 9 Mar 2026 21:08:18 +0100 Subject: [PATCH 06/10] clang-format --- cudax/include/cuda/experimental/__stf/internal/context.cuh | 2 +- cudax/include/cuda/experimental/stf.cuh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cudax/include/cuda/experimental/__stf/internal/context.cuh b/cudax/include/cuda/experimental/__stf/internal/context.cuh index 58b97a42dc2..af7cdf6b775 100644 --- a/cudax/include/cuda/experimental/__stf/internal/context.cuh +++ b/cudax/include/cuda/experimental/__stf/internal/context.cuh @@ -23,12 +23,12 @@ #include #include #include +#include #include #include #include #include #include -#include #include #include diff --git a/cudax/include/cuda/experimental/stf.cuh b/cudax/include/cuda/experimental/stf.cuh index 82ce5f2fa01..25d8751dd85 100644 --- a/cudax/include/cuda/experimental/stf.cuh +++ b/cudax/include/cuda/experimental/stf.cuh @@ -23,12 +23,12 @@ #include // #include #include +#include #include #include #include #include #include #include -#include #include #include From 67b6511c070cc85225d0f869dff97227ba1279f4 Mon Sep 17 00:00:00 2001 From: Cedric AUGONNET Date: Mon, 9 Mar 2026 21:08:58 +0100 Subject: [PATCH 07/10] add a missing header --- cudax/include/cuda/experimental/__stf/places/place_partition.cuh | 1 + 1 file changed, 1 insertion(+) diff --git a/cudax/include/cuda/experimental/__stf/places/place_partition.cuh b/cudax/include/cuda/experimental/__stf/places/place_partition.cuh index 9c6b5a0e496..d7cd70dca54 100644 --- a/cudax/include/cuda/experimental/__stf/places/place_partition.cuh +++ b/cudax/include/cuda/experimental/__stf/places/place_partition.cuh @@ -25,6 +25,7 @@ # pragma system_header #endif // no system header +#include #include #include #include From a99f9d116abf6f51434758abfd636c05fcaf585a Mon Sep 17 00:00:00 2001 From: Cedric AUGONNET Date: Mon, 9 Mar 2026 21:52:11 +0100 Subject: [PATCH 08/10] Use snake_case for get_stream instead of getStream --- .../cuda/experimental/__stf/graph/graph_task.cuh | 6 +++--- .../cuda/experimental/__stf/places/places.cuh | 14 +++++++------- .../experimental/__stf/stream/interfaces/slice.cuh | 2 +- .../__stf/stream/internal/event_types.cuh | 2 +- .../cuda/experimental/__stf/stream/reduction.cuh | 4 ++-- .../cuda/experimental/__stf/stream/stream_ctx.cuh | 6 +++--- .../cuda/experimental/__stf/stream/stream_task.cuh | 4 ++-- cudax/test/stf/cpp/test_pick_stream.cu | 10 +++++----- .../test/stf/cpp/test_pick_stream_green_context.cu | 4 ++-- 9 files changed, 26 insertions(+), 26 deletions(-) diff --git a/cudax/include/cuda/experimental/__stf/graph/graph_task.cuh b/cudax/include/cuda/experimental/__stf/graph/graph_task.cuh index 39375f6db2d..20080bd6d81 100644 --- a/cudax/include/cuda/experimental/__stf/graph/graph_task.cuh +++ b/cudax/include/cuda/experimental/__stf/graph/graph_task.cuh @@ -103,7 +103,7 @@ public: if (is_capture_enabled()) { // Select a stream from the pool - capture_stream = get_exec_place().getStream(true).stream; + capture_stream = get_exec_place().get_stream(true).stream; // Use relaxed capture mode to allow capturing workloads that lazily initialize // resources (e.g., set up memory pools) cuda_safe_call(cudaStreamBeginCapture(capture_stream, cudaStreamCaptureModeRelaxed)); @@ -366,7 +366,7 @@ public: // // Get a stream from the pool associated to the execution place - capture_stream = get_exec_place().getStream(true).stream; + capture_stream = get_exec_place().get_stream(true).stream; cudaGraph_t childGraph = nullptr; // Use relaxed capture mode to allow capturing workloads that lazily initialize @@ -628,7 +628,7 @@ public: auto lock = lock_ctx_graph(); // Get a stream from the pool associated to the execution place - cudaStream_t capture_stream = get_exec_place().getStream(true).stream; + cudaStream_t capture_stream = get_exec_place().get_stream(true).stream; cudaGraph_t childGraph = nullptr; // Use relaxed capture mode to allow capturing workloads that lazily initialize diff --git a/cudax/include/cuda/experimental/__stf/places/places.cuh b/cudax/include/cuda/experimental/__stf/places/places.cuh index 9a9a0d06c18..1cf32ea166b 100644 --- a/cudax/include/cuda/experimental/__stf/places/places.cuh +++ b/cudax/include/cuda/experimental/__stf/places/places.cuh @@ -333,7 +333,7 @@ public: return ::std::hash()(devid); } - decorated_stream getDataStream() const; + decorated_stream get_data_stream() const; private: /** @@ -867,13 +867,13 @@ public: /** * @brief Get a decorated stream from the stream pool associated to this execution place. */ - decorated_stream getStream(bool for_computation) const; + decorated_stream get_stream(bool for_computation) const; /** * @brief Create a stream valid for execution on this place. * * Call only when the place is already activated (e.g. inside exec_place_guard). - * For getting a stream from the pool, use getStream() / pick_stream() instead. + * For getting a stream from the pool, use get_stream() / pick_stream() instead. * * @return A CUDA stream valid for this execution place */ @@ -884,7 +884,7 @@ public: cudaStream_t pick_stream(bool for_computation = true) const { - return getStream(for_computation).stream; + return get_stream(for_computation).stream; } // TODO make protected ! @@ -1135,7 +1135,7 @@ inline decorated_stream stream_pool::next(const exec_place& place) return result; } -inline decorated_stream exec_place::getStream(bool for_computation) const +inline decorated_stream exec_place::get_stream(bool for_computation) const { return get_stream_pool(for_computation).next(*this); } @@ -1883,9 +1883,9 @@ inline exec_place data_place::affine_exec_place() const return exec_place::device(devid); } -inline decorated_stream data_place::getDataStream() const +inline decorated_stream data_place::get_data_stream() const { - return affine_exec_place().getStream(false); + return affine_exec_place().get_stream(false); } inline const exec_place_grid& data_place::get_grid() const diff --git a/cudax/include/cuda/experimental/__stf/stream/interfaces/slice.cuh b/cudax/include/cuda/experimental/__stf/stream/interfaces/slice.cuh index 21ac9da53a6..f2bcab12454 100644 --- a/cudax/include/cuda/experimental/__stf/stream/interfaces/slice.cuh +++ b/cudax/include/cuda/experimental/__stf/stream/interfaces/slice.cuh @@ -203,7 +203,7 @@ public: // static_assert(dimensions <= 2, "unsupported yet."); //_CCCL_ASSERT(dimensions <= 2, "unsupported yet."); - auto decorated_s = dst_memory_node.getDataStream(); + auto decorated_s = dst_memory_node.get_data_stream(); auto op = stream_async_op(bctx, decorated_s, prereqs); if (bctx.generate_event_symbols()) diff --git a/cudax/include/cuda/experimental/__stf/stream/internal/event_types.cuh b/cudax/include/cuda/experimental/__stf/stream/internal/event_types.cuh index 4322a2e0559..533c7f19c91 100644 --- a/cudax/include/cuda/experimental/__stf/stream/internal/event_types.cuh +++ b/cudax/include/cuda/experimental/__stf/stream/internal/event_types.cuh @@ -298,7 +298,7 @@ public: { // We did not select a stream yet, so we take one in the pools in // the async_resource_handle object associated to the context - dstream = place.getDataStream(); + dstream = place.get_data_stream(); } // Note that if we had stream_dev_id = -1 (eg. host memory), the device diff --git a/cudax/include/cuda/experimental/__stf/stream/reduction.cuh b/cudax/include/cuda/experimental/__stf/stream/reduction.cuh index d8d1c6168ad..2679dd5a615 100644 --- a/cudax/include/cuda/experimental/__stf/stream/reduction.cuh +++ b/cudax/include/cuda/experimental/__stf/stream/reduction.cuh @@ -74,7 +74,7 @@ public: const exec_place& ep, event_list& prereqs) override { - auto dstream = inout_memory_node.getDataStream(); + auto dstream = inout_memory_node.get_data_stream(); auto async_op = stream_async_op(d.get_ctx(), dstream, prereqs); if (d.get_ctx().generate_event_symbols()) { @@ -95,7 +95,7 @@ public: const exec_place& ep, event_list& prereqs) override { - auto dstream = out_memory_node.getDataStream(); + auto dstream = out_memory_node.get_data_stream(); auto async_op = stream_async_op(d.get_ctx(), dstream, prereqs); if (d.get_ctx().generate_event_symbols()) { diff --git a/cudax/include/cuda/experimental/__stf/stream/stream_ctx.cuh b/cudax/include/cuda/experimental/__stf/stream/stream_ctx.cuh index 8a771a930d7..5c3be88b96c 100644 --- a/cudax/include/cuda/experimental/__stf/stream/stream_ctx.cuh +++ b/cudax/include/cuda/experimental/__stf/stream/stream_ctx.cuh @@ -61,7 +61,7 @@ public: void* allocate(backend_ctx_untyped& ctx, const data_place& memory_node, ::std::ptrdiff_t& s, event_list& prereqs) override { - auto dstream = memory_node.getDataStream(); + auto dstream = memory_node.get_data_stream(); if (!memory_node.allocation_is_stream_ordered()) { @@ -83,7 +83,7 @@ public: void deallocate( backend_ctx_untyped& ctx, const data_place& memory_node, event_list& prereqs, void* ptr, size_t sz) override { - auto dstream = memory_node.getDataStream(); + auto dstream = memory_node.get_data_stream(); if (!memory_node.allocation_is_stream_ordered()) { @@ -219,7 +219,7 @@ public: decorated_stream dstream = (user_dstream.has_value()) ? user_dstream.value() - : exec_place::current_device().getStream(true /* stream for computation */); + : exec_place::current_device().get_stream(true /* stream for computation */); auto prereqs = get_state().insert_fence(*get_dot()); diff --git a/cudax/include/cuda/experimental/__stf/stream/stream_task.cuh b/cudax/include/cuda/experimental/__stf/stream/stream_task.cuh index 4176e74b01d..98114ee147b 100644 --- a/cudax/include/cuda/experimental/__stf/stream/stream_task.cuh +++ b/cudax/include/cuda/experimental/__stf/stream/stream_task.cuh @@ -128,7 +128,7 @@ public: const auto& places = grid.get_places(); for (const exec_place& p : places) { - stream_grid.push_back(p.getStream(true)); + stream_grid.push_back(p.get_stream(true)); } EXPECT(stream_grid.size() > 0UL); @@ -180,7 +180,7 @@ public: if (!found) { - dstream = e_place.getStream(true); + dstream = e_place.get_stream(true); // fprintf(stderr, "COULD NOT REUSE ... selected stream ID %ld\n", dstream.id); } } diff --git a/cudax/test/stf/cpp/test_pick_stream.cu b/cudax/test/stf/cpp/test_pick_stream.cu index a34b7470cce..28f207cc9b1 100644 --- a/cudax/test/stf/cpp/test_pick_stream.cu +++ b/cudax/test/stf/cpp/test_pick_stream.cu @@ -50,13 +50,13 @@ int main() } // ========================================================================== - // Test exec_place::getStream() - returns decorated_stream with metadata + // Test exec_place::get_stream() - returns decorated_stream with metadata // ========================================================================== { exec_place place = exec_place::current_device(); - // getStream() returns a decorated_stream with additional metadata - decorated_stream dstream = place.getStream(resources, true); + // get_stream() returns a decorated_stream with additional metadata + decorated_stream dstream = place.get_stream(resources, true); EXPECT(dstream.stream != nullptr); EXPECT(dstream.dev_id == current_device); EXPECT(get_device_from_stream(dstream.stream) == current_device); @@ -99,8 +99,8 @@ int main() EXPECT(stream != nullptr); EXPECT(get_device_from_stream(stream) == test_device); - // getStream returns more metadata - decorated_stream dstream = dev_place.getStream(resources, true); + // get_stream returns more metadata + decorated_stream dstream = dev_place.get_stream(resources, true); EXPECT(dstream.stream != nullptr); EXPECT(dstream.dev_id == test_device); } diff --git a/cudax/test/stf/cpp/test_pick_stream_green_context.cu b/cudax/test/stf/cpp/test_pick_stream_green_context.cu index c7beedbb454..fd125b289ac 100644 --- a/cudax/test/stf/cpp/test_pick_stream_green_context.cu +++ b/cudax/test/stf/cpp/test_pick_stream_green_context.cu @@ -110,8 +110,8 @@ int main() EXPECT(gc_stream != gc_stream1); } - // getStream() provides additional metadata if needed - decorated_stream dstream = gc_place0.getStream(resources, true); + // get_stream() provides additional metadata if needed + decorated_stream dstream = gc_place0.get_stream(resources, true); EXPECT(dstream.stream != nullptr); EXPECT(dstream.dev_id == current_device); From 7372e478b08eb89b38dd8c0ce9c2244325f6269d Mon Sep 17 00:00:00 2001 From: Andrei Alexandrescu Date: Tue, 10 Mar 2026 00:24:52 -0400 Subject: [PATCH 09/10] [STF] Simplify and harden exec_place implementation - Refactor data_place::to_string() to use cascaded ternary operators - Replace assert() calls with _CCCL_ASSERT for consistent diagnostics - Remove exec_place::impl::create_stream() and inline into stream_pool::next() - Simplify exec_place::operator->* to use exec_place_guard RAII pattern - Minor cleanup: remove redundant parentheses, reorder operator< logic --- .../cuda/experimental/__stf/places/places.cuh | 179 +++++------------- .../experimental/__stf/places/stream_pool.cuh | 8 +- .../stf/cpp/test_pick_stream_green_context.cu | 5 +- 3 files changed, 58 insertions(+), 134 deletions(-) diff --git a/cudax/include/cuda/experimental/__stf/places/places.cuh b/cudax/include/cuda/experimental/__stf/places/places.cuh index 1cf32ea166b..7773f940691 100644 --- a/cudax/include/cuda/experimental/__stf/places/places.cuh +++ b/cudax/include/cuda/experimental/__stf/places/places.cuh @@ -165,14 +165,7 @@ public: bool operator<(const data_place& rhs) const { // Not implemented for composite places - EXPECT(!is_composite()); - EXPECT(!rhs.is_composite()); - - // If both are extensions, delegate to the extension - if (is_extension() && rhs.is_extension()) - { - return *extension < *rhs.extension; - } + EXPECT((!is_composite() && !rhs.is_composite()), "Ordering of composite places is not implemented."); // Extensions sort after non-extensions if (is_extension() != rhs.is_extension()) @@ -180,6 +173,13 @@ public: return rhs.is_extension(); // non-extension < extension } + // If both are extensions, delegate to the extension + if (is_extension()) + { + // rhs.is_extension() is true due to previous test + return *extension < *rhs.extension; + } + // For simple places, compare devid return devid < rhs.devid; } @@ -204,14 +204,14 @@ public: { // If the devid indicates composite_devid then we must have a descriptor _CCCL_ASSERT(devid != composite_devid || composite_desc != nullptr, "invalid state"); - return (devid == composite_devid); + return devid == composite_devid; } /// checks if this data place has an extension (green context, etc.) bool is_extension() const { _CCCL_ASSERT(devid != extension_devid || extension != nullptr, "invalid state"); - return (devid == extension_devid); + return devid == extension_devid; } bool is_invalid() const @@ -238,7 +238,7 @@ public: bool is_device() const { // All other type of data places have a specific negative devid value. - return (devid >= 0); + return devid >= 0; } bool is_device_auto() const @@ -248,34 +248,13 @@ public: ::std::string to_string() const { - if (devid == host_devid) - { - return "host"; - } - if (devid == managed_devid) - { - return "managed"; - } - if (devid == device_auto_devid) - { - return "auto"; - } - if (devid == invalid_devid) - { - return "invalid"; - } - - if (is_extension()) - { - return extension->to_string(); - } - - if (is_composite()) - { - return "composite" + ::std::to_string(devid); - } - - return "dev" + ::std::to_string(devid); + return devid == host_devid ? "host" + : devid == managed_devid ? "managed" + : devid == device_auto_devid ? "auto" + : devid == invalid_devid ? "invalid" + : is_extension() ? extension->to_string() + : is_composite() ? "composite" + ::std::to_string(devid) + : "dev" + ::std::to_string(devid); } /** @@ -286,7 +265,7 @@ public: { EXPECT(p.devid >= -2, "Data place with device id ", p.devid, " does not refer to a device."); // This is not strictly a problem in this function, but it's not legit either. So let's assert. - assert(p.devid < cuda_try()); + _CCCL_ASSERT(p.devid < cuda_try(), "Invalid device id"); return p.devid + 2; } @@ -617,19 +596,18 @@ public: virtual exec_place activate() const { - if (affine.is_device()) + if (!affine.is_device()) { - auto old_dev_id = cuda_try(); - auto new_dev_id = device_ordinal(affine); - if (old_dev_id != new_dev_id) - { - cuda_safe_call(cudaSetDevice(new_dev_id)); - } - - auto old_dev = data_place::device(old_dev_id); - return exec_place(mv(old_dev)); + return exec_place(); } - return exec_place(); + auto old_dev_id = cuda_try(); + auto new_dev_id = device_ordinal(affine); + if (old_dev_id != new_dev_id) + { + cuda_safe_call(cudaSetDevice(new_dev_id)); + } + auto old_dev = data_place::device(old_dev_id); + return exec_place(mv(old_dev)); } virtual void deactivate(const exec_place& prev) const @@ -685,11 +663,6 @@ public: return affine == rhs.affine; } - bool operator!=(const impl& rhs) const - { - return !(*this == rhs); - } - virtual size_t hash() const { return affine.hash(); @@ -718,21 +691,6 @@ public: return for_computation ? pool_compute : pool_data; } - /** - * @brief Create a stream valid for execution on this place. - * - * Expected to be called with this exec place already activated (e.g. from - * stream_pool::next(place) which uses exec_place_guard). Creates a new stream - * in the current context via cudaStreamCreateWithFlags(..., cudaStreamNonBlocking). - * The caller (e.g. stream_pool::next) builds a decorated_stream from the result. - */ - cudaStream_t create_stream() const - { - cudaStream_t stream = nullptr; - cuda_safe_call(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); - return stream; - } - static constexpr size_t pool_size = 4; static constexpr size_t data_pool_size = 4; @@ -869,19 +827,6 @@ public: */ decorated_stream get_stream(bool for_computation) const; - /** - * @brief Create a stream valid for execution on this place. - * - * Call only when the place is already activated (e.g. inside exec_place_guard). - * For getting a stream from the pool, use get_stream() / pick_stream() instead. - * - * @return A CUDA stream valid for this execution place - */ - cudaStream_t create_stream() const - { - return pimpl->create_stream(); - } - cudaStream_t pick_stream(bool for_computation = true) const { return get_stream(for_computation).stream; @@ -1001,36 +946,7 @@ public: * */ template - auto operator->*(Fun&& fun) const - { - const int new_device = device_ordinal(pimpl->affine); - if (new_device >= 0) - { - // We're on a device - // Change device only if necessary. - const int old_device = cuda_try(); - if (new_device != old_device) - { - cuda_safe_call(cudaSetDevice(new_device)); - } - - SCOPE(exit) - { - // It is the responsibility of the client to ensure that any change of the current device in this - // section was reverted. - if (new_device != old_device) - { - cuda_safe_call(cudaSetDevice(old_device)); - } - }; - return ::std::forward(fun)(); - } - else - { - // We're on the host, just call the function with no further ado. - return ::std::forward(fun)(); - } - } + auto operator->*(Fun&& fun) const; public: exec_place(::std::shared_ptr pimpl) @@ -1109,6 +1025,13 @@ private: exec_place prev_; }; +template +auto exec_place::operator->*(Fun&& fun) const +{ + exec_place_guard guard(*this); + return ::std::forward(fun)(); +} + inline decorated_stream stream_pool::next(const exec_place& place) { _CCCL_ASSERT(pimpl, "stream_pool::next called on empty pool"); @@ -1120,7 +1043,7 @@ inline decorated_stream stream_pool::next(const exec_place& place) if (!result.stream) { exec_place_guard guard(place); - result.stream = place.create_stream(); + cuda_safe_call(cudaStreamCreateWithFlags(&result.stream, cudaStreamNonBlocking)); result.id = get_stream_id(result.stream); result.dev_id = get_device_from_stream(result.stream); } @@ -1414,9 +1337,9 @@ public: stream_pool& get_stream_pool(bool for_computation) const override { - assert(!for_computation); + _CCCL_ASSERT(!for_computation, "Expected data transfer stream pool"); const auto& v = get_places(); - assert(v.size() > 0); + _CCCL_ASSERT(v.size() > 0, "Grid must have at least one place"); return v[0].get_stream_pool(for_computation); } @@ -1597,7 +1520,7 @@ public: ::std::shared_ptr get_impl() const { - assert(::std::dynamic_pointer_cast(exec_place::get_impl())); + _CCCL_ASSERT(::std::dynamic_pointer_cast(exec_place::get_impl()), "Invalid exec_place_grid impl"); return ::std::static_pointer_cast(exec_place::get_impl()); } @@ -2091,15 +2014,15 @@ interpreted_execution_policy::interpreted_execution_policy( } // Make sure we have computed the width if that was implicit - assert(l0_size > 0); + _CCCL_ASSERT(l0_size > 0, "Level 0 size must be positive"); - assert(grid_size > 0); - assert(block_size <= kernel_limits.max_block_size); + _CCCL_ASSERT(grid_size > 0, "Grid size must be positive"); + _CCCL_ASSERT(block_size <= kernel_limits.max_block_size, "Block size exceeds max block size"); - assert(l0_size % ndevs == 0); - assert(l0_size % (ndevs * block_size) == 0); + _CCCL_ASSERT(l0_size % ndevs == 0, "Level 0 size must be divisible by number of devices"); + _CCCL_ASSERT(l0_size % (ndevs * block_size) == 0, "Level 0 size must be divisible by ndevs * block_size"); - assert(ndevs * grid_size * block_size == l0_size); + _CCCL_ASSERT(ndevs * grid_size * block_size == l0_size, "Dimension mismatch: ndevs * grid_size * block_size != l0_size"); this->add_level({::std::make_pair(hw_scope::device, ndevs), ::std::make_pair(hw_scope::block, grid_size), @@ -2142,9 +2065,9 @@ interpreted_execution_policy::interpreted_execution_policy( } // Enforce the resource limits in the number of threads per block - assert(int(l1_size) <= kernel_limits.block_size_limit); + _CCCL_ASSERT(int(l1_size) <= kernel_limits.block_size_limit, "Level 1 size exceeds block size limit"); - assert(l0_size % ndevs == 0); + _CCCL_ASSERT(l0_size % ndevs == 0, "Level 0 size must be divisible by number of devices"); /* Merge blocks and devices */ this->add_level({::std::make_pair(hw_scope::device, ndevs), ::std::make_pair(hw_scope::block, l0_size / ndevs)}); @@ -2197,8 +2120,8 @@ interpreted_execution_policy::interpreted_execution_policy( } // Enforce the resource limits in the number of threads per block - assert(int(l2_size) <= kernel_limits.block_size_limit); - assert(int(l0_size) <= ndevs); + _CCCL_ASSERT(int(l2_size) <= kernel_limits.block_size_limit, "Level 2 size exceeds block size limit"); + _CCCL_ASSERT(int(l0_size) <= ndevs, "Level 0 size exceeds number of devices"); /* Merge blocks and devices */ this->add_level({::std::make_pair(hw_scope::device, l0_size)}); diff --git a/cudax/include/cuda/experimental/__stf/places/stream_pool.cuh b/cudax/include/cuda/experimental/__stf/places/stream_pool.cuh index 9af9193f279..cbeb8736d8e 100644 --- a/cudax/include/cuda/experimental/__stf/places/stream_pool.cuh +++ b/cudax/include/cuda/experimental/__stf/places/stream_pool.cuh @@ -118,8 +118,8 @@ struct decorated_stream * This class uses a PIMPL idiom so that it is copyable and movable with shared * semantics: copies refer to the same underlying pool of streams. * - * When a slot is empty, next(place) activates the place (RAII guard) and calls - * place.create_stream(). Defined in places.cuh. + * When a slot is empty, next(place) activates the place (RAII guard) and creates + * a new stream. Defined in places.cuh. */ class stream_pool { @@ -159,8 +159,8 @@ public: stream_pool& operator=(stream_pool&&) = default; /** - * @brief Get the next stream in the pool; when a slot is empty, activate the place (RAII guard) and call - * place.create_stream(). Defined in places.cuh so the pool can use exec_place_guard and exec_place::create_stream(). + * @brief Get the next stream in the pool; when a slot is empty, activate the place (RAII guard) and create + * a new stream. Defined in places.cuh so the pool can use exec_place_guard. */ decorated_stream next(const exec_place& place); diff --git a/cudax/test/stf/cpp/test_pick_stream_green_context.cu b/cudax/test/stf/cpp/test_pick_stream_green_context.cu index fd125b289ac..f269beb76a4 100644 --- a/cudax/test/stf/cpp/test_pick_stream_green_context.cu +++ b/cudax/test/stf/cpp/test_pick_stream_green_context.cu @@ -115,10 +115,11 @@ int main() EXPECT(dstream.stream != nullptr); EXPECT(dstream.dev_id == current_device); - // create_stream() returns cudaStream_t; call with place activated so the stream is in the green context + // Verify that streams created while a green context place is activated belong to that green context { exec_place_guard guard(gc_place0); - cudaStream_t created = gc_place0.create_stream(); + cudaStream_t created = nullptr; + cuda_safe_call(cudaStreamCreateWithFlags(&created, cudaStreamNonBlocking)); EXPECT(created != nullptr); EXPECT(get_device_from_stream(created) == current_device); verify_stream_green_context(created, view0.g_ctx); From b30ffa3ea6166a8a50585b736e8fd0610a0d1c5c Mon Sep 17 00:00:00 2001 From: Elias Stehle <3958403+elstehle@users.noreply.github.com> Date: Tue, 10 Mar 2026 05:21:11 +0100 Subject: [PATCH 10/10] Add support for small, variable-size segments to `DeviceBatchedTopK` (#7926) * prepare tests for var segments * adds tests for small variable-size segments * restores tests for types * reverts now-superfluous segment fixup test code * variable segment size tests * split up test tus, restore utils * adds tests for variable-size keys * fixes tu splitup * style fix --- cub/cub/agent/agent_batched_topk.cuh | 5 +- .../device/dispatch/dispatch_batched_topk.cuh | 21 +- .../catch2_test_device_segmented_topk_keys.cu | 105 +++++++- ...catch2_test_device_segmented_topk_pairs.cu | 241 +++++++++++++++--- cub/test/catch2_test_device_topk_common.cuh | 147 +++++++++-- 5 files changed, 451 insertions(+), 68 deletions(-) diff --git a/cub/cub/agent/agent_batched_topk.cuh b/cub/cub/agent/agent_batched_topk.cuh index c8f081e3717..948f2b4ffa2 100644 --- a/cub/cub/agent/agent_batched_topk.cuh +++ b/cub/cub/agent/agent_batched_topk.cuh @@ -150,8 +150,9 @@ struct agent_batched_topk_worker_per_segment // Resolve Segment Parameters const auto segment_size = segment_sizes.get_param(segment_id); - const auto k = k_param.get_param(segment_id); - const auto direction = select_directions.get_param(segment_id); + const auto k = ::cuda::std::min( + k_param.get_param(segment_id), static_cast(segment_size)); + const auto direction = select_directions.get_param(segment_id); // Determine padding key based on direction const key_t padding_key = diff --git a/cub/cub/device/dispatch/dispatch_batched_topk.cuh b/cub/cub/device/dispatch/dispatch_batched_topk.cuh index 49916a007ea..c2a7cfba220 100644 --- a/cub/cub/device/dispatch/dispatch_batched_topk.cuh +++ b/cub/cub/device/dispatch/dispatch_batched_topk.cuh @@ -198,14 +198,10 @@ struct dispatch_batched_topk static constexpr bool keys_only = ::cuda::std::is_same_v; template - CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t invoke_fixed_segment_size() + CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t invoke_one_worker_per_segment() { using max_policy_t = typename SelectedPolicy::max_policy; - // Currently, only uniform segment sizes are supported - static_assert(!params::is_per_segment_param_v, - "Only uniform segment sizes are currently supported."); - // Instantiate the kernel with the selected policy and check shared memory requirements using topk_policy_t = ActiveWorkerPerSegmentPolicyTPolicyT; @@ -281,17 +277,14 @@ struct dispatch_batched_topk SelectDirectionParameterT, NumSegmentsParameterT>; - // Currently, we only support fixed-size segments that fit into shared memory + // Currently, we only support segments that fit into shared memory // TODO (elstehle): extend support for variable-size segments - static_assert( - !params::is_per_segment_param_v - && find_smallest_covering_policy_t::supports_one_worker_per_segment, - "Currently only small, fixed-size segments are supported, where each segment can be processed by a single thread " - "block."); - if constexpr (!params::is_per_segment_param_v - && find_smallest_covering_policy_t::supports_one_worker_per_segment) + static_assert(find_smallest_covering_policy_t::supports_one_worker_per_segment, + "Currently only small segments are supported, where each segment can be processed by a single thread " + "block."); + if constexpr (find_smallest_covering_policy_t::supports_one_worker_per_segment) { - return invoke_fixed_segment_size(); + return invoke_one_worker_per_segment(); } else { diff --git a/cub/test/catch2_test_device_segmented_topk_keys.cu b/cub/test/catch2_test_device_segmented_topk_keys.cu index 7ab81602485..c9a1e891990 100644 --- a/cub/test/catch2_test_device_segmented_topk_keys.cu +++ b/cub/test/catch2_test_device_segmented_topk_keys.cu @@ -8,6 +8,7 @@ #include #include +#include #include #include @@ -168,11 +169,111 @@ C2H_TEST("DeviceBatchedTopK::{Min,Max}Keys work with small fixed-size segments", cub::detail::batched_topk::num_segments_uniform<>{num_segments}, cub::detail::batched_topk::total_num_items_guarantee{num_segments * segment_size}); // Prepare expected results - segmented_sort_keys(expected_keys, num_segments, segment_size, direction); + fixed_size_segmented_sort_keys(expected_keys, num_segments, segment_size, direction); compact_sorted_keys_to_topk(expected_keys, segment_size, k); // Since the results of top-k are unordered, sort output segments before comparison. - segmented_sort_keys(keys_out_buffer, num_segments, k, direction); + fixed_size_segmented_sort_keys(keys_out_buffer, num_segments, k, direction); + + REQUIRE(expected_keys == keys_out_buffer); +} + +C2H_TEST("DeviceBatchedTopK::{Min,Max}Keys work with small variable-size segments", + "[keys][segmented][topk][device]", + key_types, + max_segment_size_list, + max_num_k_list) +{ + using segment_size_t = cuda::std::int64_t; + using segment_index_t = cuda::std::int64_t; + + using key_t = c2h::get<0, TestType>; + + // Statically constrained maximum segment size and k + constexpr segment_size_t static_max_segment_size = c2h::get<1, TestType>::value; + constexpr segment_size_t static_max_k = c2h::get<2, TestType>::value; + + // Test both directions (as runtime value) + const auto direction = GENERATE_COPY(cub::detail::topk::select::min, cub::detail::topk::select::max); + + constexpr segment_size_t min_items = 1; + constexpr segment_size_t max_items = 1000000; + + // Number of items + const segment_size_t num_items = GENERATE_COPY( + take(2, random(min_items, max_items)), + values({ + min_items, + max_items, + })); + + // Generate segment sizes + constexpr segment_size_t min_segment_size = 1; + constexpr auto max_segment_size = static_max_segment_size; + c2h::device_vector segment_offsets = + c2h::gen_uniform_offsets(C2H_SEED(3), num_items, min_segment_size, max_segment_size); + const segment_index_t num_segments = static_cast(segment_offsets.size() - 1); + auto segment_offsets_it = thrust::raw_pointer_cast(segment_offsets.data()); + auto segment_size_it = cuda::make_transform_iterator( + cuda::make_counting_iterator(segment_index_t{0}), segment_size_op{segment_offsets_it}); + + // Set the k value + const segment_size_t k = + GENERATE_COPY(values({segment_size_t{1}, static_max_k}), take(3, random(segment_size_t{1}, static_max_k))); + + // Capture test parameters + CAPTURE(c2h::type_name(), + c2h::type_name(), + c2h::type_name(), + static_max_segment_size, + static_max_k, + k, + num_segments, + direction); + + // Compute compacted output offsets: + // Each output segment holds exactly min(k, segment_size[i]) items, tightly packed. + auto compacted_output_sizes_it = cuda::make_transform_iterator( + cuda::make_counting_iterator(segment_index_t{0}), + get_output_size_op{segment_offsets.cbegin(), cuda::constant_iterator(k)}); + c2h::device_vector compacted_offsets(num_segments + 1, thrust::no_init); + thrust::exclusive_scan( + compacted_output_sizes_it, compacted_output_sizes_it + num_segments + 1, compacted_offsets.begin()); + segment_size_t total_output_size = compacted_offsets.back(); + + // Prepare keys input & output + c2h::device_vector keys_in_buffer(num_items, thrust::no_init); + c2h::device_vector keys_out_buffer(total_output_size, thrust::no_init); + const int num_key_seeds = 1; + c2h::gen(C2H_SEED(num_key_seeds), keys_in_buffer); + auto d_keys_in_ptr = thrust::raw_pointer_cast(keys_in_buffer.data()); + auto d_keys_out_ptr = thrust::raw_pointer_cast(keys_out_buffer.data()); + auto d_keys_in = + cuda::make_permutation_iterator(cuda::make_counting_iterator(d_keys_in_ptr), segment_offsets.cbegin()); + auto d_keys_out = + cuda::make_permutation_iterator(cuda::make_counting_iterator(d_keys_out_ptr), compacted_offsets.cbegin()); + + // Copy input for verification + c2h::device_vector expected_keys(keys_in_buffer); + + // Run the top-k algorithm + batched_topk_keys( + d_keys_in, + d_keys_out, + cub::detail::batched_topk::segment_size_per_segment{ + segment_size_it}, + cub::detail::batched_topk::k_uniform<1, static_max_k>{k}, + cub::detail::batched_topk::select_direction_uniform{direction}, + cub::detail::batched_topk::num_segments_uniform<>{num_segments}, + cub::detail::batched_topk::total_num_items_guarantee{num_items}); + + // Verify keys are returned correctly: sort each segment of the expected input, then compact the top-k + segmented_sort_keys(expected_keys, num_segments, segment_offsets.cbegin(), segment_offsets.cbegin() + 1, direction); + expected_keys = compact_to_topk_batched(expected_keys, segment_offsets, k); + + // Since the results of top-k are unordered, sort compacted output segments before comparison + segmented_sort_keys( + keys_out_buffer, num_segments, compacted_offsets.cbegin(), compacted_offsets.cbegin() + 1, direction); REQUIRE(expected_keys == keys_out_buffer); } diff --git a/cub/test/catch2_test_device_segmented_topk_pairs.cu b/cub/test/catch2_test_device_segmented_topk_pairs.cu index 05d8afe65d7..6e50386fe3c 100644 --- a/cub/test/catch2_test_device_segmented_topk_pairs.cu +++ b/cub/test/catch2_test_device_segmented_topk_pairs.cu @@ -7,6 +7,8 @@ #include #include +#include +#include #include @@ -16,25 +18,35 @@ #include #include -// Function object used to flag duplicate items within a segment -template -struct flag_duplicates_in_segment +// Maps an item index to its segment id for fixed-size segments +struct fixed_stride_segment_id_op +{ + cuda::std::int64_t stride; + + template + __device__ IndexT operator()(IndexT idx) const + { + return static_cast(idx / stride); + } +}; + +// Flags adjacent duplicate items that belong to the same segment +template +struct flag_intra_segment_duplicates { ItemItT d_sorted_items; - cuda::std::int64_t segment_size; + SegIdItT d_segment_ids; - bool __device__ operator()(cuda::std::int64_t idx) const + template + __device__ bool operator()(IndexT idx) const { - // Only flag if items at i and i+1 are in the same segment - bool same_segment = ((idx + 1) % segment_size != 0); - if (same_segment) - { - return d_sorted_items[idx] == d_sorted_items[idx + 1]; - } - return false; + return d_segment_ids[idx] == d_segment_ids[idx + 1] && d_sorted_items[idx] == d_sorted_items[idx + 1]; } }; +template +flag_intra_segment_duplicates(ItemItT, SegIdItT) -> flag_intra_segment_duplicates; + template ; // Segment size: static, uniform using max_num_k_list = c2h::enum_type_list; -using key_types = c2h::type_list< // cuda::std::uint8_t, - float //, - // cuda::std::uint64_t - // // clang-format off - // #if TEST_HALF_T() - // , half_t - // #endif // TEST_HALF_T() - // #if TEST_BF_T() - // , bfloat16_t - // #endif // TEST_BF_T() +// %PARAM% TEST_TYPES types 0:1:2 + +#if TEST_TYPES == 0 +using key_types = + c2h::type_list; // clang-format on +#elif TEST_TYPES == 1 +using key_types = c2h::type_list; +#elif TEST_TYPES == 2 +using key_types = c2h::type_list; +#endif + +// Unsigned integer types used for the radix-pass boundary distribution test +using uint_key_types = c2h::type_list; // Consistency check: ensures values remain associated with their corresponding keys template @@ -124,26 +146,54 @@ bool verify_pairs_consistency(const c2h::device_vector& keys_in, template bool verify_unique_indices(c2h::device_vector& values_out, cuda::std::int64_t num_segments, cuda::std::int64_t k) { - // Make a copy to sort - c2h::device_vector sorted_values = values_out; + // Make a copy & sort + c2h::device_vector sorted_values{values_out}; + fixed_size_segmented_sort_keys(sorted_values, num_segments, k, cub::detail::topk::select::min); - // Sort the values within each segment for subsequent duplicate check - segmented_sort_keys(sorted_values, num_segments, k, cub::detail::topk::select::min); + auto num_items = sorted_values.size(); + auto counting_it = cuda::make_counting_iterator(cuda::std::int64_t{0}); + auto seg_ids = cuda::make_transform_iterator(counting_it, fixed_stride_segment_id_op{k}); + flag_intra_segment_duplicates flag_op{sorted_values.cbegin(), seg_ids}; + auto num_duplicates = thrust::count_if(counting_it, counting_it + (num_items - 1), flag_op); - // Check for adjacent duplicates within segment boundaries - auto d_sorted_values = thrust::raw_pointer_cast(sorted_values.data()); - auto num_items = sorted_values.size(); + return num_duplicates == 0; +} - flag_duplicates_in_segment flag_op{d_sorted_values, k}; +// Overload for variable-size segments: sorts compacted values within each segment and checks for duplicates +template +bool verify_unique_indices(const c2h::device_vector& values_compacted, + const c2h::device_vector& compacted_offsets, + cuda::std::int64_t num_segments) +{ + c2h::device_vector sorted_values = values_compacted; + segmented_sort_keys( + sorted_values, + num_segments, + compacted_offsets.cbegin(), + compacted_offsets.cbegin() + 1, + cub::detail::topk::select::min); + + auto num_items = sorted_values.size(); + + // Generate segment ids via scatter + inclusive_scan: scatter a 1 at each interior segment + // boundary, then prefix-sum to produce monotonic group ids + c2h::device_vector segment_ids(num_items, OffsetT{0}); + thrust::scatter(cuda::constant_iterator(1), + cuda::constant_iterator(1) + (num_segments - 1), + compacted_offsets.cbegin() + 1, + segment_ids.begin()); + thrust::inclusive_scan(segment_ids.begin(), segment_ids.end(), segment_ids.begin()); + + flag_intra_segment_duplicates flag_op{sorted_values.cbegin(), segment_ids.cbegin()}; auto num_duplicates = - thrust::count_if(thrust::make_counting_iterator(size_t{0}), thrust::make_counting_iterator(num_items - 1), flag_op); + thrust::count_if(cuda::make_counting_iterator(size_t{0}), cuda::make_counting_iterator(num_items - 1), flag_op); return num_duplicates == 0; } C2H_TEST("DeviceBatchedTopK::{Min,Max}Pairs work with small fixed-size segments", - "[keys][segmented][topk][device]", + "[pairs][segmented][topk][device]", key_types, max_segment_size_list, max_num_k_list) @@ -234,12 +284,133 @@ C2H_TEST("DeviceBatchedTopK::{Min,Max}Pairs work with small fixed-size segments" // This catches the case where we just returned a valid value multiple times REQUIRE(verify_unique_indices(values_out_buffer, num_segments, k) == true); - // Verify keys are sorted correctly - segmented_sort_keys(expected_keys, num_segments, segment_size, direction); + // Verify keys are returned correctly + fixed_size_segmented_sort_keys(expected_keys, num_segments, segment_size, direction); compact_sorted_keys_to_topk(expected_keys, segment_size, k); // Since the results of top-k are unordered, sort output segments before comparison. - segmented_sort_keys(keys_out_buffer, num_segments, k, direction); + fixed_size_segmented_sort_keys(keys_out_buffer, num_segments, k, direction); + + REQUIRE(expected_keys == keys_out_buffer); +} + +C2H_TEST("DeviceBatchedTopK::{Min,Max}Pairs work with small variable-size segments", + "[pairs][segmented][topk][device]", + key_types, + max_segment_size_list, + max_num_k_list) +{ + using segment_size_t = cuda::std::int64_t; + using segment_index_t = cuda::std::int64_t; + + using key_t = c2h::get<0, TestType>; + using val_t = cuda::std::int32_t; + + // Statically constrained maximum segment size and k + constexpr segment_size_t static_max_segment_size = c2h::get<1, TestType>::value; + constexpr segment_size_t static_max_k = c2h::get<2, TestType>::value; + + // Test both directions (as runtime value) + const auto direction = GENERATE_COPY(cub::detail::topk::select::min, cub::detail::topk::select::max); + + constexpr segment_size_t min_items = 1; + constexpr segment_size_t max_items = 1000000; + + // Number of items + const segment_size_t num_items = GENERATE_COPY( + take(2, random(min_items, max_items)), + values({ + min_items, + max_items, + })); + + // Generate segment sizes + constexpr segment_size_t min_segment_size = 1; + constexpr auto max_segment_size = static_max_segment_size; + c2h::device_vector segment_offsets = + c2h::gen_uniform_offsets(C2H_SEED(3), num_items, min_segment_size, max_segment_size); + const segment_index_t num_segments = static_cast(segment_offsets.size() - 1); + auto segment_offsets_it = thrust::raw_pointer_cast(segment_offsets.data()); + auto segment_size_it = cuda::make_transform_iterator( + cuda::make_counting_iterator(segment_index_t{0}), segment_size_op{segment_offsets_it}); + + // Set the k value + const segment_size_t k = + GENERATE_COPY(values({segment_size_t{1}, static_max_k}), take(3, random(segment_size_t{1}, static_max_k))); + + // Capture test parameters + CAPTURE(c2h::type_name(), + c2h::type_name(), + c2h::type_name(), + static_max_segment_size, + static_max_k, + k, + num_segments, + direction); + + // Compute compacted output offsets: + // Each output segment holds exactly min(k, segment_size[i]) items, tightly packed. + auto compacted_output_sizes_it = cuda::make_transform_iterator( + cuda::make_counting_iterator(segment_index_t{0}), + get_output_size_op{segment_offsets.cbegin(), cuda::constant_iterator(k)}); + c2h::device_vector compacted_offsets(num_segments + 1, thrust::no_init); + thrust::exclusive_scan( + compacted_output_sizes_it, compacted_output_sizes_it + num_segments + 1, compacted_offsets.begin()); + segment_size_t total_output_size = compacted_offsets.back(); + + // Prepare keys input & output + c2h::device_vector keys_in_buffer(num_items, thrust::no_init); + c2h::device_vector keys_out_buffer(total_output_size, thrust::no_init); + const int num_key_seeds = 1; + c2h::gen(C2H_SEED(num_key_seeds), keys_in_buffer); + auto d_keys_in_ptr = thrust::raw_pointer_cast(keys_in_buffer.data()); + auto d_keys_out_ptr = thrust::raw_pointer_cast(keys_out_buffer.data()); + auto d_keys_in = + cuda::make_permutation_iterator(cuda::make_counting_iterator(d_keys_in_ptr), segment_offsets.cbegin()); + auto d_keys_out = + cuda::make_permutation_iterator(cuda::make_counting_iterator(d_keys_out_ptr), compacted_offsets.cbegin()); + + // Prepare values input & output + auto values_in_it = cuda::make_counting_iterator(val_t{0}); + c2h::device_vector values_out_buffer(total_output_size, thrust::no_init); + auto d_values_out_ptr = thrust::raw_pointer_cast(values_out_buffer.data()); + auto d_values_in = + cuda::make_permutation_iterator(cuda::make_counting_iterator(values_in_it), segment_offsets.cbegin()); + auto d_values_out = + cuda::make_permutation_iterator(cuda::make_counting_iterator(d_values_out_ptr), compacted_offsets.cbegin()); + + // Copy input for verification + c2h::device_vector expected_keys(keys_in_buffer); + + // Run the top-k algorithm + batched_topk_pairs( + d_keys_in, + d_keys_out, + d_values_in, + d_values_out, + cub::detail::batched_topk::segment_size_per_segment{ + segment_size_it}, + cub::detail::batched_topk::k_uniform<1, static_max_k>{k}, + cub::detail::batched_topk::select_direction_uniform{direction}, + cub::detail::batched_topk::num_segments_uniform<>{num_segments}, + cub::detail::batched_topk::total_num_items_guarantee{num_items}); + + // Verification: + // - We verify correct top-k selection through the keys + // - We verify that values were permuted along correctly by making sure values remain associated with their keys and + // making sure we do not duplicate values + REQUIRE(verify_pairs_consistency(expected_keys, keys_out_buffer, values_out_buffer) == true); + + // Verify values don't appear more than once in the returned results + REQUIRE(verify_unique_indices(values_out_buffer, compacted_offsets, num_segments) == true); + + // Verify keys are returned correctly: sort each segment of the expected input, then compact the top-k + segmented_sort_keys(expected_keys, num_segments, segment_offsets.cbegin(), segment_offsets.cbegin() + 1, direction); + expected_keys = compact_to_topk_batched(expected_keys, segment_offsets, k); + + // Since the results of top-k are unordered, sort compacted output segments before comparison + segmented_sort_keys( + keys_out_buffer, num_segments, compacted_offsets.cbegin(), compacted_offsets.cbegin() + 1, direction); REQUIRE(expected_keys == keys_out_buffer); } diff --git a/cub/test/catch2_test_device_topk_common.cuh b/cub/test/catch2_test_device_topk_common.cuh index 8d6a1d1c47b..a66e7aa86d8 100644 --- a/cub/test/catch2_test_device_topk_common.cuh +++ b/cub/test/catch2_test_device_topk_common.cuh @@ -3,6 +3,7 @@ #pragma once +#include #include #include // topk::select::{min, max} @@ -43,6 +44,52 @@ struct inc_t } }; +template +struct segment_size_op +{ + OffsetItT d_offsets; + + template + __host__ __device__ __forceinline__ auto operator()(IndexT segment_id) const + { + return d_offsets[segment_id + 1] - d_offsets[segment_id]; + } +}; + +template +struct get_output_size_op +{ + OffsetItT offset_it; + KSizesItT k_it; + + __device__ __forceinline__ cuda::std::int64_t operator()(cuda::std::int64_t segment_id) const + { + const auto segment_size = offset_it[segment_id + 1] - offset_it[segment_id]; + return ::cuda::std::min(static_cast(k_it[segment_id]), segment_size); + } +}; + +template +get_output_size_op(OffsetItT, KSizesItT) -> get_output_size_op; + +template +struct offset_iterator_op +{ + IteratorT base_it; + OffsetItT offset_it; + + offset_iterator_op(IteratorT base_it, OffsetItT offset_it) + : base_it(base_it) + , offset_it(offset_it) + {} + + template + __device__ __forceinline__ IteratorT operator()(IndexT segment_id) const + { + return base_it + offset_it[segment_id]; + } +}; + template using direction_to_comparator_t = cuda::std::conditional_t, cuda::std::greater<>>; @@ -181,11 +228,53 @@ void compact_sorted_keys_to_topk( d_keys_in.resize(new_end - d_keys_in.begin()); } -template -void segmented_sort_keys(c2h::device_vector& d_keys_in, - cuda::std::int64_t num_segments, - cuda::std::int64_t segment_size, - cub::detail::topk::select direction) +// Stream-compacts each segment to only contain the top-k elements +template +c2h::device_vector compact_to_topk_batched( + c2h::device_vector& d_keys_in, const c2h::device_vector& d_offsets, cuda::std::int64_t k) +{ + // Expect + const auto num_segments = d_offsets.size() - 1; + + // Maps segments to source pointers: d_keys_in.data() + offset[i] + auto src_ptrs_it = cuda::make_transform_iterator( + cuda::make_counting_iterator(0), offset_iterator_op{d_keys_in.cbegin(), d_offsets.cbegin()}); + + // Calculates the output sizes (if segment size is smaller than k, then output size is segment size, otherwise k) + auto copy_sizes_it = cuda::make_transform_iterator( + cuda::make_counting_iterator(0), get_output_size_op{d_offsets.cbegin(), cuda::constant_iterator(k)}); + + // Calculate destination offsets via prefix sum + c2h::device_vector d_output_offsets(num_segments + 1, thrust::no_init); + thrust::exclusive_scan(copy_sizes_it, copy_sizes_it + num_segments + 1, d_output_offsets.begin()); + + OffsetT total_compacted_size = d_output_offsets.back(); + c2h::device_vector d_keys_out(total_compacted_size, thrust::no_init); + + // Map segments to destination pointers: d_keys_out.data() + new_offset[i] + auto dst_ptrs_it = cuda::make_transform_iterator( + cuda::make_counting_iterator(0), offset_iterator_op{d_keys_out.begin(), d_output_offsets.cbegin()}); + + // Query temporary storage size + void* d_temp_storage = nullptr; + size_t temp_storage_bytes = 0; + cub::DeviceCopy::Batched(d_temp_storage, temp_storage_bytes, src_ptrs_it, dst_ptrs_it, copy_sizes_it, num_segments); + c2h::device_vector d_temp(temp_storage_bytes, thrust::no_init); + d_temp_storage = thrust::raw_pointer_cast(d_temp.data()); + + // Run batched copy to compact top-k elements of each segment to the front of the input buffer + cub::DeviceCopy::Batched(d_temp_storage, temp_storage_bytes, src_ptrs_it, dst_ptrs_it, copy_sizes_it, num_segments); + + return d_keys_out; +} + +template +void segmented_sort_keys( + c2h::device_vector& d_keys_in, + cuda::std::int64_t num_segments, + OffsetItT d_segment_offsets_begin_it, + OffsetItT d_segment_offsets_end_it, + cub::detail::topk::select direction) { cuda::std::int64_t num_items = d_keys_in.size(); @@ -194,16 +283,18 @@ void segmented_sort_keys(c2h::device_vector& d_keys_in, cub::DoubleBuffer d_keys( thrust::raw_pointer_cast(d_keys_in.data()), thrust::raw_pointer_cast(d_keys_alt.data())); - // Prepare segment offsets - auto segment_offsets_it = - cuda::make_strided_iterator(cuda::make_counting_iterator(0), segment_size); - // Query temporary storage size size_t temp_storage_bytes = 0; if (direction == cub::detail::topk::select::min) { cub::DeviceSegmentedSort::SortKeys( - nullptr, temp_storage_bytes, d_keys, num_items, num_segments, segment_offsets_it, (segment_offsets_it + 1)); + nullptr, + temp_storage_bytes, + d_keys, + num_items, + num_segments, + d_segment_offsets_begin_it, + d_segment_offsets_end_it); // Allocate temporary storage c2h::device_vector d_temp_storage(temp_storage_bytes, thrust::no_init); @@ -215,13 +306,19 @@ void segmented_sort_keys(c2h::device_vector& d_keys_in, d_keys, num_items, num_segments, - segment_offsets_it, - (segment_offsets_it + 1)); + d_segment_offsets_begin_it, + d_segment_offsets_end_it); } else { cub::DeviceSegmentedSort::SortKeysDescending( - nullptr, temp_storage_bytes, d_keys, num_items, num_segments, segment_offsets_it, (segment_offsets_it + 1)); + nullptr, + temp_storage_bytes, + d_keys, + num_items, + num_segments, + d_segment_offsets_begin_it, + d_segment_offsets_end_it); // Allocate temporary storage c2h::device_vector d_temp_storage(temp_storage_bytes, thrust::no_init); @@ -233,8 +330,8 @@ void segmented_sort_keys(c2h::device_vector& d_keys_in, d_keys, num_items, num_segments, - segment_offsets_it, - (segment_offsets_it + 1)); + d_segment_offsets_begin_it, + d_segment_offsets_end_it); } // Make sure the result is returned in the original buffer @@ -243,3 +340,23 @@ void segmented_sort_keys(c2h::device_vector& d_keys_in, thrust::copy(d_keys.Current(), d_keys.Current() + num_items, d_keys_in.begin()); } } + +template +void fixed_size_segmented_sort_keys( + c2h::device_vector& d_keys_in, + cuda::std::int64_t num_segments, + cuda::std::int64_t segment_size, + cub::detail::topk::select direction) +{ + auto segment_offsets_it = + cuda::make_strided_iterator(cuda::make_counting_iterator(0), segment_size); + + // We materialize the offsets to reduce the number of kernel template specializations + c2h::device_vector d_segment_offsets(num_segments + 1); + thrust::copy(segment_offsets_it, segment_offsets_it + (num_segments + 1), d_segment_offsets.begin()); + + // Perform segmented sort + auto d_segment_offsets_begin_it = d_segment_offsets.cbegin(); + auto d_segment_offsets_end_it = d_segment_offsets_begin_it + 1; + segmented_sort_keys(d_keys_in, num_segments, d_segment_offsets_begin_it, d_segment_offsets_end_it, direction); +}