diff --git a/cpp/include/rapidsmpf/shuffler/shuffler.hpp b/cpp/include/rapidsmpf/shuffler/shuffler.hpp index f0d8d00ad..14e755b75 100644 --- a/cpp/include/rapidsmpf/shuffler/shuffler.hpp +++ b/cpp/include/rapidsmpf/shuffler/shuffler.hpp @@ -122,8 +122,8 @@ class Shuffler { * @param mpe Optional custom metadata payload exchange. If not provided, * uses the default tag-based implementation. * - * @note It is safe to reuse the `op_id` as soon as `wait` has completed - * locally. + * @note It is safe to reuse the `op_id` as soon as `wait_reusable()` has + * completed locally, or after the shuffler has been shut down or destroyed. * * @note The caller promises that inserted buffers are stream-ordered with respect * to their own stream, and extracted buffers are likewise guaranteed to be stream- @@ -143,14 +143,16 @@ class Shuffler { * @brief Construct a new shuffler for a single shuffle. * * @param comm The communicator to use. - * @param op_id The operation ID of the shuffle. This ID is unique for this operation, - * and should not be reused until all nodes has called `Shuffler::shutdown()`. + * @param op_id The operation ID of the shuffle. * @param total_num_partitions Total number of partitions in the shuffle. * @param br Buffer resource used to allocate temporary and the shuffle result. * @param partition_owner Function to determine partition ownership. * @param mpe Optional custom metadata payload exchange. If not provided, * uses the default tag-based implementation. * + * @note It is safe to reuse the `op_id` as soon as `wait_reusable()` has + * completed locally, or after the shuffler has been shut down or destroyed. + * * @note The caller promises that inserted buffers are stream-ordered with respect * to their own stream, and extracted buffers are likewise guaranteed to be stream- * ordered with respect to their own stream. @@ -240,12 +242,27 @@ class Shuffler { /** * @brief Wait for all partitions to finish (blocking). * + * Once this returns, local partitions are ready to extract. Internal send and + * protocol cleanup may still be progressing in the background. + * * @param timeout Optional timeout (ms) to wait. * * @throws std::runtime_error if the timeout is reached. */ void wait(std::optional timeout = {}); + /** + * @brief Wait until this shuffle is fully drained and its op_id is reusable. + * + * This is stronger than `wait()`: it also waits for all internal communication + * state to drain, including protocol messages used to protect op_id reuse. + * + * @param timeout Optional timeout (ms) to wait. + * + * @throws std::runtime_error if the timeout is reached. + */ + void wait_reusable(std::optional timeout = {}); + /** * @brief Spills data to device if necessary. * @@ -325,9 +342,10 @@ class Shuffler { std::atomic active_{true}; // Have we called `insert_finished()` on this rank. std::atomic locally_finished_{false}; - // Flipped to true exactly once when partitions are ready for extraction and we've - // posted all sends we're going to + // Flipped to true exactly once when partitions are ready for extraction. bool can_extract_{false}; + // Flipped to true exactly once when all internal communication has drained. + bool can_reuse_{false}; detail::ChunksToSend to_send_; ///< Storage for chunks to send to other ranks. detail::ReceivedChunks received_; ///< Storage for received chunks that are ///< ready to be extracted by the user. diff --git a/cpp/src/shuffler/shuffler.cpp b/cpp/src/shuffler/shuffler.cpp index 2c81a40e1..aaafcadde 100644 --- a/cpp/src/shuffler/shuffler.cpp +++ b/cpp/src/shuffler/shuffler.cpp @@ -182,20 +182,31 @@ class Shuffler::Progress { // Finished and shuffler is no longer active. bool const is_done = !shuffler_.active_.load(std::memory_order_acquire) && is_finished && containers_empty; - // Signal can_extract_ when all chunks have been received and all internal - // containers are drained. If we own no partitions we "can-extract" immediately, - // but we only wake a waiter once we've drained internal containers so that we can - // reuse the op_id for a subsequent shuffle. - if (!shuffler_.can_extract_ && is_finished && containers_empty) { - { - std::lock_guard lock(shuffler_.mutex_); + // Signal can_extract_ as soon as local outputs are complete so extraction can + // overlap with remaining send/protocol cleanup. Signal can_reuse_ only after all + // internal state is drained. + bool const ready_to_extract = is_finished; + bool const ready_to_reuse = is_finished && containers_empty; + bool notify = false; + FinishedCallback callback; + { + std::lock_guard lock(shuffler_.mutex_); + if (!shuffler_.can_extract_ && ready_to_extract) { shuffler_.can_extract_ = true; + callback = std::move(shuffler_.finished_callback_); + notify = true; } - shuffler_.cv_.notify_all(); - if (auto callback = std::move(shuffler_.finished_callback_)) { - callback(); + if (!shuffler_.can_reuse_ && ready_to_reuse) { + shuffler_.can_reuse_ = true; + notify = true; } } + if (notify) { + shuffler_.cv_.notify_all(); + } + if (callback) { + callback(); + } return is_done ? ProgressThread::ProgressState::Done : ProgressThread::ProgressState::InProgress; } @@ -435,6 +446,20 @@ void Shuffler::wait(std::optional timeout) { } } +void Shuffler::wait_reusable(std::optional timeout) { + RAPIDSMPF_NVTX_FUNC_RANGE(); + std::unique_lock lock(mutex_); + if (timeout.has_value()) { + RAPIDSMPF_EXPECTS( + cv_.wait_for(lock, *timeout, [&] { return can_reuse_; }), + "wait timeout reached", + std::runtime_error + ); + } else { + cv_.wait(lock, [&] { return can_reuse_; }); + } +} + std::size_t Shuffler::spill(std::optional amount) { RAPIDSMPF_NVTX_FUNC_RANGE(); std::size_t spill_need{0}; diff --git a/cpp/tests/test_shuffler.cpp b/cpp/tests/test_shuffler.cpp index 51f7030f0..d975a6f2b 100644 --- a/cpp/tests/test_shuffler.cpp +++ b/cpp/tests/test_shuffler.cpp @@ -918,8 +918,8 @@ TEST(Shuffler, concurrent_wait) { std::ranges::for_each(futures, [](auto& f) { f.get(); }); } -// Test that reusing an OpID after a completed shuffle doesn't cause cross-matching of -// messages between the old and new shuffle. +// Test that reusing an OpID after a fully drained shuffle doesn't cause cross-matching +// of messages between the old and new shuffle. // // On rank 0 we inject a stream-ordered delay into device allocations so that received // chunks stay "not ready" in the event loop. With small messages, other ranks can finish @@ -1017,7 +1017,7 @@ TEST(Shuffler, opid_reuse) { ); insert_data(shuffle1, 42); shuffle1.insert_finished(); - EXPECT_NO_THROW(shuffle1.wait(wait_timeout)); + EXPECT_NO_THROW(shuffle1.wait_reusable(wait_timeout)); rapidsmpf::shuffler::Shuffler shuffle2( comm, op_id, total_num_partitions, shuffler_br @@ -1114,7 +1114,7 @@ TEST(Shuffler, opid_reuse_with_empty_partitions) { ); insert_data(shuffle1, 42); shuffle1.insert_finished(); - EXPECT_NO_THROW(shuffle1.wait(wait_timeout)); + EXPECT_NO_THROW(shuffle1.wait_reusable(wait_timeout)); rapidsmpf::shuffler::Shuffler shuffle2( comm, op_id, total_num_partitions, shuffler_br