diff --git a/cudax/include/cuda/experimental/__stf/internal/logical_data.cuh b/cudax/include/cuda/experimental/__stf/internal/logical_data.cuh index c07d834d832..59bb057c504 100644 --- a/cudax/include/cuda/experimental/__stf/internal/logical_data.cuh +++ b/cudax/include/cuda/experimental/__stf/internal/logical_data.cuh @@ -1467,8 +1467,7 @@ public: continue; } - // TODO THIS MAY BE A BUG: do we care about managed devices or host? - const auto memory_node = data_place::device(static_cast(n - 2)); + const auto memory_node = from_index(n); // Skip the target memory node in this step if (memory_node == target_memory_node) { diff --git a/cudax/include/cuda/experimental/__stf/places/data_place_extension.cuh b/cudax/include/cuda/experimental/__stf/places/data_place_extension.cuh deleted file mode 100644 index e72bfccec56..00000000000 --- a/cudax/include/cuda/experimental/__stf/places/data_place_extension.cuh +++ /dev/null @@ -1,217 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Part of CUDASTF in CUDA C++ Core Libraries, -// under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. -// -//===----------------------------------------------------------------------===// - -/** - * @file - * @brief Base class for data_place extensions, enabling custom place types - * - * This extension mechanism allows custom data place types (like green contexts) - * to be defined without modifying the core data_place class. - * Extensions provide virtual methods for place-specific behavior like memory - * allocation and string representation. - */ - -#pragma once - -#include - -#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) -# pragma GCC system_header -#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) -# pragma clang system_header -#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) -# pragma system_header -#endif // no system header - -#include - -#include -#include - -#include -#include - -namespace cuda::experimental::stf -{ -// Forward declarations -class exec_place; - -/** - * @brief Base class for data_place extensions - * - * Custom data place types inherit from this class and override virtual methods - * to provide place-specific behavior. This enables extensibility without - * modifying the core data_place class. - * - * Example usage for a custom place type: - * @code - * class my_custom_extension : public data_place_extension { - * public: - * exec_place affine_exec_place() const override { ... } - * int get_device_ordinal() const override { return my_device_id; } - * ::std::string to_string() const override { return "my_custom_place"; } - * size_t hash() const override { return std::hash{}(my_device_id); } - * bool operator==(const data_place_extension& other) const override { ... } - * }; - * @endcode - */ -class data_place_extension -{ -public: - virtual ~data_place_extension() = default; - - /** - * @brief Get the affine execution place for this data place - * - * Returns the exec_place that should be used for computation on data - * stored at this place. The exec_place may have its own virtual methods - * (e.g., activate/deactivate) for execution-specific behavior. - */ - virtual exec_place affine_exec_place() const = 0; - - /** - * @brief Get the device ordinal for this place - * - * Returns the CUDA device ID associated with this place. - * For host-only places, this should return -1. - */ - virtual int get_device_ordinal() const = 0; - - /** - * @brief Get a string representation of this place - * - * Used for debugging and logging purposes. - */ - virtual ::std::string to_string() const = 0; - - /** - * @brief Compute a hash value for this place - * - * Used for storing data_place in hash-based containers. - */ - virtual size_t hash() const = 0; - - /** - * @brief Check equality with another extension - * - * @param other The other extension to compare with - * @return true if the extensions represent the same place - */ - virtual bool operator==(const data_place_extension& other) const = 0; - - /** - * @brief Compare ordering with another extension - * - * @param other The other extension to compare with - * @return true if this extension is less than the other - */ - virtual bool operator<(const data_place_extension& other) const = 0; - - /** - * @brief Create a physical memory allocation for this place (VMM API) - * - * This method is used by localized arrays (composite_slice) to create physical - * memory segments that are then mapped into a contiguous virtual address space. - * Custom place types can override this method to provide specialized memory - * allocation behavior. - * - * @note Managed memory is not supported by the VMM API. - * - * @param handle Output parameter for the allocation handle - * @param size Size of the allocation in bytes - * @return CUresult indicating success or failure - * - * @see allocate() for regular memory allocation - */ - virtual CUresult mem_create(CUmemGenericAllocationHandle* handle, size_t size) const - { - int dev_ordinal = get_device_ordinal(); - - CUmemAllocationProp prop = {}; - prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; - if (dev_ordinal >= 0) - { - // Device memory - prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - prop.location.id = dev_ordinal; - } -#if _CCCL_CTK_AT_LEAST(12, 2) - else if (dev_ordinal == -1) - { - // Host memory (device ordinal -1) - // CU_MEM_LOCATION_TYPE_HOST requires CUDA 12.2+ - prop.location.type = CU_MEM_LOCATION_TYPE_HOST; - prop.location.id = 0; - } - else - { - // Managed memory (-2) is not supported by the VMM API - _CCCL_ASSERT(false, "mem_create: managed memory is not supported by the VMM API"); - return CUDA_ERROR_NOT_SUPPORTED; - } -#else // ^^^ _CCCL_CTK_AT_LEAST(12, 2) ^^^ / vvv _CCCL_CTK_BELOW(12, 2) vvv - else if (dev_ordinal == -1) - { - // Host VMM requires CU_MEM_LOCATION_TYPE_HOST which is only available in CUDA 12.2+ - _CCCL_ASSERT(false, "mem_create: host VMM requires CUDA 12.2+ (CU_MEM_LOCATION_TYPE_HOST not available)"); - return CUDA_ERROR_NOT_SUPPORTED; - } - else - { - // Managed memory (-2) is not supported by the VMM API - _CCCL_ASSERT(false, "mem_create: managed memory is not supported by the VMM API"); - return CUDA_ERROR_NOT_SUPPORTED; - } -#endif // _CCCL_CTK_AT_LEAST(12, 2) - return cuMemCreate(handle, size, &prop, 0); - } - - /** - * @brief Allocate memory for this place (raw allocation) - * - * This is the low-level allocation interface. For stream-ordered allocations - * (where allocation_is_stream_ordered() returns true), the allocation will - * be ordered with respect to other operations on the stream. For immediate - * allocations, the stream parameter is ignored. - * - * @param size Size of the allocation in bytes - * @param stream CUDA stream for stream-ordered allocations (ignored for immediate allocations) - * @return Pointer to allocated memory - */ - virtual void* allocate(::std::ptrdiff_t size, cudaStream_t stream) const = 0; - - /** - * @brief Deallocate memory for this place (raw deallocation) - * - * @param ptr Pointer to memory to deallocate - * @param size Size of the allocation - * @param stream CUDA stream for stream-ordered deallocations (ignored for immediate deallocations) - */ - virtual void deallocate(void* ptr, size_t size, cudaStream_t stream) const = 0; - - /** - * @brief Returns true if allocation/deallocation is stream-ordered - * - * When this returns true, the allocation uses stream-ordered APIs like - * cudaMallocAsync, and allocators should use stream_async_op to synchronize - * prerequisites before allocation. - * - * When this returns false, the allocation is immediate (like cudaMallocHost) - * and the stream parameter is ignored. Note that immediate deallocations - * (e.g., cudaFree) may or may not introduce implicit synchronization. - * - * Default is true since most GPU-based extensions use cudaMallocAsync. - */ - virtual bool allocation_is_stream_ordered() const - { - return true; - } -}; -} // end namespace cuda::experimental::stf diff --git a/cudax/include/cuda/experimental/__stf/places/data_place_impl.cuh b/cudax/include/cuda/experimental/__stf/places/data_place_impl.cuh new file mode 100644 index 00000000000..ab699cb48d6 --- /dev/null +++ b/cudax/include/cuda/experimental/__stf/places/data_place_impl.cuh @@ -0,0 +1,412 @@ +//===----------------------------------------------------------------------===// +// +// Part of CUDASTF in CUDA C++ Core Libraries, +// under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. +// +//===----------------------------------------------------------------------===// + +/** + * @file + * @brief Concrete implementations of data_place_interface + * + * This file contains implementations for standard data place types: + * host, managed, device, invalid, affine, and device_auto. + */ + +#pragma once + +#include + +#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) +# pragma GCC system_header +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) +# pragma clang system_header +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) +# pragma system_header +#endif // no system header + +#include +#include +#include + +namespace cuda::experimental::stf +{ +/** + * @brief Implementation for the invalid data place + */ +class data_place_invalid final : public data_place_interface +{ +public: + bool is_invalid() const override + { + return true; + } + + int get_device_ordinal() const override + { + return data_place_interface::invalid; + } + + ::std::string to_string() const override + { + return "invalid"; + } + + size_t hash() const override + { + return ::std::hash()(data_place_interface::invalid); + } + + int cmp(const data_place_interface& other) const override + { + if (typeid(*this) != typeid(other)) + { + return typeid(*this).before(typeid(other)) ? -1 : 1; + } + return 0; + } + + void* allocate(::std::ptrdiff_t, cudaStream_t) const override + { + throw ::std::logic_error("Cannot allocate from invalid data_place"); + } + + void deallocate(void*, size_t, cudaStream_t) const override + { + throw ::std::logic_error("Cannot deallocate from invalid data_place"); + } + + bool allocation_is_stream_ordered() const override + { + return false; + } +}; + +/** + * @brief Implementation for the host (pinned memory) data place + */ +class data_place_host final : public data_place_interface +{ +public: + bool is_host() const override + { + return true; + } + + int get_device_ordinal() const override + { + return data_place_interface::host; + } + + ::std::string to_string() const override + { + return "host"; + } + + size_t hash() const override + { + return ::std::hash()(data_place_interface::host); + } + + int cmp(const data_place_interface& other) const override + { + if (typeid(*this) != typeid(other)) + { + return typeid(*this).before(typeid(other)) ? -1 : 1; + } + return 0; + } + + void* allocate(::std::ptrdiff_t size, cudaStream_t) const override + { + void* result = nullptr; + cuda_safe_call(cudaMallocHost(&result, size)); + return result; + } + + void deallocate(void* ptr, size_t, cudaStream_t) const override + { + cuda_safe_call(cudaFreeHost(ptr)); + } + + bool allocation_is_stream_ordered() const override + { + return false; + } + + CUresult mem_create(CUmemGenericAllocationHandle* handle, size_t size) const override + { +#if _CCCL_CTK_AT_LEAST(12, 2) + CUmemAllocationProp prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_HOST; + prop.location.id = 0; + return cuMemCreate(handle, size, &prop, 0); +#else + (void) handle; + (void) size; + return CUDA_ERROR_NOT_SUPPORTED; +#endif + } +}; + +/** + * @brief Implementation for managed memory data place + */ +class data_place_managed final : public data_place_interface +{ +public: + bool is_managed() const override + { + return true; + } + + int get_device_ordinal() const override + { + return data_place_interface::managed; + } + + ::std::string to_string() const override + { + return "managed"; + } + + size_t hash() const override + { + return ::std::hash()(data_place_interface::managed); + } + + int cmp(const data_place_interface& other) const override + { + if (typeid(*this) != typeid(other)) + { + return typeid(*this).before(typeid(other)) ? -1 : 1; + } + return 0; + } + + void* allocate(::std::ptrdiff_t size, cudaStream_t) const override + { + void* result = nullptr; + cuda_safe_call(cudaMallocManaged(&result, size)); + return result; + } + + void deallocate(void* ptr, size_t, cudaStream_t) const override + { + cuda_safe_call(cudaFree(ptr)); + } + + bool allocation_is_stream_ordered() const override + { + return false; + } +}; + +/** + * @brief Implementation for a specific CUDA device data place + */ +class data_place_device final : public data_place_interface +{ +public: + explicit data_place_device(int device_id) + : device_id_(device_id) + { + _CCCL_ASSERT(device_id >= 0, "Device ID must be non-negative"); + } + + bool is_device() const override + { + return true; + } + + int get_device_ordinal() const override + { + return device_id_; + } + + ::std::string to_string() const override + { + return "dev" + ::std::to_string(device_id_); + } + + size_t hash() const override + { + return ::std::hash()(device_id_); + } + + int cmp(const data_place_interface& other) const override + { + if (typeid(*this) != typeid(other)) + { + return typeid(*this).before(typeid(other)) ? -1 : 1; + } + return (device_id_ > static_cast(other).device_id_) + - (device_id_ < static_cast(other).device_id_); + } + + void* allocate(::std::ptrdiff_t size, cudaStream_t stream) const override + { + void* result = nullptr; + const int prev_dev = cuda_try(); + + if (prev_dev != device_id_) + { + cuda_safe_call(cudaSetDevice(device_id_)); + } + + SCOPE(exit) + { + if (prev_dev != device_id_) + { + cuda_safe_call(cudaSetDevice(prev_dev)); + } + }; + + cuda_safe_call(cudaMallocAsync(&result, size, stream)); + return result; + } + + void deallocate(void* ptr, size_t, cudaStream_t stream) const override + { + const int prev_dev = cuda_try(); + + if (prev_dev != device_id_) + { + cuda_safe_call(cudaSetDevice(device_id_)); + } + + SCOPE(exit) + { + if (prev_dev != device_id_) + { + cuda_safe_call(cudaSetDevice(prev_dev)); + } + }; + + cuda_safe_call(cudaFreeAsync(ptr, stream)); + } + + bool allocation_is_stream_ordered() const override + { + return true; + } + + CUresult mem_create(CUmemGenericAllocationHandle* handle, size_t size) const override + { + CUmemAllocationProp prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = device_id_; + return cuMemCreate(handle, size, &prop, 0); + } + +private: + int device_id_; +}; + +/** + * @brief Implementation for the affine data place (uses exec_place's affine data place) + */ +class data_place_affine final : public data_place_interface +{ +public: + bool is_affine() const override + { + return true; + } + + int get_device_ordinal() const override + { + return data_place_interface::affine; + } + + ::std::string to_string() const override + { + return "affine"; + } + + size_t hash() const override + { + return ::std::hash()(data_place_interface::affine); + } + + int cmp(const data_place_interface& other) const override + { + if (typeid(*this) != typeid(other)) + { + return typeid(*this).before(typeid(other)) ? -1 : 1; + } + return 0; + } + + void* allocate(::std::ptrdiff_t, cudaStream_t) const override + { + throw ::std::logic_error("Cannot allocate from affine data_place directly"); + } + + void deallocate(void*, size_t, cudaStream_t) const override + { + throw ::std::logic_error("Cannot deallocate from affine data_place directly"); + } + + bool allocation_is_stream_ordered() const override + { + return false; + } +}; + +/** + * @brief Implementation for device_auto data place (auto-select device) + */ +class data_place_device_auto final : public data_place_interface +{ +public: + bool is_device_auto() const override + { + return true; + } + + int get_device_ordinal() const override + { + return data_place_interface::device_auto; + } + + ::std::string to_string() const override + { + return "auto"; + } + + size_t hash() const override + { + return ::std::hash()(data_place_interface::device_auto); + } + + int cmp(const data_place_interface& other) const override + { + if (typeid(*this) != typeid(other)) + { + return typeid(*this).before(typeid(other)) ? -1 : 1; + } + return 0; + } + + void* allocate(::std::ptrdiff_t, cudaStream_t) const override + { + throw ::std::logic_error("Cannot allocate from device_auto data_place directly"); + } + + void deallocate(void*, size_t, cudaStream_t) const override + { + throw ::std::logic_error("Cannot deallocate from device_auto data_place directly"); + } + + bool allocation_is_stream_ordered() const override + { + return true; + } +}; +} // end namespace cuda::experimental::stf diff --git a/cudax/include/cuda/experimental/__stf/places/data_place_interface.cuh b/cudax/include/cuda/experimental/__stf/places/data_place_interface.cuh new file mode 100644 index 00000000000..84a3b68950d --- /dev/null +++ b/cudax/include/cuda/experimental/__stf/places/data_place_interface.cuh @@ -0,0 +1,258 @@ +//===----------------------------------------------------------------------===// +// +// Part of CUDASTF in CUDA C++ Core Libraries, +// under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. +// +//===----------------------------------------------------------------------===// + +/** + * @file + * @brief Abstract interface for data_place implementations + * + * This interface defines the contract that all data_place implementations must satisfy. + * It enables a clean polymorphic design where host, managed, device, composite, and + * extension-based places all implement a common interface. + */ + +#pragma once + +#include + +#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) +# pragma GCC system_header +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) +# pragma clang system_header +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) +# pragma system_header +#endif // no system header + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace cuda::experimental::stf +{ +// Forward declarations +class exec_place; +class exec_place_grid; +class pos4; +class dim4; + +//! Function type for computing executor placement from data coordinates +using get_executor_func_t = pos4 (*)(pos4, dim4, dim4); + +/** + * @brief Abstract interface for data_place implementations + * + * All data_place types (host, managed, device, composite, extensions) implement + * this interface. The data_place class holds a shared_ptr to this interface + * and delegates all operations to it. + */ +class data_place_interface +{ +public: + virtual ~data_place_interface() = default; + + /** + * @brief Special device ordinal values for non-device places + * + * Returned by get_device_ordinal() for places that don't correspond + * to a specific CUDA device. + */ + enum ord : int + { + invalid = ::std::numeric_limits::min(), + composite = -5, + device_auto = -4, + affine = -3, + managed = -2, + host = -1, + }; + + // === Type identification === + + /** + * @brief Check if this is the host (pinned memory) place + */ + virtual bool is_host() const + { + return false; + } + + /** + * @brief Check if this is the managed memory place + */ + virtual bool is_managed() const + { + return false; + } + + /** + * @brief Check if this is a specific device place + */ + virtual bool is_device() const + { + return false; + } + + /** + * @brief Check if this is the invalid place + */ + virtual bool is_invalid() const + { + return false; + } + + /** + * @brief Check if this is the affine place (uses exec_place's affine data place) + */ + virtual bool is_affine() const + { + return false; + } + + /** + * @brief Check if this is the device_auto place (auto-select device) + */ + virtual bool is_device_auto() const + { + return false; + } + + /** + * @brief Check if this is a composite place + */ + virtual bool is_composite() const + { + return false; + } + + /** + * @brief Check if this is an extension-based place (green context, etc.) + */ + virtual bool is_extension() const + { + return false; + } + + // === Core properties === + + /** + * @brief Get the device ordinal for this place + * + * Returns: + * - >= 0 for specific CUDA devices + * - data_place_ordinals::host (-1) for host + * - data_place_ordinals::managed (-2) for managed + * - data_place_ordinals::affine (-3) for affine + * - data_place_ordinals::device_auto (-4) for device_auto + * - data_place_ordinals::composite (-5) for composite + * - data_place_ordinals::invalid for invalid + */ + virtual int get_device_ordinal() const = 0; + + /** + * @brief Get a string representation of this place + */ + virtual ::std::string to_string() const = 0; + + /** + * @brief Compute a hash value for this place + */ + virtual size_t hash() const = 0; + + /** + * @brief Three-way comparison with another place + * + * @return -1 if *this < other, 0 if *this == other, 1 if *this > other + */ + virtual int cmp(const data_place_interface& other) const = 0; + + // === Memory allocation === + + /** + * @brief Allocate memory at this place + * + * @param size Size of the allocation in bytes + * @param stream CUDA stream for stream-ordered allocations + * @return Pointer to allocated memory + * @throws std::runtime_error if allocation is not supported for this place type + */ + virtual void* allocate(::std::ptrdiff_t size, cudaStream_t stream) const = 0; + + /** + * @brief Deallocate memory at this place + * + * @param ptr Pointer to memory to deallocate + * @param size Size of the allocation + * @param stream CUDA stream for stream-ordered deallocations + */ + virtual void deallocate(void* ptr, size_t size, cudaStream_t stream) const = 0; + + /** + * @brief Returns true if allocation/deallocation is stream-ordered + */ + virtual bool allocation_is_stream_ordered() const = 0; + + /** + * @brief Create a physical memory allocation for this place (VMM API) + * + * Default implementation returns CUDA_ERROR_NOT_SUPPORTED. + * Subclasses that support VMM should override this. + * + * @param handle Output parameter for the allocation handle + * @param size Size of the allocation in bytes + * @return CUresult indicating success or failure + */ + virtual CUresult mem_create(CUmemGenericAllocationHandle*, size_t) const + { + return CUDA_ERROR_NOT_SUPPORTED; + } + + // === Extension support === + + /** + * @brief Get the implementation for the affine exec_place (for extensions) + * + * Extensions override this to provide their affine exec_place implementation. + * Returns nullptr by default (non-extensions). + * The returned shared_ptr should be castable to shared_ptr. + */ + virtual ::std::shared_ptr get_affine_exec_impl() const + { + return nullptr; + } + + // === Composite-specific (throw by default) === + + /** + * @brief Get the grid for composite places + * @throws std::logic_error if not a composite place + */ + virtual const exec_place_grid& get_grid() const + { + throw ::std::logic_error("get_grid() called on non-composite data_place"); + } + + /** + * @brief Get the partitioner function for composite places + * @throws std::logic_error if not a composite place + */ + virtual const get_executor_func_t& get_partitioner() const + { + throw ::std::logic_error("get_partitioner() called on non-composite data_place"); + } +}; +} // end namespace cuda::experimental::stf diff --git a/cudax/include/cuda/experimental/__stf/places/exec/green_context.cuh b/cudax/include/cuda/experimental/__stf/places/exec/green_context.cuh index 2ce9691a0ba..4fc1847ca5f 100644 --- a/cudax/include/cuda/experimental/__stf/places/exec/green_context.cuh +++ b/cudax/include/cuda/experimental/__stf/places/exec/green_context.cuh @@ -25,7 +25,7 @@ # pragma system_header #endif // no system header -#include +#include #include #include #include @@ -52,7 +52,7 @@ public: /** * @brief Extension implementation for green context data places */ - class extension : public data_place_extension + class extension : public data_place_interface { public: /** @@ -63,7 +63,10 @@ public: : view_(mv(view)) {} - exec_place affine_exec_place() const override; + bool is_extension() const override + { + return true; + } int get_device_ordinal() const override { @@ -80,24 +83,18 @@ public: return hash_all(view_.g_ctx, view_.devid); } - bool operator==(const data_place_extension& other) const override + int cmp(const data_place_interface& other) const override { if (typeid(*this) != typeid(other)) { - return false; + return typeid(*this).before(typeid(other)) ? -1 : 1; } - const auto& other_gc = static_cast(other); - return view_ == other_gc.view_; - } - - bool operator<(const data_place_extension& other) const override - { - if (typeid(*this) != typeid(other)) + const auto& o = static_cast(other); + if (view_ < o.view_) { - return typeid(*this).before(typeid(other)); + return -1; } - const auto& other_gc = static_cast(other); - return view_ < other_gc.view_; + return (view_ == o.view_) ? 0 : 1; } /** @@ -127,6 +124,14 @@ public: cuda_safe_call(cudaFreeAsync(ptr, stream)); } + bool allocation_is_stream_ordered() const override + { + return true; + } + + // Provide affine exec place implementation for data_place::affine_exec_place() + ::std::shared_ptr get_affine_exec_impl() const override; + private: green_ctx_view view_; }; @@ -418,9 +423,9 @@ inline exec_place exec_place::green_ctx(const green_ctx_view& gc_view, bool use_ return exec_place_green_ctx(gc_view, use_green_ctx_data_place); } -inline exec_place green_ctx_data_place::extension::affine_exec_place() const +inline ::std::shared_ptr green_ctx_data_place::extension::get_affine_exec_impl() const { - return exec_place::green_ctx(view_); + return exec_place::green_ctx(view_).get_impl(); } inline data_place data_place::green_ctx(const green_ctx_view& gc_view) diff --git a/cudax/include/cuda/experimental/__stf/places/places.cuh b/cudax/include/cuda/experimental/__stf/places/places.cuh index d2de042452e..6a79fedda61 100644 --- a/cudax/include/cuda/experimental/__stf/places/places.cuh +++ b/cudax/include/cuda/experimental/__stf/places/places.cuh @@ -27,7 +27,8 @@ # pragma system_header #endif // no system header -#include +#include +#include #include #include @@ -60,31 +61,44 @@ class exec_place_green_ctx; //! Function type for computing executor placement from data coordinates using get_executor_func_t = pos4 (*)(pos4, dim4, dim4); +// Forward declaration for composite implementation +class data_place_composite; + /** * @brief Designates where data will be stored (CPU memory vs. on device 0 (first GPU), device 1 (second GPU), ...) * - * This typed `enum` is aligned with CUDA device ordinals but does not implicitly convert to `int`. See `device_ordinal` - * below. + * This class uses a polymorphic design where all place types (host, managed, device, + * composite, extensions) implement a common data_place_interface. The data_place class + * holds a shared_ptr to this interface and delegates operations to it. */ class data_place { - // Constructors and factory functions below forward to this. - explicit data_place(int devid) - : devid(devid) + // Private constructor from interface pointer + explicit data_place(::std::shared_ptr impl) + : pimpl_(mv(impl)) {} + template + static ::std::shared_ptr make_static_instance() + { + static T instance; + return ::std::shared_ptr(&instance, [](data_place_interface*) {}); + } + public: /** * @brief Default constructor. The object is initialized as invalid. */ - data_place() = default; + data_place() + : pimpl_(make_static_instance()) + {} /** * @brief Represents an invalid `data_place` object. */ static data_place invalid() { - return data_place(invalid_devid); + return data_place(make_static_instance()); } /** @@ -93,7 +107,7 @@ public: */ static data_place host() { - return data_place(host_devid); + return data_place(make_static_instance()); } /** @@ -101,14 +115,14 @@ public: */ static data_place managed() { - return data_place(managed_devid); + return data_place(make_static_instance()); } /// This actually does not define a data_place, but means that we should use /// the data place affine to the execution place static data_place affine() { - return data_place(affine_devid); + return data_place(make_static_instance()); } /** @@ -117,12 +131,10 @@ public: */ static data_place device_auto() { - return data_place(device_auto_devid); + return data_place(make_static_instance()); } - /** @brief Data is placed on device with index dev_id. Two relaxations are allowed: -1 can be passed to create a - * placeholder for the host, and -2 can be used to create a placeholder for a managed device. - */ + /** @brief Data is placed on device with index dev_id. */ static data_place device(int dev_id = 0) { static int const ndevs = [] { @@ -131,8 +143,17 @@ public: return result; }(); - EXPECT((dev_id >= managed_devid && dev_id < ndevs), "Invalid device ID ", dev_id); - return data_place(dev_id); + EXPECT((dev_id >= 0 && dev_id < ndevs), "Invalid device ID ", dev_id); + + static data_place_device* impls = [] { + auto* result = static_cast(::operator new[](ndevs * sizeof(data_place_device))); + for (int i = 0; i < ndevs; ++i) + { + new (result + i) data_place_device(i); + } + return result; + }(); + return data_place(::std::shared_ptr(&impls[dev_id], [](data_place_interface*) {})); } /** @@ -153,7 +174,15 @@ public: static data_place green_ctx(const green_ctx_view& gc_view); #endif // _CCCL_CTK_AT_LEAST(12, 4) - bool operator==(const data_place& rhs) const; + bool operator==(const data_place& rhs) const + { + // Same pointer means same place + if (pimpl_.get() == rhs.pimpl_.get()) + { + return true; + } + return pimpl_->cmp(*rhs.pimpl_) == 0; + } bool operator!=(const data_place& rhs) const { @@ -164,23 +193,8 @@ public: bool operator<(const data_place& rhs) const { // Not implemented for composite places - EXPECT(!is_composite()); - EXPECT(!rhs.is_composite()); - - // If both are extensions, delegate to the extension - if (is_extension() && rhs.is_extension()) - { - return *extension < *rhs.extension; - } - - // Extensions sort after non-extensions - if (is_extension() != rhs.is_extension()) - { - return rhs.is_extension(); // non-extension < extension - } - - // For simple places, compare devid - return devid < rhs.devid; + EXPECT((!is_composite() && !rhs.is_composite()), "Ordering of composite places is not implemented."); + return pimpl_->cmp(*rhs.pimpl_) < 0; } bool operator>(const data_place& rhs) const @@ -201,80 +215,49 @@ public: /// checks if this data place is a composite data place bool is_composite() const { - // If the devid indicates composite_devid then we must have a descriptor - _CCCL_ASSERT(devid != composite_devid || composite_desc != nullptr, "invalid state"); - return (devid == composite_devid); + return pimpl_->is_composite(); } /// checks if this data place has an extension (green context, etc.) bool is_extension() const { - _CCCL_ASSERT(devid != extension_devid || extension != nullptr, "invalid state"); - return (devid == extension_devid); + return pimpl_->is_extension(); } bool is_invalid() const { - return devid == invalid_devid; + return pimpl_->is_invalid(); } bool is_host() const { - return devid == host_devid; + return pimpl_->is_host(); } bool is_managed() const { - return devid == managed_devid; + return pimpl_->is_managed(); } bool is_affine() const { - return devid == affine_devid; + return pimpl_->is_affine(); } /// checks if this data place corresponds to a specific device bool is_device() const { - // All other type of data places have a specific negative devid value. - return (devid >= 0); + return pimpl_->is_device(); } bool is_device_auto() const { - return devid == device_auto_devid; + return pimpl_->is_device_auto(); } ::std::string to_string() const { - if (devid == host_devid) - { - return "host"; - } - if (devid == managed_devid) - { - return "managed"; - } - if (devid == device_auto_devid) - { - return "auto"; - } - if (devid == invalid_devid) - { - return "invalid"; - } - - if (is_extension()) - { - return extension->to_string(); - } - - if (is_composite()) - { - return "composite" + ::std::to_string(devid); - } - - return "dev" + ::std::to_string(devid); + return pimpl_->to_string(); } /** @@ -283,10 +266,27 @@ public: */ friend inline size_t to_index(const data_place& p) { - EXPECT(p.devid >= -2, "Data place with device id ", p.devid, " does not refer to a device."); - // This is not strictly a problem in this function, but it's not legit either. So let's assert. - assert(p.devid < cuda_try()); - return p.devid + 2; + int devid = p.pimpl_->get_device_ordinal(); + EXPECT(devid >= -2, "Data place with device id ", devid, " does not refer to a device."); + _CCCL_ASSERT(devid < cuda_try(), "Invalid device id"); + return devid + 2; + } + + /** + * @brief Inverse of `to_index`: converts an index back to a `data_place`. + * Index 0 -> managed, 1 -> host, 2 -> device(0), 3 -> device(1), ... + */ + friend inline data_place from_index(size_t n) + { + if (n == 0) + { + return data_place::managed(); + } + if (n == 1) + { + return data_place::host(); + } + return data_place::device(static_cast(n - 2)); } /** @@ -295,21 +295,20 @@ public: */ friend inline int device_ordinal(const data_place& p) { - if (p.is_extension()) - { - return p.extension->get_device_ordinal(); - } + return p.pimpl_->get_device_ordinal(); + } - // TODO: restrict this function, i.e. sometimes it's called with invalid places. - // EXPECT(p != invalid, "Invalid device id ", p.devid, " for data place."); - // EXPECT(p.devid >= -2, "Data place with device id ", p.devid, " does not refer to a device."); - // assert(p.devid < cuda_try()); - return p.devid; + const exec_place_grid& get_grid() const + { + return pimpl_->get_grid(); } - const exec_place_grid& get_grid() const; - const get_executor_func_t& get_partitioner() const; + const get_executor_func_t& get_partitioner() const + { + return pimpl_->get_partitioner(); + } + // Defined later after exec_place is complete exec_place affine_exec_place() const; /** @@ -319,277 +318,79 @@ public: */ size_t hash() const { - // Not implemented for composite places - EXPECT(!is_composite()); - - // Extensions provide their own hash - if (is_extension()) - { - return extension->hash(); - } - - // For simple places, hash the devid directly - return ::std::hash()(devid); + return pimpl_->hash(); } decorated_stream getDataStream() const; -private: - /** - * @brief Store the fields specific to a composite data place - * Definition comes later to avoid cyclic dependencies. - */ - class composite_state; - - //{ state - int devid = invalid_devid; // invalid by default - // Stores the fields specific to composite data places - ::std::shared_ptr composite_desc; - // Extension for custom place types (green contexts, etc.) - ::std::shared_ptr extension; - //} state - -public: /** * @brief Check if this data place has a custom extension */ bool has_extension() const { - return extension != nullptr; + return is_extension(); } /** - * @brief Get the extension (may be nullptr for standard place types) + * @brief Get the underlying interface pointer + * + * This is primarily for internal use and backward compatibility. */ - const ::std::shared_ptr& get_extension() const + const ::std::shared_ptr& get_impl() const { - return extension; + return pimpl_; } /** * @brief Create a data_place from an extension * * This factory method allows custom place types to be created from - * data_place_extension implementations. + * data_place_interface implementations that return true from is_extension(). */ - static data_place from_extension(::std::shared_ptr ext) + static data_place from_extension(::std::shared_ptr ext) { - data_place result(extension_devid); - result.extension = mv(ext); - return result; + return data_place(mv(ext)); } /** * @brief Create a physical memory allocation for this place (VMM API) - * - * This method is used by localized arrays (composite_slice) to create physical - * memory segments that are then mapped into a contiguous virtual address space. - * It delegates to the extension's mem_create if present (enabling custom place - * types to override memory allocation), otherwise creates a standard pinned - * allocation on this place's device or host. - * - * Managed memory is not supported by the VMM API. - * - * @note For regular memory allocation (not VMM-based), use the allocate() method - * instead, which provides stream-ordered allocation via cudaMallocAsync. - * - * @param handle Output parameter for the allocation handle - * @param size Size of the allocation in bytes - * @return CUresult indicating success or failure - * - * @see allocate() for regular memory allocation */ CUresult mem_create(CUmemGenericAllocationHandle* handle, size_t size) const { - if (extension) - { - return extension->mem_create(handle, size); - } - - int dev_ordinal = device_ordinal(*this); - - CUmemAllocationProp prop = {}; - prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; - if (dev_ordinal >= 0) - { - // Device memory - prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - prop.location.id = dev_ordinal; - } -#if _CCCL_CTK_AT_LEAST(12, 2) - else if (dev_ordinal == -1) - { - // Host memory (device ordinal -1) - // CU_MEM_LOCATION_TYPE_HOST requires CUDA 12.2+ - prop.location.type = CU_MEM_LOCATION_TYPE_HOST; - prop.location.id = 0; - } - else - { - // Managed memory (-2) is not supported by the VMM API - _CCCL_ASSERT(false, "mem_create: managed memory is not supported by the VMM API"); - return CUDA_ERROR_NOT_SUPPORTED; - } -#else // ^^^ _CCCL_CTK_AT_LEAST(12, 2) ^^^ / vvv _CCCL_CTK_BELOW(12, 2) vvv - else if (dev_ordinal == -1) - { - // Host VMM requires CU_MEM_LOCATION_TYPE_HOST which is only available in CUDA 12.2+ - _CCCL_ASSERT(false, "mem_create: host VMM requires CUDA 12.2+ (CU_MEM_LOCATION_TYPE_HOST not available)"); - return CUDA_ERROR_NOT_SUPPORTED; - } - else - { - // Managed memory (-2) is not supported by the VMM API - _CCCL_ASSERT(false, "mem_create: managed memory is not supported by the VMM API"); - return CUDA_ERROR_NOT_SUPPORTED; - } -#endif // _CCCL_CTK_AT_LEAST(12, 2) - return cuMemCreate(handle, size, &prop, 0); + return pimpl_->mem_create(handle, size); } /** * @brief Allocate memory at this data place (raw allocation) - * - * This is the low-level allocation interface that handles all place types: - * - For extensions: delegates to extension->allocate() - * - For host: uses cudaMallocHost (immediate, stream ignored) - * - For managed: uses cudaMallocManaged (immediate, stream ignored) - * - For device: uses cudaMallocAsync (stream-ordered) - * - * @param size Size of the allocation in bytes - * @param stream CUDA stream for stream-ordered allocations (ignored for immediate allocations, defaults to nullptr) - * @return Pointer to allocated memory */ void* allocate(::std::ptrdiff_t size, cudaStream_t stream = nullptr) const { - // Delegate to extension if present - if (extension) - { - return extension->allocate(size, stream); - } - - void* result = nullptr; - - if (is_host()) - { - cuda_safe_call(cudaMallocHost(&result, size)); - } - else if (is_managed()) - { - cuda_safe_call(cudaMallocManaged(&result, size)); - } - else - { - // Device allocation - EXPECT(!is_composite(), "Composite places don't support direct allocation"); - const int prev_dev = cuda_try(); - const int target_dev = devid; - - if (prev_dev != target_dev) - { - cuda_safe_call(cudaSetDevice(target_dev)); - } - - SCOPE(exit) - { - if (prev_dev != target_dev) - { - cuda_safe_call(cudaSetDevice(prev_dev)); - } - }; - - cuda_safe_call(cudaMallocAsync(&result, size, stream)); - } - - return result; + return pimpl_->allocate(size, stream); } /** * @brief Deallocate memory at this data place (raw deallocation) - * - * For immediate deallocations (host, managed), the stream is ignored. - * Note that cudaFree (used for managed memory) may introduce implicit synchronization. - * - * @param ptr Pointer to memory to deallocate - * @param size Size of the allocation - * @param stream CUDA stream for stream-ordered deallocations (ignored for immediate deallocations, defaults to - * nullptr) */ void deallocate(void* ptr, size_t size, cudaStream_t stream = nullptr) const { - // Delegate to extension if present - if (extension) - { - extension->deallocate(ptr, size, stream); - return; - } - - if (is_host()) - { - cuda_safe_call(cudaFreeHost(ptr)); - } - else if (is_managed()) - { - cuda_safe_call(cudaFree(ptr)); - } - else - { - // Device deallocation - const int prev_dev = cuda_try(); - const int target_dev = devid; - - if (prev_dev != target_dev) - { - cuda_safe_call(cudaSetDevice(target_dev)); - } - - SCOPE(exit) - { - if (prev_dev != target_dev) - { - cuda_safe_call(cudaSetDevice(prev_dev)); - } - }; - - cuda_safe_call(cudaFreeAsync(ptr, stream)); - } + pimpl_->deallocate(ptr, size, stream); } /** * @brief Returns true if allocation/deallocation is stream-ordered - * - * When this returns true, the allocation uses stream-ordered APIs like - * cudaMallocAsync, and allocators should use stream_async_op to synchronize - * prerequisites before allocation. - * - * When this returns false, the allocation is immediate (like cudaMallocHost) - * and the stream parameter is ignored. Note that immediate deallocations - * (e.g., cudaFree) may introduce implicit synchronization. */ bool allocation_is_stream_ordered() const { - if (extension) - { - return extension->allocation_is_stream_ordered(); - } - // Host and managed are immediate (stream ignored), device is stream-ordered - return !is_host() && !is_managed(); + return pimpl_->allocation_is_stream_ordered(); } private: - /* Constants to implement data_place::invalid(), data_place::host(), etc. */ - enum devid : int - { - invalid_devid = ::std::numeric_limits::min(), - extension_devid = -6, // For any custom extension-based place - composite_devid = -5, - device_auto_devid = -4, - affine_devid = -3, - managed_devid = -2, - host_devid = -1, - }; + ::std::shared_ptr pimpl_; }; +/** Declaration for unqualified lookup (friend is only found via ADL when a \c data_place argument is present). */ +inline data_place from_index(size_t n); + /** * @brief Indicates where a computation takes place (CPU, dev0, dev1, ...) * @@ -616,19 +417,18 @@ public: virtual exec_place activate() const { - if (affine.is_device()) + if (!affine.is_device()) { - auto old_dev_id = cuda_try(); - auto new_dev_id = device_ordinal(affine); - if (old_dev_id != new_dev_id) - { - cuda_safe_call(cudaSetDevice(new_dev_id)); - } - - auto old_dev = data_place::device(old_dev_id); - return exec_place(mv(old_dev)); + return exec_place(); } - return exec_place(); + auto old_dev_id = cuda_try(); + auto new_dev_id = device_ordinal(affine); + if (old_dev_id != new_dev_id) + { + cuda_safe_call(cudaSetDevice(new_dev_id)); + } + auto old_dev = data_place::device(old_dev_id); + return exec_place(mv(old_dev)); } virtual void deactivate(const exec_place& prev) const @@ -684,11 +484,6 @@ public: return affine == rhs.affine; } - bool operator!=(const impl& rhs) const - { - return !(*this == rhs); - } - virtual size_t hash() const { return affine.hash(); @@ -717,21 +512,6 @@ public: return for_computation ? pool_compute : pool_data; } - /** - * @brief Create a stream valid for execution on this place. - * - * Expected to be called with this exec place already activated (e.g. from - * stream_pool::next(place) which uses exec_place_guard). Creates a new stream - * in the current context via cudaStreamCreateWithFlags(..., cudaStreamNonBlocking). - * The caller (e.g. stream_pool::next) builds a decorated_stream from the result. - */ - cudaStream_t create_stream() const - { - cudaStream_t stream = nullptr; - cuda_safe_call(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); - return stream; - } - static constexpr size_t pool_size = 4; static constexpr size_t data_pool_size = 4; @@ -868,19 +648,6 @@ public: */ decorated_stream getStream(bool for_computation) const; - /** - * @brief Create a stream valid for execution on this place. - * - * Call only when the place is already activated (e.g. inside exec_place_guard). - * For getting a stream from the pool, use getStream() / pick_stream() instead. - * - * @return A CUDA stream valid for this execution place - */ - cudaStream_t create_stream() const - { - return pimpl->create_stream(); - } - cudaStream_t pick_stream(bool for_computation = true) const { return getStream(for_computation).stream; @@ -1000,36 +767,7 @@ public: * */ template - auto operator->*(Fun&& fun) const - { - const int new_device = device_ordinal(pimpl->affine); - if (new_device >= 0) - { - // We're on a device - // Change device only if necessary. - const int old_device = cuda_try(); - if (new_device != old_device) - { - cuda_safe_call(cudaSetDevice(new_device)); - } - - SCOPE(exit) - { - // It is the responsibility of the client to ensure that any change of the current device in this - // section was reverted. - if (new_device != old_device) - { - cuda_safe_call(cudaSetDevice(old_device)); - } - }; - return ::std::forward(fun)(); - } - else - { - // We're on the host, just call the function with no further ado. - return ::std::forward(fun)(); - } - } + auto operator->*(Fun&& fun) const; public: exec_place(::std::shared_ptr pimpl) @@ -1108,6 +846,13 @@ private: exec_place prev_; }; +template +auto exec_place::operator->*(Fun&& fun) const +{ + exec_place_guard guard(*this); + return ::std::forward(fun)(); +} + inline decorated_stream stream_pool::next(const exec_place& place) { _CCCL_ASSERT(pimpl, "stream_pool::next called on empty pool"); @@ -1119,7 +864,7 @@ inline decorated_stream stream_pool::next(const exec_place& place) if (!result.stream) { exec_place_guard guard(place); - result.stream = place.create_stream(); + cuda_safe_call(cudaStreamCreateWithFlags(&result.stream, cudaStreamNonBlocking)); result.id = get_stream_id(result.stream); result.dev_id = get_device_from_stream(result.stream); } @@ -1413,9 +1158,9 @@ public: stream_pool& get_stream_pool(bool for_computation) const override { - assert(!for_computation); + _CCCL_ASSERT(!for_computation, "Expected data transfer stream pool"); const auto& v = get_places(); - assert(v.size() > 0); + _CCCL_ASSERT(v.size() > 0, "Grid must have at least one place"); return v[0].get_stream_pool(for_computation); } @@ -1596,7 +1341,7 @@ public: ::std::shared_ptr get_impl() const { - assert(::std::dynamic_pointer_cast(exec_place::get_impl())); + _CCCL_ASSERT(::std::dynamic_pointer_cast(exec_place::get_impl()), "Invalid exec_place_grid impl"); return ::std::static_pointer_cast(exec_place::get_impl()); } @@ -1629,6 +1374,47 @@ inline exec_place_grid make_grid(::std::vector places) return make_grid(mv(places), grid_dim); } +// === data_place::affine_exec_place implementation === +// Defined here after exec_place_grid is complete + +inline exec_place data_place::affine_exec_place() const +{ + if (is_host()) + { + return exec_place::host(); + } + + // Managed memory uses host exec_place (debatable but follows original behavior) + if (is_managed()) + { + return exec_place::host(); + } + + if (is_composite()) + { + // Return the grid of places associated to this composite data place + // exec_place_grid inherits from exec_place, so this works via slicing + return get_grid(); + } + + if (is_extension()) + { + // Extensions provide their own affine exec_place via get_affine_exec_impl() + auto impl = pimpl_->get_affine_exec_impl(); + _CCCL_ASSERT(impl != nullptr, "Extension must provide affine exec_place implementation"); + return exec_place(::std::static_pointer_cast(impl)); + } + + if (is_device()) + { + // This must be a specific device + return exec_place::device(pimpl_->get_device_ordinal()); + } + + // For invalid, affine, device_auto - throw + throw ::std::logic_error("affine_exec_place() not meaningful for this data_place type"); +} + /// Implementation deferred because we need the definition of exec_place_grid inline exec_place exec_place::iterator::operator*() { @@ -1803,123 +1589,105 @@ inline exec_place_grid partition_tile(const exec_place_grid& e_place, dim4 tile_ return make_grid(mv(places), size); } -/* - * This is defined here so that we avoid cyclic dependencies. +/** + * @brief Implementation for composite data places + * + * Composite places represent data distributed across multiple devices, + * using a grid of execution places and a partitioner function. */ -class data_place::composite_state +class data_place_composite final : public data_place_interface { public: - composite_state() = default; - - composite_state(exec_place_grid grid, get_executor_func_t partitioner_func) - : grid(mv(grid)) - , partitioner_func(mv(partitioner_func)) + data_place_composite(exec_place_grid grid, get_executor_func_t partitioner_func) + : grid_(mv(grid)) + , partitioner_func_(mv(partitioner_func)) {} - const exec_place_grid& get_grid() const + bool is_composite() const override { - return grid; + return true; } - const get_executor_func_t& get_partitioner() const + + int get_device_ordinal() const override { - return partitioner_func; + return data_place_interface::composite; } -private: - exec_place_grid grid; - get_executor_func_t partitioner_func; -}; - -inline data_place data_place::composite(get_executor_func_t f, const exec_place_grid& grid) -{ - data_place result; - - // Flags this is a composite data place - result.devid = composite_devid; - - // Save the state that is specific to a composite data place into the - // data_place object. - result.composite_desc = ::std::make_shared(grid, f); - - return result; -} - -// User-visible API when the same partitioner as the one of the grid -template -data_place data_place::composite(partitioner_t, const exec_place_grid& g) -{ - return data_place::composite(&partitioner_t::get_executor, g); -} - -inline exec_place data_place::affine_exec_place() const -{ - // EXPECT(*this != affine); - // EXPECT(*this != data_place::invalid()); - - if (is_host()) + ::std::string to_string() const override { - return exec_place::host(); + return "composite"; } - // This is debatable ! - if (is_managed()) + size_t hash() const override { - return exec_place::host(); + // Composite places don't support hashing + throw ::std::logic_error("hash() not supported for composite data_place"); } - if (is_composite()) + int cmp(const data_place_interface& other) const override { - // Return the grid of places associated to that composite data place - return get_grid(); + if (typeid(*this) != typeid(other)) + { + return typeid(*this).before(typeid(other)) ? -1 : 1; + } + const auto& o = static_cast(other); + if (get_partitioner() != o.get_partitioner()) + { + return ::std::less{}(o.get_partitioner(), get_partitioner()) ? 1 : -1; + } + if (get_grid() == o.get_grid()) + { + return 0; + } + // Grids differ: compare structurally (shape first, then element-by-element places) + return (get_grid() < o.get_grid()) ? -1 : 1; } - if (is_extension()) + void* allocate(::std::ptrdiff_t, cudaStream_t) const override { - return extension->affine_exec_place(); + throw ::std::logic_error("Composite places don't support direct allocation"); } - // This must be a device - return exec_place::device(devid); -} - -inline decorated_stream data_place::getDataStream() const -{ - return affine_exec_place().getStream(false); -} - -inline const exec_place_grid& data_place::get_grid() const -{ - return composite_desc->get_grid(); -}; -inline const get_executor_func_t& data_place::get_partitioner() const -{ - return composite_desc->get_partitioner(); -} - -inline bool data_place::operator==(const data_place& rhs) const -{ - if (is_composite() != rhs.is_composite()) + void deallocate(void*, size_t, cudaStream_t) const override { - return false; + throw ::std::logic_error("Composite places don't support direct deallocation"); } - if (is_extension() != rhs.is_extension()) + bool allocation_is_stream_ordered() const override { return false; } - if (!is_composite() && !is_extension()) + const exec_place_grid& get_grid() const override { - return devid == rhs.devid; + return grid_; } - if (is_extension()) + const get_executor_func_t& get_partitioner() const override { - _CCCL_ASSERT(devid == extension_devid, ""); - return (rhs.devid == extension_devid && *extension == *rhs.extension); + return partitioner_func_; } - return (get_grid() == rhs.get_grid() && (get_partitioner() == rhs.get_partitioner())); +private: + exec_place_grid grid_; + get_executor_func_t partitioner_func_; +}; + +inline data_place data_place::composite(get_executor_func_t f, const exec_place_grid& grid) +{ + return data_place(::std::make_shared(grid, f)); +} + +// User-visible API when the same partitioner as the one of the grid +template +data_place data_place::composite(partitioner_t, const exec_place_grid& g) +{ + return data_place::composite(&partitioner_t::get_executor, g); +} + +inline decorated_stream data_place::getDataStream() const +{ + return affine_exec_place().getStream(false); } #ifdef UNITTESTED_FILE