Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions cpp/include/rapidsmpf/shuffler/shuffler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ class Shuffler {
* @param mpe Optional custom metadata payload exchange. If not provided,
* uses the default tag-based implementation.
*
* @note It is safe to reuse the `op_id` as soon as `wait` has completed
* locally.
* @note It is safe to reuse the `op_id` as soon as `wait_reusable()` has
* completed locally, or after the shuffler has been shut down or destroyed.
*
* @note The caller promises that inserted buffers are stream-ordered with respect
* to their own stream, and extracted buffers are likewise guaranteed to be stream-
Expand All @@ -143,14 +143,16 @@ class Shuffler {
* @brief Construct a new shuffler for a single shuffle.
*
* @param comm The communicator to use.
* @param op_id The operation ID of the shuffle. This ID is unique for this operation,
* and should not be reused until all nodes has called `Shuffler::shutdown()`.
* @param op_id The operation ID of the shuffle.
* @param total_num_partitions Total number of partitions in the shuffle.
* @param br Buffer resource used to allocate temporary and the shuffle result.
* @param partition_owner Function to determine partition ownership.
* @param mpe Optional custom metadata payload exchange. If not provided,
* uses the default tag-based implementation.
*
* @note It is safe to reuse the `op_id` as soon as `wait_reusable()` has
* completed locally, or after the shuffler has been shut down or destroyed.
*
* @note The caller promises that inserted buffers are stream-ordered with respect
* to their own stream, and extracted buffers are likewise guaranteed to be stream-
* ordered with respect to their own stream.
Expand Down Expand Up @@ -240,12 +242,27 @@ class Shuffler {
/**
* @brief Wait for all partitions to finish (blocking).
*
* Once this returns, local partitions are ready to extract. Internal send and
* protocol cleanup may still be progressing in the background.
*
* @param timeout Optional timeout (ms) to wait.
*
* @throws std::runtime_error if the timeout is reached.
*/
void wait(std::optional<std::chrono::milliseconds> timeout = {});

/**
* @brief Wait until this shuffle is fully drained and its op_id is reusable.
*
* This is stronger than `wait()`: it also waits for all internal communication
* state to drain, including protocol messages used to protect op_id reuse.
*
* @param timeout Optional timeout (ms) to wait.
*
* @throws std::runtime_error if the timeout is reached.
*/
void wait_reusable(std::optional<std::chrono::milliseconds> timeout = {});

/**
* @brief Spills data to device if necessary.
*
Expand Down Expand Up @@ -325,9 +342,10 @@ class Shuffler {
std::atomic<bool> active_{true};
// Have we called `insert_finished()` on this rank.
std::atomic<bool> locally_finished_{false};
// Flipped to true exactly once when partitions are ready for extraction and we've
// posted all sends we're going to
// Flipped to true exactly once when partitions are ready for extraction.
bool can_extract_{false};
// Flipped to true exactly once when all internal communication has drained.
bool can_reuse_{false};
detail::ChunksToSend to_send_; ///< Storage for chunks to send to other ranks.
detail::ReceivedChunks received_; ///< Storage for received chunks that are
///< ready to be extracted by the user.
Expand Down
45 changes: 35 additions & 10 deletions cpp/src/shuffler/shuffler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,20 +182,31 @@ class Shuffler::Progress {
// Finished and shuffler is no longer active.
bool const is_done = !shuffler_.active_.load(std::memory_order_acquire)
&& is_finished && containers_empty;
// Signal can_extract_ when all chunks have been received and all internal
// containers are drained. If we own no partitions we "can-extract" immediately,
// but we only wake a waiter once we've drained internal containers so that we can
// reuse the op_id for a subsequent shuffle.
if (!shuffler_.can_extract_ && is_finished && containers_empty) {
{
std::lock_guard lock(shuffler_.mutex_);
// Signal can_extract_ as soon as local outputs are complete so extraction can
// overlap with remaining send/protocol cleanup. Signal can_reuse_ only after all
// internal state is drained.
bool const ready_to_extract = is_finished;
bool const ready_to_reuse = is_finished && containers_empty;
bool notify = false;
FinishedCallback callback;
{
std::lock_guard lock(shuffler_.mutex_);
if (!shuffler_.can_extract_ && ready_to_extract) {
shuffler_.can_extract_ = true;
callback = std::move(shuffler_.finished_callback_);
notify = true;
}
shuffler_.cv_.notify_all();
if (auto callback = std::move(shuffler_.finished_callback_)) {
callback();
if (!shuffler_.can_reuse_ && ready_to_reuse) {
shuffler_.can_reuse_ = true;
notify = true;
}
}
if (notify) {
shuffler_.cv_.notify_all();
}
if (callback) {
callback();
}
return is_done ? ProgressThread::ProgressState::Done
: ProgressThread::ProgressState::InProgress;
}
Expand Down Expand Up @@ -435,6 +446,20 @@ void Shuffler::wait(std::optional<std::chrono::milliseconds> timeout) {
}
}

void Shuffler::wait_reusable(std::optional<std::chrono::milliseconds> timeout) {
RAPIDSMPF_NVTX_FUNC_RANGE();
std::unique_lock lock(mutex_);
if (timeout.has_value()) {
RAPIDSMPF_EXPECTS(
cv_.wait_for(lock, *timeout, [&] { return can_reuse_; }),
"wait timeout reached",
std::runtime_error
);
} else {
cv_.wait(lock, [&] { return can_reuse_; });
}
}

std::size_t Shuffler::spill(std::optional<std::size_t> amount) {
RAPIDSMPF_NVTX_FUNC_RANGE();
std::size_t spill_need{0};
Expand Down
8 changes: 4 additions & 4 deletions cpp/tests/test_shuffler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -918,8 +918,8 @@ TEST(Shuffler, concurrent_wait) {
std::ranges::for_each(futures, [](auto& f) { f.get(); });
}

// Test that reusing an OpID after a completed shuffle doesn't cause cross-matching of
// messages between the old and new shuffle.
// Test that reusing an OpID after a fully drained shuffle doesn't cause cross-matching
// of messages between the old and new shuffle.
//
// On rank 0 we inject a stream-ordered delay into device allocations so that received
// chunks stay "not ready" in the event loop. With small messages, other ranks can finish
Expand Down Expand Up @@ -1017,7 +1017,7 @@ TEST(Shuffler, opid_reuse) {
);
insert_data(shuffle1, 42);
shuffle1.insert_finished();
EXPECT_NO_THROW(shuffle1.wait(wait_timeout));
EXPECT_NO_THROW(shuffle1.wait_reusable(wait_timeout));

rapidsmpf::shuffler::Shuffler shuffle2(
comm, op_id, total_num_partitions, shuffler_br
Expand Down Expand Up @@ -1114,7 +1114,7 @@ TEST(Shuffler, opid_reuse_with_empty_partitions) {
);
insert_data(shuffle1, 42);
shuffle1.insert_finished();
EXPECT_NO_THROW(shuffle1.wait(wait_timeout));
EXPECT_NO_THROW(shuffle1.wait_reusable(wait_timeout));

rapidsmpf::shuffler::Shuffler shuffle2(
comm, op_id, total_num_partitions, shuffler_br
Expand Down
Loading