Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 39 additions & 6 deletions rocfile/src/file.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
* SPDX-License-Identifier: MIT
*/

#include "context.h"
#include "file.h"
#include "state.h"
#include "sys.h"

#include <algorithm>
#include <cstdlib>
Expand All @@ -18,14 +20,45 @@ using std::vector;

namespace rocFile {

UnregisteredFile::UnregisteredFile(int fd)
Comment thread
derobins marked this conversation as resolved.
: m_fd(fd), m_stat{Context<Sys>::get()->fstat(fd)}, m_flags{Context<Sys>::get()->fcntl(fd, F_GETFL, 0)},
m_mountinfo{Context<LibMountHelper>::get()->getMountInfo(m_stat.st_dev)}
{
}

int
UnregisteredFile::getFd() const noexcept
{
return m_fd;
}

struct stat
UnregisteredFile::getStat() const noexcept
{
return m_stat;
}

int
UnregisteredFile::getFlags() const noexcept
{
return m_flags;
}

optional<MountInfo>
UnregisteredFile::getMountInfo() const noexcept
{
return m_mountinfo;
}

rocFileHandle_t
IFile::getHandle() const
{
return reinterpret_cast<rocFileHandle_t>(const_cast<IFile *>(this));
}

File::File(int _fd, const struct stat &fstat, int _status_flags, optional<MountInfo> _mountinfo)
: fd{_fd}, device{fstat.st_dev}, mode{fstat.st_mode}, status_flags{_status_flags}, mountinfo{_mountinfo}
File::File(const UnregisteredFile &uf)
: fd{uf.getFd()}, device{uf.getStat().st_dev}, mode{uf.getStat().st_mode}, status_flags{uf.getFlags()},
mountinfo{uf.getMountInfo()}
{
}

Expand Down Expand Up @@ -71,14 +104,14 @@ FileMap::getFile(rocFileHandle_t fh)
}

rocFileHandle_t
FileMap::registerFile(int fd, struct stat &fstat, int status_flags, optional<MountInfo> mountinfo)
FileMap::registerFile(const UnregisteredFile &uf)
{
if (from_fd.end() != from_fd.find(fd)) {
if (from_fd.end() != from_fd.find(uf.getFd())) {
throw FileAlreadyRegistered();
}

auto file = std::shared_ptr<IFile>(new File(fd, fstat, status_flags, mountinfo));
from_fd[fd] = file;
auto file = std::shared_ptr<IFile>(new File(uf));
from_fd[file->getFd()] = file;
from_fh[file->getHandle()] = file;

return file->getHandle();
Expand Down
53 changes: 41 additions & 12 deletions rocfile/src/file.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,42 @@ struct FileOperationsOutstanding : public std::runtime_error {
}
};

class UnregisteredFile {
public:
/// @brief Construct an unregistered file
///
/// During construction of an unregistered file, information about the file
/// is collected from the system.
///
/// @param fd A valid file descriptor
UnregisteredFile(int fd);

/// @return Returns the file descriptor
int getFd() const noexcept;

/// @return Returns the information provided by fstat (2)
struct stat getStat() const noexcept;

/// @return Returns the flags provided by fcntl (2)
int getFlags() const noexcept;

/// @brief Returns information obtained from /proc/self/mountinfo
std::optional<MountInfo> getMountInfo() const noexcept;

private:
/// @brief The file descriptor
int m_fd;

/// @brief Information provided by fstat (2)
struct stat m_stat;

/// @brief Flags provided by fcntl(2)
int m_flags;

/// @brief Information obtained from /proc/self/mountinfo
std::optional<MountInfo> m_mountinfo;
};

class IFile {
public:
virtual ~IFile() = default;
Expand Down Expand Up @@ -73,11 +109,9 @@ class File : public IFile {
virtual std::optional<MountInfo> getMountInfo() const override;

private:
/// @brief Construct a file object
/// @param fd The file descriptor for the file
/// @param fstat The struct stat value obtained by calling fstat(2) with fd
/// @param mountinfo Mount information for the filesystem backing fd
File(int fd, const struct stat &fstat, int status_flags, std::optional<MountInfo> mountinfo);
/// @brief Construct a registered file
/// @param uf An unregistered file
File(const UnregisteredFile &uf);

/// @brief The file descriptor
int fd;
Expand Down Expand Up @@ -109,13 +143,8 @@ class FileMap {

/// @brief Registers a file. Files must be registered before they can be used with rocFile IO APIs
/// @attention A unique_lock on RocFileMutex must be held
/// @param fd An open file descriptor
/// @param fstat The struct stat value obtained by calling fstat(2) with fd
/// @param status_flags The fd's status flags
/// @param mountinfo Mount information for the filesystem backing fd
/// @return A handle to be used when calling rocFile IO APIs
virtual rocFileHandle_t registerFile(int fd, struct stat &fstat, int status_flags,
std::optional<MountInfo> mountinfo);
/// @param uf An unregistered file
virtual rocFileHandle_t registerFile(const UnregisteredFile &uf);

/// @brief Deregisters the file associated with the provided file handle
/// @attention A unique_lock on RocFileMutex must be held
Expand Down
12 changes: 4 additions & 8 deletions rocfile/src/rocfile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,11 @@ try {
}

switch (descr->type) {
case rocFileHandleTypeOpaqueFD:
// Validate
if (descr->handle.fd < 0) {
return {rocFileInvalidValue, hipSuccess};
}

// Register
*fh = Context<DriverState>::get()->registerFile(descr->handle.fd);
case rocFileHandleTypeOpaqueFD: {
UnregisteredFile uf{descr->handle.fd};
*fh = Context<DriverState>::get()->registerFile(uf);
return {rocFileSuccess, hipSuccess};
}
case rocFileHandleTypeOpaqueWin32:
case rocFileHandleTypeUserspaceFS:
default:
Expand Down
10 changes: 2 additions & 8 deletions rocfile/src/state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,8 @@ DriverState::getBuffer(const void *buf, size_t length, int flags)
//

rocFileHandle_t
DriverState::registerFile(int fd)
DriverState::registerFile(const UnregisteredFile &uf)
{
// Get file information outside of the state mutex to avoid potentially
// stalling other IO threads
auto fstat{Context<Sys>::get()->fstat(fd)};
auto status_flags{Context<Sys>::get()->fcntl(fd, F_GETFL, 0)};
auto mountinfo{Context<LibMountHelper>::get()->getMountInfo(fstat.st_dev)};

unique_lock<shared_mutex> ulock{state_mutex};

// For NVIDIA cuFile compatibility, implicitly "initialize"
Expand All @@ -132,7 +126,7 @@ DriverState::registerFile(int fd)
ref_count++;
}

return file_map->registerFile(fd, fstat, status_flags, mountinfo);
return file_map->registerFile(uf);
}

void
Expand Down
4 changes: 2 additions & 2 deletions rocfile/src/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ class DriverState {
//

/// @brief Registers a file. Files must be registered before they can be used with rocFile IO APIs
/// @param [in] fd An open file descriptor
/// @param [in] uf An unregistered file
/// @return A handle to be used when calling rocFile IO APIs
virtual rocFileHandle_t registerFile(int fd);
virtual rocFileHandle_t registerFile(const UnregisteredFile &uf);

/// @brief Deregisters the file associated with the provided file handle
/// @param [in] fh The handle of the file to deregister
Expand Down
1 change: 1 addition & 0 deletions rocfile/src/sys.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <optional>
#include <stdexcept>

#include <fcntl.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
Expand Down
6 changes: 3 additions & 3 deletions rocfile/test/driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
#include <cstdint>

using namespace rocFile;

using ::testing::StrictMock;
using namespace testing;

// Put tests inside the macros to suppress the global constructor
// warnings
Expand Down Expand Up @@ -109,7 +108,8 @@ TEST_F(RocFileDriverAdmin, HandleRegisterBadFD)

descr.handle.fd = -1;
descr.type = rocFileHandleTypeOpaqueFD;
descr.fs_ops = nullptr;

EXPECT_CALL(msys, fstat).WillOnce(Throw(Sys::RuntimeError(EBADF)));

ASSERT_EQ(rocFileUseCount(), 0);
ASSERT_NE(rocFileHandleRegister(&handle, &descr), ROCFILE_SUCCESS);
Expand Down
4 changes: 1 addition & 3 deletions rocfile/test/mfile.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ class MFileMap : public FileMap {
MFileMap()
{
}
MOCK_METHOD(rocFileHandle_t, registerFile,
(int fd, struct stat &fstat, int _status_flags, std::optional<MountInfo> mountinfo),
(override));
MOCK_METHOD(rocFileHandle_t, registerFile, (const UnregisteredFile &uf), (override));
MOCK_METHOD(void, deregisterFile, (rocFileHandle_t fh), (override));
MOCK_METHOD(std::shared_ptr<IFile>, getFile, (rocFileHandle_t), (override));
MOCK_METHOD(void, clear, (), (override));
Expand Down
2 changes: 1 addition & 1 deletion rocfile/test/mstate.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class MDriverState : public DriverState {
MOCK_METHOD(void, deregisterBuffer, (const void *buf), (override));
MOCK_METHOD(std::shared_ptr<IBuffer>, getBuffer, (const void *buf), (override));
MOCK_METHOD(std::shared_ptr<IBuffer>, getBuffer, (const void *buf, size_t length, int flags), (override));
MOCK_METHOD(rocFileHandle_t, registerFile, (int fd), (override));
MOCK_METHOD(rocFileHandle_t, registerFile, (const UnregisteredFile &uf), (override));
MOCK_METHOD(void, deregisterFile, (rocFileHandle_t fh), (override));
MOCK_METHOD(std::shared_ptr<IFile>, getFile, (rocFileHandle_t fh), (override));
MOCK_METHOD(file_buffer_pair, getFileAndBuffer,
Expand Down