Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion cpp/include/rapidsmpf/memory/spill_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <map>
#include <mutex>
#include <optional>
#include <shared_mutex>

#include <rapidsmpf/pausable_thread_loop.hpp>
#include <rapidsmpf/utils/misc.hpp>
Expand Down Expand Up @@ -113,7 +114,11 @@ class SpillManager {
std::size_t spill_to_make_headroom(std::int64_t headroom = 0);

private:
mutable std::mutex mutex_;
// `spill()` takes a shared lock so spill work runs in parallel; add/remove take
// exclusive, which drains in-flight spillers before returning -- callers rely on this
// for safe teardown. SpillFunctions must be safe to invoke concurrently
// with themselves.
std::shared_mutex mutex_;
BufferResource* br_;
std::size_t spill_function_id_counter_{0};
std::map<SpillFunctionID, SpillFunction> spill_functions_;
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/memory/spill_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ SpillManager::~SpillManager() {
SpillManager::SpillFunctionID SpillManager::add_spill_function(
SpillFunction spill_function, int priority
) {
std::lock_guard<std::mutex> lock(mutex_);
std::unique_lock lock(mutex_);
auto const id = spill_function_id_counter_++;
RAPIDSMPF_EXPECTS(
spill_functions_.insert({id, std::move(spill_function)}).second,
Expand All @@ -49,7 +49,7 @@ SpillManager::SpillFunctionID SpillManager::add_spill_function(
}

void SpillManager::remove_spill_function(SpillFunctionID fid) {
std::lock_guard<std::mutex> lock(mutex_);
std::unique_lock lock(mutex_);
auto& prio = spill_function_priorities_;
for (auto it = prio.begin(); it != prio.end(); ++it) {
if (it->second == fid) {
Expand All @@ -68,7 +68,7 @@ void SpillManager::remove_spill_function(SpillFunctionID fid) {
std::size_t SpillManager::spill(std::size_t amount) {
RAPIDSMPF_NVTX_FUNC_RANGE();
std::size_t spilled{0};
std::unique_lock<std::mutex> lock(mutex_);
std::shared_lock lock(mutex_);
for (auto const [_, fid] : spill_function_priorities_) {
if (spilled >= amount) {
break;
Expand Down
60 changes: 60 additions & 0 deletions cpp/tests/test_spill_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@
*/


#include <chrono>
#include <condition_variable>
#include <future>
#include <mutex>
#include <numeric>
#include <vector>

#include <gtest/gtest.h>

#include <cudf_test/base_fixture.hpp>
Expand Down Expand Up @@ -74,3 +81,56 @@ TEST(SpillManager, SpillFunction) {
EXPECT_EQ(br.spill_manager().spill_to_make_headroom(-100_KiB), 0);
EXPECT_EQ(br.memory_available(MemoryType::DEVICE)(), 100_KiB);
}

// Verify that multiple concurrent `spill()` calls execute the spill function in
// parallel. The spill function blocks at a rendezvous point until all worker
// threads have entered it; with an exclusive lock this would deadlock and the
// test would time out.
TEST(SpillManager, ConcurrentSpill) {
constexpr int num_threads = 4;
constexpr auto rendezvous_timeout = std::chrono::seconds(5);

BufferResource br{
cudf::get_current_device_resource_ref(),
rapidsmpf::PinnedMemoryResource::Disabled,
{{MemoryType::DEVICE, []() -> std::int64_t { return 0; }}}
};

std::mutex m;
std::condition_variable cv;
int entered = 0;

SpillManager::SpillFunction func = [&](std::size_t amount) -> std::size_t {
std::unique_lock<std::mutex> lock(m);
++entered;
if (entered == num_threads) {
cv.notify_all();
}
// If spill() calls run in parallel, all threads quickly reach
// `entered == num_threads` and proceed. If they're serialized, every
// thread but the last times out here, so a successful (non-timeout)
// wait_for is proof that the calls actually overlapped.
EXPECT_TRUE(cv.wait_for(lock, rendezvous_timeout, [&] {
return entered == num_threads;
})) << "spill() calls did not run concurrently";
return amount;
};

br.spill_manager().add_spill_function(func, /* priority = */ 0);

std::vector<std::future<std::size_t>> futures;
futures.reserve(num_threads);
for (int i = 0; i < num_threads; ++i) {
futures.emplace_back(std::async(std::launch::async, [&] {
return br.spill_manager().spill(1_KiB);
}));
}
auto const total_spilled = std::accumulate(
futures.begin(),
futures.end(),
std::size_t{0},
[](std::size_t sum, std::future<std::size_t>& f) { return sum + f.get(); }
);

EXPECT_EQ(total_spilled, num_threads * 1_KiB);
}
Loading