From 22f15ff4e6a9ec7689573a357dae94a15de78184 Mon Sep 17 00:00:00 2001 From: Kurt McMillan Date: Mon, 3 Nov 2025 22:24:12 +0000 Subject: [PATCH 1/2] rocfile: Move file inspection in DriverState::registerFile(...) into UnregisteredFile --- rocfile/src/file.cpp | 43 ++++++++++++++++++++++++++++----- rocfile/src/file.h | 53 +++++++++++++++++++++++++++++++---------- rocfile/src/rocfile.cpp | 12 ++++------ rocfile/src/state.cpp | 10 ++------ rocfile/src/state.h | 4 ++-- rocfile/test/driver.cpp | 6 ++--- rocfile/test/mfile.h | 4 +--- rocfile/test/mstate.h | 2 +- 8 files changed, 91 insertions(+), 43 deletions(-) diff --git a/rocfile/src/file.cpp b/rocfile/src/file.cpp index 1b1cf1e2..66f8b11c 100644 --- a/rocfile/src/file.cpp +++ b/rocfile/src/file.cpp @@ -18,14 +18,45 @@ using std::vector; namespace rocFile { +UnregisteredFile::UnregisteredFile(int fd) + : m_fd(fd), m_stat{Context::get()->fstat(fd)}, m_flags{Context::get()->fcntl(fd, F_GETFL, 0)}, + m_mountinfo{Context::get()->getMountInfo(m_stat.st_dev)} +{ +} + +int +UnregisteredFile::getFd() const noexcept +{ + return m_fd; +} + +struct stat +UnregisteredFile::getStat() const noexcept +{ + return m_stat; +} + +int +UnregisteredFile::getFlags() const noexcept +{ + return m_flags; +} + +optional +UnregisteredFile::getMountInfo() const noexcept +{ + return m_mountinfo; +} + rocFileHandle_t IFile::getHandle() const { return reinterpret_cast(const_cast(this)); } -File::File(int _fd, const struct stat &fstat, int _status_flags, optional _mountinfo) - : fd{_fd}, device{fstat.st_dev}, mode{fstat.st_mode}, status_flags{_status_flags}, mountinfo{_mountinfo} +File::File(const UnregisteredFile &uf) + : fd{uf.getFd()}, device{uf.getStat().st_dev}, mode{uf.getStat().st_mode}, status_flags{uf.getFlags()}, + mountinfo{uf.getMountInfo()} { } @@ -71,14 +102,14 @@ FileMap::getFile(rocFileHandle_t fh) } rocFileHandle_t -FileMap::registerFile(int fd, struct stat &fstat, int status_flags, optional mountinfo) +FileMap::registerFile(const UnregisteredFile &uf) { - if (from_fd.end() != from_fd.find(fd)) { + if (from_fd.end() != from_fd.find(uf.getFd())) { throw FileAlreadyRegistered(); } - auto file = std::shared_ptr(new File(fd, fstat, status_flags, mountinfo)); - from_fd[fd] = file; + auto file = std::shared_ptr(new File(uf)); + from_fd[file->getFd()] = file; from_fh[file->getHandle()] = file; return file->getHandle(); diff --git a/rocfile/src/file.h b/rocfile/src/file.h index 0697e96d..0b5dff6b 100644 --- a/rocfile/src/file.h +++ b/rocfile/src/file.h @@ -38,6 +38,42 @@ struct FileOperationsOutstanding : public std::runtime_error { } }; +class UnregisteredFile { +public: + /// @brief Construct an unregistered file + /// + /// During construction of an unregistered file, information about the file + /// is collected from the system. + /// + /// @param fd A valid file descriptor + UnregisteredFile(int fd); + + /// @return Returns the file descriptor + int getFd() const noexcept; + + /// @return Returns the information provided by fstat (2) + struct stat getStat() const noexcept; + + /// @return Returns the flags provided by fcntl (2) + int getFlags() const noexcept; + + /// @brief Returns information obtained from /proc/self/mountinfo + std::optional getMountInfo() const noexcept; + +private: + /// @brief The file descriptor + int m_fd; + + /// @brief Information provided by fstat (2) + struct stat m_stat; + + /// @brief Flags provided by fcntl(2) + int m_flags; + + /// @brief Information obtained from /proc/self/mountinfo + std::optional m_mountinfo; +}; + class IFile { public: virtual ~IFile() = default; @@ -73,11 +109,9 @@ class File : public IFile { virtual std::optional getMountInfo() const override; private: - /// @brief Construct a file object - /// @param fd The file descriptor for the file - /// @param fstat The struct stat value obtained by calling fstat(2) with fd - /// @param mountinfo Mount information for the filesystem backing fd - File(int fd, const struct stat &fstat, int status_flags, std::optional mountinfo); + /// @brief Construct a registered file + /// @param uf An unregistered file + File(const UnregisteredFile &uf); /// @brief The file descriptor int fd; @@ -109,13 +143,8 @@ class FileMap { /// @brief Registers a file. Files must be registered before they can be used with rocFile IO APIs /// @attention A unique_lock on RocFileMutex must be held - /// @param fd An open file descriptor - /// @param fstat The struct stat value obtained by calling fstat(2) with fd - /// @param status_flags The fd's status flags - /// @param mountinfo Mount information for the filesystem backing fd - /// @return A handle to be used when calling rocFile IO APIs - virtual rocFileHandle_t registerFile(int fd, struct stat &fstat, int status_flags, - std::optional mountinfo); + /// @param uf An unregistered file + virtual rocFileHandle_t registerFile(const UnregisteredFile &uf); /// @brief Deregisters the file associated with the provided file handle /// @attention A unique_lock on RocFileMutex must be held diff --git a/rocfile/src/rocfile.cpp b/rocfile/src/rocfile.cpp index b52d900e..307463a0 100644 --- a/rocfile/src/rocfile.cpp +++ b/rocfile/src/rocfile.cpp @@ -132,15 +132,11 @@ try { } switch (descr->type) { - case rocFileHandleTypeOpaqueFD: - // Validate - if (descr->handle.fd < 0) { - return {rocFileInvalidValue, hipSuccess}; - } - - // Register - *fh = Context::get()->registerFile(descr->handle.fd); + case rocFileHandleTypeOpaqueFD: { + UnregisteredFile uf{descr->handle.fd}; + *fh = Context::get()->registerFile(uf); return {rocFileSuccess, hipSuccess}; + } case rocFileHandleTypeOpaqueWin32: case rocFileHandleTypeUserspaceFS: default: diff --git a/rocfile/src/state.cpp b/rocfile/src/state.cpp index fdfe9526..0cf0ef53 100644 --- a/rocfile/src/state.cpp +++ b/rocfile/src/state.cpp @@ -116,14 +116,8 @@ DriverState::getBuffer(const void *buf, size_t length, int flags) // rocFileHandle_t -DriverState::registerFile(int fd) +DriverState::registerFile(const UnregisteredFile &uf) { - // Get file information outside of the state mutex to avoid potentially - // stalling other IO threads - auto fstat{Context::get()->fstat(fd)}; - auto status_flags{Context::get()->fcntl(fd, F_GETFL, 0)}; - auto mountinfo{Context::get()->getMountInfo(fstat.st_dev)}; - unique_lock ulock{state_mutex}; // For NVIDIA cuFile compatibility, implicitly "initialize" @@ -132,7 +126,7 @@ DriverState::registerFile(int fd) ref_count++; } - return file_map->registerFile(fd, fstat, status_flags, mountinfo); + return file_map->registerFile(uf); } void diff --git a/rocfile/src/state.h b/rocfile/src/state.h index 260894a1..f939c39d 100644 --- a/rocfile/src/state.h +++ b/rocfile/src/state.h @@ -106,9 +106,9 @@ class DriverState { // /// @brief Registers a file. Files must be registered before they can be used with rocFile IO APIs - /// @param [in] fd An open file descriptor + /// @param [in] uf An unregistered file /// @return A handle to be used when calling rocFile IO APIs - virtual rocFileHandle_t registerFile(int fd); + virtual rocFileHandle_t registerFile(const UnregisteredFile &uf); /// @brief Deregisters the file associated with the provided file handle /// @param [in] fh The handle of the file to deregister diff --git a/rocfile/test/driver.cpp b/rocfile/test/driver.cpp index c36f42a7..15d49e0b 100644 --- a/rocfile/test/driver.cpp +++ b/rocfile/test/driver.cpp @@ -19,8 +19,7 @@ #include using namespace rocFile; - -using ::testing::StrictMock; +using namespace testing; // Put tests inside the macros to suppress the global constructor // warnings @@ -109,7 +108,8 @@ TEST_F(RocFileDriverAdmin, HandleRegisterBadFD) descr.handle.fd = -1; descr.type = rocFileHandleTypeOpaqueFD; - descr.fs_ops = nullptr; + + EXPECT_CALL(msys, fstat).WillOnce(Throw(Sys::RuntimeError(EBADF))); ASSERT_EQ(rocFileUseCount(), 0); ASSERT_NE(rocFileHandleRegister(&handle, &descr), ROCFILE_SUCCESS); diff --git a/rocfile/test/mfile.h b/rocfile/test/mfile.h index 0a80b42a..e1d54bc0 100644 --- a/rocfile/test/mfile.h +++ b/rocfile/test/mfile.h @@ -31,9 +31,7 @@ class MFileMap : public FileMap { MFileMap() { } - MOCK_METHOD(rocFileHandle_t, registerFile, - (int fd, struct stat &fstat, int _status_flags, std::optional mountinfo), - (override)); + MOCK_METHOD(rocFileHandle_t, registerFile, (const UnregisteredFile &uf), (override)); MOCK_METHOD(void, deregisterFile, (rocFileHandle_t fh), (override)); MOCK_METHOD(std::shared_ptr, getFile, (rocFileHandle_t), (override)); MOCK_METHOD(void, clear, (), (override)); diff --git a/rocfile/test/mstate.h b/rocfile/test/mstate.h index 64321cb4..86181d8a 100644 --- a/rocfile/test/mstate.h +++ b/rocfile/test/mstate.h @@ -34,7 +34,7 @@ class MDriverState : public DriverState { MOCK_METHOD(void, deregisterBuffer, (const void *buf), (override)); MOCK_METHOD(std::shared_ptr, getBuffer, (const void *buf), (override)); MOCK_METHOD(std::shared_ptr, getBuffer, (const void *buf, size_t length, int flags), (override)); - MOCK_METHOD(rocFileHandle_t, registerFile, (int fd), (override)); + MOCK_METHOD(rocFileHandle_t, registerFile, (const UnregisteredFile &uf), (override)); MOCK_METHOD(void, deregisterFile, (rocFileHandle_t fh), (override)); MOCK_METHOD(std::shared_ptr, getFile, (rocFileHandle_t fh), (override)); MOCK_METHOD(file_buffer_pair, getFileAndBuffer, From d72c66f827a0eff5c9807de0f6e7eb95fa0a9c0b Mon Sep 17 00:00:00 2001 From: Kurt McMillan Date: Tue, 4 Nov 2025 16:50:00 +0000 Subject: [PATCH 2/2] review: Add missing headers --- rocfile/src/file.cpp | 2 ++ rocfile/src/sys.h | 1 + 2 files changed, 3 insertions(+) diff --git a/rocfile/src/file.cpp b/rocfile/src/file.cpp index 66f8b11c..01d8d42e 100644 --- a/rocfile/src/file.cpp +++ b/rocfile/src/file.cpp @@ -3,8 +3,10 @@ * SPDX-License-Identifier: MIT */ +#include "context.h" #include "file.h" #include "state.h" +#include "sys.h" #include #include diff --git a/rocfile/src/sys.h b/rocfile/src/sys.h index 73d9e39b..ba275440 100644 --- a/rocfile/src/sys.h +++ b/rocfile/src/sys.h @@ -11,6 +11,7 @@ #include #include +#include #include #include #include