Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions cpp/benchmarks/bench_shuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class ArgumentParser {
}
try {
int option;
while ((option = getopt(argc, argv, "C:r:w:c:n:p:o:m:l:LigsbxhM:")) != -1) {
while ((option = getopt(argc, argv, "C:r:w:c:n:p:o:m:l:LigsbdxhM:")) != -1) {
switch (option) {
case 'h':
{
Expand All @@ -77,7 +77,9 @@ class ArgumentParser {
" unlimited)\n"
<< " -g Use pre-partitioned (hash) input tables "
"(default: unset, hash partition during insertion)\n"
<< " -s Discard output chunks to simulate streaming "
<< " -s Discard output chunks after extract and "
"concat to simulate streaming (default: disabled)\n"
<< " -d Discard result after shuffle completes "
"(default: disabled)\n"
<< " -b Disallow memory overbooking when generating "
"input data (default: allow memory overbooking)\n"
Expand Down Expand Up @@ -159,6 +161,9 @@ class ArgumentParser {
case 's':
enable_output_discard = true;
break;
case 'd':
just_shuffle = true;
break;
case 'b':
input_data_allow_overbooking = rapidsmpf::AllowOverbooking::NO;
break;
Expand Down Expand Up @@ -235,6 +240,9 @@ class ArgumentParser {
if (enable_output_discard) {
ss << " -s (enable output discard to simulate streaming)\n";
}
if (just_shuffle) {
ss << " -d (only shuffle, no extraction)\n";
}
if (input_data_allow_overbooking == rapidsmpf::AllowOverbooking::NO) {
ss << " -b (disallow memory overbooking when generating input data)\n";
}
Expand Down Expand Up @@ -271,6 +279,7 @@ class ArgumentParser {
std::int64_t device_mem_limit_mb{-1};
bool pinned_mem_disable{false};
bool enable_cupti_monitoring{false};
bool just_shuffle{false};
std::string cupti_csv_prefix;
};

Expand Down Expand Up @@ -315,20 +324,21 @@ rapidsmpf::Duration do_run(
shuffle_insert_fn(shuffler);

shuffler.wait();
for (auto finished_partition : shuffler.local_partitions()) {
auto packed_chunks = shuffler.extract(finished_partition);
auto output_partition = rapidsmpf::unpack_and_concat(
rapidsmpf::unspill_partitions(
std::move(packed_chunks), br, rapidsmpf::AllowOverbooking::YES
),
stream,
br
);
if (!args.enable_output_discard) {
output_partitions.emplace_back(std::move(output_partition));
if (!args.just_shuffle) {
for (auto finished_partition : shuffler.local_partitions()) {
auto packed_chunks = shuffler.extract(finished_partition);
auto output_partition = rapidsmpf::unpack_and_concat(
rapidsmpf::unspill_partitions(
std::move(packed_chunks), br, rapidsmpf::AllowOverbooking::YES
),
stream,
br
);
if (!args.enable_output_discard) {
output_partitions.emplace_back(std::move(output_partition));
}
}
}
stream.synchronize();
}

auto const elapsed = rapidsmpf::Clock::now() - t0_elapsed;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ class MetadataPayloadExchange {
* @return `true` if the communication layer is idle; `false` if activity is ongoing.
*/
[[nodiscard]] virtual bool is_idle() const = 0;
[[nodiscard]] virtual bool finished_polling() const = 0;
};


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class TagMetadataPayloadExchange : public MetadataPayloadExchange {
* @copydoc MetadataPayloadExchange::is_idle
*/
bool is_idle() const override;
bool finished_polling() const override;

private:
/**
Expand Down
19 changes: 6 additions & 13 deletions cpp/include/rapidsmpf/shuffler/chunk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,24 +228,17 @@ class Chunk {
* @brief Create a chunk by deserializing a metadata message.
*
* @param msg The metadata message received from another rank.
* @param br Buffer resource for allocating the data buffer of the deserialized
* message. Ignored when @p data is provided.
* @param br Buffer resource for allocating a the data buffer of the deserialized
* message.
* @param validate Whether to validate the metadata buffer.
* @param data Optional pre-existing data buffer to use instead of allocating a new
* one. When non-null, the buffer resource allocation is skipped entirely, avoiding
* unnecessary memory pressure from a temporary allocation.
* @return The chunk.
*
* @throws std::logic_error if the chunk is not a control message and neither @p data
* nor @p br is provided.
* @throws std::runtime_error if the metadata buffer does not follow the expected
* format and @p validate is true.
* @throws std::logic_error if the chunk is not a control message and no buffer
* resource is provided. @throws std::runtime_error if the metadata buffer does not
* follow the expected format and `validate` is true.
*/
static Chunk deserialize(
std::vector<std::uint8_t> const& msg,
BufferResource* br,
bool validate = true,
std::unique_ptr<Buffer> data = nullptr
std::vector<std::uint8_t> const& msg, BufferResource* br, bool validate = true
);

/**
Expand Down
23 changes: 4 additions & 19 deletions cpp/include/rapidsmpf/shuffler/shuffler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#include <vector>

#include <rapidsmpf/communicator/communicator.hpp>
#include <rapidsmpf/communicator/metadata_payload_exchange/tag.hpp>
#include <rapidsmpf/error.hpp>
#include <rapidsmpf/memory/buffer_resource.hpp>
#include <rapidsmpf/memory/packed_data.hpp>
Expand Down Expand Up @@ -119,8 +118,6 @@ class Shuffler {
* @param br Buffer resource used to allocate temporary and the shuffle result.
* @param finished_callback Callback to notify when all partitions are finished.
* @param partition_owner Function to determine partition ownership.
* @param mpe Optional custom metadata payload exchange. If not provided,
* uses the default tag-based implementation.
*
* @note It is safe to reuse the `op_id` as soon as `wait` has completed
* locally.
Expand All @@ -135,8 +132,7 @@ class Shuffler {
PartID total_num_partitions,
BufferResource* br,
FinishedCallback&& finished_callback,
PartitionOwner partition_owner = round_robin,
std::unique_ptr<communicator::MetadataPayloadExchange> mpe = nullptr
PartitionOwner partition_owner = round_robin
);

/**
Expand All @@ -148,8 +144,6 @@ class Shuffler {
* @param total_num_partitions Total number of partitions in the shuffle.
* @param br Buffer resource used to allocate temporary and the shuffle result.
* @param partition_owner Function to determine partition ownership.
* @param mpe Optional custom metadata payload exchange. If not provided,
* uses the default tag-based implementation.
*
* @note The caller promises that inserted buffers are stream-ordered with respect
* to their own stream, and extracted buffers are likewise guaranteed to be stream-
Expand All @@ -160,18 +154,9 @@ class Shuffler {
OpID op_id,
PartID total_num_partitions,
BufferResource* br,
PartitionOwner partition_owner = round_robin,
std::unique_ptr<communicator::MetadataPayloadExchange> mpe = nullptr
PartitionOwner partition_owner = round_robin
)
: Shuffler(
comm,
op_id,
total_num_partitions,
br,
nullptr,
partition_owner,
std::move(mpe)
) {}
: Shuffler(comm, op_id, total_num_partitions, br, nullptr, partition_owner) {}

~Shuffler();

Expand Down Expand Up @@ -328,12 +313,12 @@ class Shuffler {
// Flipped to true exactly once when partitions are ready for extraction and we've
// posted all sends we're going to
bool can_extract_{false};
OpID const op_id_;
detail::ChunksToSend to_send_; ///< Storage for chunks to send to other ranks.
detail::ReceivedChunks received_; ///< Storage for received chunks that are
///< ready to be extracted by the user.

std::shared_ptr<Communicator> comm_;
std::unique_ptr<communicator::MetadataPayloadExchange> mpe_;
ProgressThread::FunctionID progress_thread_function_id_;

SpillManager::SpillFunctionID spill_function_id_;
Expand Down
25 changes: 20 additions & 5 deletions cpp/src/communicator/metadata_payload_exchange/tag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ void TagMetadataPayloadExchange::finish() {
// exactly how many application messages we sent to it, so the peer can
// stop receiving once it has them all.
// Format: [sentinel=UINT64_MAX (8 bytes)][message_count (8 bytes)]
for (Rank peer = 0; peer < nranks_; ++peer) {
for (Rank i = 0; i < nranks_; ++i) {
auto const peer = (i + rank_) % nranks_;
if (peer == rank_) {
continue;
}
Expand All @@ -158,6 +159,22 @@ void TagMetadataPayloadExchange::finish() {
}
}

bool TagMetadataPayloadExchange::finished_polling() const {
if (finished_) {
for (Rank peer = 0; peer < nranks_; ++peer) {
if (peer == rank_) {
continue;
}
auto const p = safe_cast<std::size_t>(peer);
if (!peer_terminated_[p] || peer_received_[p] < peer_expected_[p]) {
return false;
}
}
return true;
}
return false;
}

bool TagMetadataPayloadExchange::is_idle() const {
bool const io_idle = fire_and_forget_.empty() && incoming_messages_.empty()
&& in_transit_messages_.empty() && in_transit_futures_.empty();
Expand All @@ -184,7 +201,8 @@ void TagMetadataPayloadExchange::receive_metadata() {

// Use per-peer recv_from to avoid consuming messages belonging to a future
// collective on the same tag (see rapidsai/rapidsmpf#927).
for (Rank peer = 0; peer < nranks_; ++peer) {
for (Rank i = 0; i < nranks_; ++i) {
auto const peer = (i + rank_) % nranks_;
if (peer == rank_) {
continue;
}
Expand Down Expand Up @@ -316,9 +334,6 @@ TagMetadataPayloadExchange::setup_data_receives() {
);
// Store in per-rank vector to maintain order
in_transit_messages_[src].push_back(std::move(tag_message));
// Break to ensure we don't return later messages before this one
// completes
break;
} else {
// Control/metadata-only message
// Only return if there are no earlier in-transit messages from this rank
Expand Down
8 changes: 3 additions & 5 deletions cpp/src/shuffler/chunk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,7 @@ Chunk Chunk::from_finished_partition(
}

Chunk Chunk::deserialize(
std::vector<std::uint8_t> const& msg,
BufferResource* br,
bool validate,
std::unique_ptr<Buffer> data
std::vector<std::uint8_t> const& msg, BufferResource* br, bool validate
) {
if (validate) {
RAPIDSMPF_EXPECTS(
Expand Down Expand Up @@ -91,7 +88,8 @@ Chunk Chunk::deserialize(
msg.begin() + safe_cast<std::int64_t>(offset), msg.end()
);

if (!data && expected_num_chunks == 0) {
std::unique_ptr<Buffer> data;
if (expected_num_chunks == 0) {
RAPIDSMPF_EXPECTS(
br != nullptr, "Deserializing non-control Chunk requires a BufferResource"
);
Expand Down
Loading
Loading