diff --git a/hipfile/src/amd_detail/batch/batch.cpp b/hipfile/src/amd_detail/batch/batch.cpp index 7ff7c14f..b4133526 100644 --- a/hipfile/src/amd_detail/batch/batch.cpp +++ b/hipfile/src/amd_detail/batch/batch.cpp @@ -113,7 +113,7 @@ BatchContext::submit_operations(const hipFileIOParams_t *params, unsigned num_pa throw std::invalid_argument(msg.str()); } - std::vector> pending_ops{}; + std::vector> pending_ops{}; // It would be more performant to be able to perform multiple lookups // rather than waiting to lock the DriverState lock for each lookup. @@ -124,7 +124,7 @@ BatchContext::submit_operations(const hipFileIOParams_t *params, unsigned num_pa // file flags. auto [_file, _buffer] = Context::get()->getFileAndBuffer( param_copy->fh, param_copy->u.batch.devPtr_base, param_copy->u.batch.size, 0); - auto op = std::make_shared(std::move(param_copy), _buffer, _file); + auto op = std::shared_ptr{new BatchOperation{std::move(param_copy), _buffer, _file}}; pending_ops.push_back(op); } @@ -133,6 +133,11 @@ BatchContext::submit_operations(const hipFileIOParams_t *params, unsigned num_pa outstanding_ops.insert(pending_ops.begin(), pending_ops.end()); } +std::unordered_set>& +BatchContextAccessor::get_ops_set(BatchContext& _context){ + return _context.outstanding_ops; +} + void BatchContextMap::clear() { diff --git a/hipfile/src/amd_detail/batch/batch.h b/hipfile/src/amd_detail/batch/batch.h index f8372a17..62d914e6 100644 --- a/hipfile/src/amd_detail/batch/batch.h +++ b/hipfile/src/amd_detail/batch/batch.h @@ -28,8 +28,13 @@ struct InvalidBatchHandle : public std::invalid_argument { } }; +class IBatchOperation { +public: + virtual ~IBatchOperation() = default; +}; + /// @brief Represents a single IO Request -class BatchOperation { +class BatchOperation : public IBatchOperation { public: /// @brief Create an operation to handle and track an IO request. /// @param [in] params IO parameters @@ -89,13 +94,27 @@ class BatchContext : public IBatchContext { /// but is not yet complete or completed but not yet retrieved by the /// application. /// shared_ptr as it may need to be passed to a backend. - std::unordered_set> outstanding_ops; + std::unordered_set> outstanding_ops; BatchContext(unsigned capacity); + friend class BatchContextAccessor; friend class BatchContextMap; }; +/* + * Friend class of BatchContext + * + * Can be used to peer into BatchContext's hidden members. + * Should not be used in production. + */ +class BatchContextAccessor { +public: + // Return a reference to the unordered_set to modify what ops are loaded + // in the context. + std::unordered_set>& get_ops_set(BatchContext& _context); +}; + class BatchContextMap { public: /*! diff --git a/hipfile/test/amd_detail/batch/batch.cpp b/hipfile/test/amd_detail/batch/batch.cpp index 1c1c0044..e0c4bb70 100644 --- a/hipfile/test/amd_detail/batch/batch.cpp +++ b/hipfile/test/amd_detail/batch/batch.cpp @@ -10,6 +10,7 @@ #include "hipfile-test.h" #include "hipfile-warnings.h" #include "invalid-enum.h" +#include "mbatch.h" #include "mbuffer.h" #include "mfile.h" #include "mstate.h" @@ -348,4 +349,17 @@ TEST_F(HipFileBatchContext, SubmitSingleBadParamModeInvalid) ASSERT_THROW(_context->submit_operations(&bad_io_params, 1), std::invalid_argument); } +// Not a real test - proof of concept +TEST_F(HipFileBatchContext, _InsertMBatchOperationIntoContext) +{ + BatchContextAccessor bca; + auto ops = bca.get_ops_set(*std::dynamic_pointer_cast(_context)); + + std::shared_ptr mock_op = std::make_unique(); + + ops.insert(mock_op); + + ASSERT_EQ(1, ops.size()); +} + HIPFILE_WARN_NO_GLOBAL_CTOR_ON diff --git a/hipfile/test/amd_detail/mbatch.h b/hipfile/test/amd_detail/mbatch.h index c59ba908..e5d21bed 100644 --- a/hipfile/test/amd_detail/mbatch.h +++ b/hipfile/test/amd_detail/mbatch.h @@ -15,6 +15,9 @@ namespace hipFile { +class MBatchOperation : public IBatchOperation { +}; + class MBatchContext : public IBatchContext { public: MOCK_METHOD(unsigned, get_capacity, (), (const, noexcept, override));