diff --git a/client/BUILD.bazel b/client/BUILD.bazel index 8ee0f6d..e183137 100644 --- a/client/BUILD.bazel +++ b/client/BUILD.bazel @@ -81,6 +81,7 @@ cc_test( "@abseil-cpp//absl/status", "@abseil-cpp//absl/status:status_matchers", "@abseil-cpp//absl/status:statusor", + "@abseil-cpp//absl/strings:str_format", "@googletest//:gtest", "@coroutines//:co", ] + select({ diff --git a/client/client_channel.cc b/client/client_channel.cc index ff70418..8585e55 100644 --- a/client/client_channel.cc +++ b/client/client_channel.cc @@ -12,7 +12,7 @@ #include #include #include -#if defined(__ANDROID__) +#if SUBSPACE_SHMEM_MODE == SUBSPACE_SHMEM_MODE_ANDROID #include #ifndef MFD_CLOEXEC #define MFD_CLOEXEC 0x0001U diff --git a/client/client_test.cc b/client/client_test.cc index dd424be..4e6f7ed 100644 --- a/client/client_test.cc +++ b/client/client_test.cc @@ -10,18 +10,26 @@ #include "absl/flags/parse.h" #include "absl/hash/hash_testing.h" #include "absl/status/status.h" +#include "absl/strings/str_format.h" #include "common/system_info.h" #include "toolbelt/clock.h" #include "toolbelt/hexdump.h" #include "toolbelt/pipe.h" #include #include +#include #include #include #include #include #include #include +#if SUBSPACE_SHMEM_MODE == SUBSPACE_SHMEM_MODE_ANDROID +#include +#ifndef MFD_CLOEXEC +#define MFD_CLOEXEC 0x0001U +#endif +#endif #include ABSL_FLAG(bool, start_server, true, "Start the subspace server"); @@ -81,6 +89,28 @@ uint64_t ExpectedSplitBufferVirtualMemoryUsage(int num_slots, AlignPage(slot_size) * static_cast(num_slots); } +#if SUBSPACE_SHMEM_MODE == SUBSPACE_SHMEM_MODE_ANDROID +absl::StatusOr CreateTestMemfd(const char *name, + size_t size) { +#ifdef __NR_memfd_create + int fd = static_cast( + syscall(__NR_memfd_create, name, static_cast(MFD_CLOEXEC))); + if (fd == -1) { + return absl::InternalError(absl::StrFormat( + "Failed to create test memfd %s: %s", name, strerror(errno))); + } + toolbelt::FileDescriptor result(fd); + if (ftruncate(result.Fd(), static_cast(size)) == -1) { + return absl::InternalError(absl::StrFormat( + "Failed to size test memfd %s: %s", name, strerror(errno))); + } + return result; +#else + return absl::UnimplementedError("memfd_create is not available"); +#endif +} +#endif + subspace::SplitBufferCallbacks MakeTestSplitBufferCallbacks( std::shared_ptr state) { subspace::SplitBufferCallbacks callbacks; @@ -198,6 +228,64 @@ TEST_F(ClientTest, Resize1) { ASSERT_EQ(512, pub->SlotSize()); } +#if SUBSPACE_SHMEM_MODE == SUBSPACE_SHMEM_MODE_ANDROID +TEST(AndroidBufferRegistrationTest, FailedRegistrationRollsBackNumBuffers) { + constexpr int kNumSlots = 2; + absl::StatusOr scb_fd = + CreateTestMemfd("subspace_test_scb", sizeof(subspace::SystemControlBlock)); + ASSERT_OK(scb_fd); + absl::StatusOr ccb_fd = + CreateTestMemfd("subspace_test_ccb", subspace::CcbSize(kNumSlots)); + ASSERT_OK(ccb_fd); + absl::StatusOr bcb_fd = CreateTestMemfd( + "subspace_test_bcb", sizeof(subspace::BufferControlBlock)); + ASSERT_OK(bcb_fd); + + subspace::PublisherOptions options; + subspace::details::PublisherImpl publisher( + "android_registration_rollback", kNumSlots, /*channel_id=*/0, + /*publisher_id=*/0, /*vchan_id=*/-1, /*session_id=*/123, "", + options, [](subspace::Channel *) { return false; }, + /*user_id=*/0, /*group_id=*/0); + ASSERT_OK(publisher.Map( + subspace::SharedMemoryFds(std::move(*ccb_fd), std::move(*bcb_fd)), + *scb_fd)); + + int failed_registration_attempts = 0; + publisher.SetClientBufferRegistrationCallback( + [&](const subspace::ClientBufferHandleMetadata &metadata, + const toolbelt::FileDescriptor *fd) { + failed_registration_attempts++; + EXPECT_EQ(0u, metadata.buffer_index); + EXPECT_NE(nullptr, fd); + EXPECT_TRUE(fd->Valid()); + return absl::InternalError("injected registration failure"); + }); + + absl::Status status = publisher.CreateOrAttachBuffers(/*slot_size=*/128); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(1, failed_registration_attempts); + EXPECT_EQ(0, publisher.GetCcb()->num_buffers.load(std::memory_order_relaxed)); + EXPECT_TRUE(publisher.GetBuffers().empty()); + + std::vector registered_indices; + publisher.SetClientBufferRegistrationCallback( + [&](const subspace::ClientBufferHandleMetadata &metadata, + const toolbelt::FileDescriptor *fd) { + EXPECT_NE(nullptr, fd); + EXPECT_TRUE(fd->Valid()); + registered_indices.push_back(metadata.buffer_index); + return absl::OkStatus(); + }); + + ASSERT_OK(publisher.CreateOrAttachBuffers(/*slot_size=*/128)); + ASSERT_EQ(1u, registered_indices.size()); + EXPECT_EQ(0u, registered_indices[0]); + EXPECT_EQ(1, publisher.GetCcb()->num_buffers.load(std::memory_order_relaxed)); + EXPECT_EQ(1u, publisher.GetBuffers().size()); +} +#endif + TEST_F(ClientTest, ResizeCallback) { subspace::Client client; InitClient(client); diff --git a/client/publisher.cc b/client/publisher.cc index 56aa8e6..1ceb437 100644 --- a/client/publisher.cc +++ b/client/publisher.cc @@ -7,6 +7,7 @@ #include "client_channel.h" #include "common/client_buffer.h" #include "toolbelt/clock.h" +#include #include #include namespace subspace { @@ -121,6 +122,22 @@ absl::Status PublisherImpl::CreateOrAttachBuffers(uint64_t final_slot_size) { if (absl::Status status = client_buffer_registration_callback_(metadata, &buffer.fd); !status.ok()) { + int expected_num_buffers = new_num_buffers; + if (ccb_->num_buffers.compare_exchange_strong( + expected_num_buffers, old_num_buffers, + std::memory_order_acq_rel, std::memory_order_relaxed)) { + for (int j = old_num_buffers; j < new_num_buffers; j++) { + bcb_->sizes[j].store(0, std::memory_order_relaxed); + } + } + const size_t rollback_to = + std::min(static_cast(old_num_buffers), + buffers_.size()); + for (size_t j = rollback_to; j < buffers_.size(); j++) { + UnmapBufferSet(j, *buffers_[j], + /*destroy_owned_buffers=*/false); + } + buffers_.resize(rollback_to); return status; } } diff --git a/common/channel.h b/common/channel.h index 1d94dfe..021c4ae 100644 --- a/common/channel.h +++ b/common/channel.h @@ -27,7 +27,9 @@ namespace subspace { #define SUBSPACE_SHMEM_MODE_LINUX 2 #define SUBSPACE_SHMEM_MODE_ANDROID 3 -// Change this if you want to use a different shared memory mode. +// Change this if you want to use a different shared memory mode. Builds may +// define SUBSPACE_SHMEM_MODE explicitly to exercise a non-default backend. +#ifndef SUBSPACE_SHMEM_MODE #if defined(__ANDROID__) // Android does not have /dev/shm; use anonymous fd-backed shared memory. #define SUBSPACE_SHMEM_MODE SUBSPACE_SHMEM_MODE_ANDROID @@ -39,6 +41,7 @@ namespace subspace { // memory. #define SUBSPACE_SHMEM_MODE SUBSPACE_SHMEM_MODE_POSIX #endif +#endif // Flag for flags field in MessagePrefix. constexpr int kMessageActivate = 1; // This is a reliable activation message. diff --git a/server/server_channel.cc b/server/server_channel.cc index 1ca8094..3286922 100644 --- a/server/server_channel.cc +++ b/server/server_channel.cc @@ -7,7 +7,7 @@ #include "server/server.h" #include #include -#if defined(__ANDROID__) +#if SUBSPACE_SHMEM_MODE == SUBSPACE_SHMEM_MODE_ANDROID #include #ifndef MFD_CLOEXEC #define MFD_CLOEXEC 0x0001U