diff --git a/cpp/include/rapidsmpf/memory/spill_manager.hpp b/cpp/include/rapidsmpf/memory/spill_manager.hpp index 43892fb7a..773a61d5e 100644 --- a/cpp/include/rapidsmpf/memory/spill_manager.hpp +++ b/cpp/include/rapidsmpf/memory/spill_manager.hpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -113,7 +114,11 @@ class SpillManager { std::size_t spill_to_make_headroom(std::int64_t headroom = 0); private: - mutable std::mutex mutex_; + // `spill()` takes a shared lock so spill work runs in parallel; add/remove take + // exclusive, which drains in-flight spillers before returning -- callers rely on this + // for safe teardown. SpillFunctions must be safe to invoke concurrently + // with themselves. + std::shared_mutex mutex_; BufferResource* br_; std::size_t spill_function_id_counter_{0}; std::map spill_functions_; diff --git a/cpp/src/memory/spill_manager.cpp b/cpp/src/memory/spill_manager.cpp index 6b20150d7..8c8f2ca6e 100644 --- a/cpp/src/memory/spill_manager.cpp +++ b/cpp/src/memory/spill_manager.cpp @@ -33,7 +33,7 @@ SpillManager::~SpillManager() { SpillManager::SpillFunctionID SpillManager::add_spill_function( SpillFunction spill_function, int priority ) { - std::lock_guard lock(mutex_); + std::unique_lock lock(mutex_); auto const id = spill_function_id_counter_++; RAPIDSMPF_EXPECTS( spill_functions_.insert({id, std::move(spill_function)}).second, @@ -49,7 +49,7 @@ SpillManager::SpillFunctionID SpillManager::add_spill_function( } void SpillManager::remove_spill_function(SpillFunctionID fid) { - std::lock_guard lock(mutex_); + std::unique_lock lock(mutex_); auto& prio = spill_function_priorities_; for (auto it = prio.begin(); it != prio.end(); ++it) { if (it->second == fid) { @@ -68,7 +68,7 @@ void SpillManager::remove_spill_function(SpillFunctionID fid) { std::size_t SpillManager::spill(std::size_t amount) { RAPIDSMPF_NVTX_FUNC_RANGE(); std::size_t spilled{0}; - std::unique_lock lock(mutex_); + std::shared_lock lock(mutex_); for (auto const [_, fid] : spill_function_priorities_) { if (spilled >= amount) { break; diff --git a/cpp/tests/test_spill_manager.cpp b/cpp/tests/test_spill_manager.cpp index e4cbc541e..ac46bf010 100644 --- a/cpp/tests/test_spill_manager.cpp +++ b/cpp/tests/test_spill_manager.cpp @@ -4,6 +4,13 @@ */ +#include +#include +#include +#include +#include +#include + #include #include @@ -74,3 +81,56 @@ TEST(SpillManager, SpillFunction) { EXPECT_EQ(br.spill_manager().spill_to_make_headroom(-100_KiB), 0); EXPECT_EQ(br.memory_available(MemoryType::DEVICE)(), 100_KiB); } + +// Verify that multiple concurrent `spill()` calls execute the spill function in +// parallel. The spill function blocks at a rendezvous point until all worker +// threads have entered it; with an exclusive lock this would deadlock and the +// test would time out. +TEST(SpillManager, ConcurrentSpill) { + constexpr int num_threads = 4; + constexpr auto rendezvous_timeout = std::chrono::seconds(5); + + BufferResource br{ + cudf::get_current_device_resource_ref(), + rapidsmpf::PinnedMemoryResource::Disabled, + {{MemoryType::DEVICE, []() -> std::int64_t { return 0; }}} + }; + + std::mutex m; + std::condition_variable cv; + int entered = 0; + + SpillManager::SpillFunction func = [&](std::size_t amount) -> std::size_t { + std::unique_lock lock(m); + ++entered; + if (entered == num_threads) { + cv.notify_all(); + } + // If spill() calls run in parallel, all threads quickly reach + // `entered == num_threads` and proceed. If they're serialized, every + // thread but the last times out here, so a successful (non-timeout) + // wait_for is proof that the calls actually overlapped. + EXPECT_TRUE(cv.wait_for(lock, rendezvous_timeout, [&] { + return entered == num_threads; + })) << "spill() calls did not run concurrently"; + return amount; + }; + + br.spill_manager().add_spill_function(func, /* priority = */ 0); + + std::vector> futures; + futures.reserve(num_threads); + for (int i = 0; i < num_threads; ++i) { + futures.emplace_back(std::async(std::launch::async, [&] { + return br.spill_manager().spill(1_KiB); + })); + } + auto const total_spilled = std::accumulate( + futures.begin(), + futures.end(), + std::size_t{0}, + [](std::size_t sum, std::future& f) { return sum + f.get(); } + ); + + EXPECT_EQ(total_spilled, num_threads * 1_KiB); +}