diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 297b4b2fd3..a40f71dd5f 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -694,14 +694,15 @@ void kmeans_fit( rmm::device_uvector batch_workspace(streaming_batch_size, stream); - cuvs::spatial::knn::detail::utils::batch_load_iterator data_batches( - X.data_handle(), n_samples, n_features, streaming_batch_size, stream); + auto data_batches = cuvs::spatial::knn::detail::utils::make_batch_load_iterator( + handle, X.data_handle(), n_samples, n_features, streaming_batch_size, stream); // Host-path weight batches: only materialized when weights are provided and // the data resides on host - std::optional> weight_batches; + std::optional> weight_batches; if constexpr (!data_on_device) { if (weight_ptr != nullptr) { - weight_batches.emplace(weight_ptr, n_samples, 1, streaming_batch_size, stream); + weight_batches = cuvs::spatial::knn::detail::utils::make_batch_load_iterator( + handle, weight_ptr, n_samples, IndexT{1}, streaming_batch_size, stream); } else { raft::matrix::fill(handle, batch_weights_buf.view(), DataT{1}); } @@ -833,7 +834,7 @@ void kmeans_fit( raft::make_device_matrix_view(new_centroids_ptr, n_clusters, n_features); data_batches.reset(); - using wt_iter_t = cuvs::spatial::knn::detail::utils::batch_load_iterator; + using wt_iter_t = cuvs::spatial::knn::detail::utils::batch_load_iterator_dyn; std::optional wt_it; if (weight_batches.has_value()) { weight_batches->reset(); @@ -932,7 +933,7 @@ void kmeans_fit( iter_inertia = DataT{0}; data_batches.reset(); - using wt_iter_t = cuvs::spatial::knn::detail::utils::batch_load_iterator; + using wt_iter_t = cuvs::spatial::knn::detail::utils::batch_load_iterator_dyn; std::optional wt_it; if (weight_batches.has_value()) { weight_batches->reset(); diff --git a/cpp/src/neighbors/detail/ann_utils.cuh b/cpp/src/neighbors/detail/ann_utils.cuh index a7872f87a0..bdb2c310c6 100644 --- a/cpp/src/neighbors/detail/ann_utils.cuh +++ b/cpp/src/neighbors/detail/ann_utils.cuh @@ -6,8 +6,14 @@ #pragma once #include +#include +#include #include #include +#include +#include +#include +#include #include #include #include @@ -21,6 +27,9 @@ #include #include +#include +#include +#include namespace cuvs::spatial::knn::detail::utils { @@ -378,166 +387,446 @@ void copy_selected(IdxT n_rows, } /** - * A batch input iterator over the data source. - * Given an input pointer, it decides whether the current device has the access to the data and - * gives it back to the user in batches. Three scenarios are possible: + * Helper that returns the stream to use for non-blocking copies and a flag indicating whether + * cross-stream pipelining (prefetch / writeback) can be enabled. * - * 1. if `source == nullptr`: then `batch.data() == nullptr` - * 2. if `source` is accessible from the device, `batch.data()` points directly at the source at - * the proper offsets on each iteration. - * 3. if `source` is not accessible from the device, `batch.data()` points to an intermediate - * buffer; the corresponding data is copied in the given `stream` on every iterator dereference - * (i.e. batches can be skipped). Dereferencing the same batch two times in a row does not force - * the raft::copy. + * If `res` has a CUDA stream pool with at least one stream, the first pool stream is used and + * `true` is returned (prefetch can run concurrently with kernels on res's main stream). Otherwise + * the main stream itself is returned with `false`, and the caller should treat prefetch as a + * no-op (no overlap is possible on a single stream). + */ +inline auto get_prefetch_stream(raft::resources const& res) + -> std::pair +{ + if (res.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL) && + raft::resource::get_stream_pool_size(res) >= 1) { + return {raft::resource::get_stream_from_stream_pool(res), true}; + } + return {raft::resource::get_cuda_stream(res), false}; +} + +/** + * Iterate a 2D mdspan in row-batches with optional pipelined H2D / D2H copies and kernel work. * - * In all three scenarios, the number of iterations, batch offsets and sizes are the same. + * Strategy is selected at compile time from `MdspanT::accessor_type::is_device_accessible`: + * * passthrough: each batch is a row-range of `input_view_` directly; no buffering, no copy. + * * copy_device: each batch is staged through one or two internal device buffers via + * `cudaMemcpyAsync` on the caller-supplied `copy_stream`. With `prefetch=true`, two buffers + * are used as a ring so that the next batch's H2D and the previous batch's D2H can overlap + * with the user kernel running on `res`'s main stream. * - * The iterator can be reused. If the number of iterations is one, at most one raft::copy will ever - * be invoked (i.e. small datasets are not reloaded multiple times). + * Three orthogonal flags control behavior of the copy_device strategy: + * * `prefetch`: if true, allocate two device buffers and pipeline copies using + * `prefetch_next_batch()`. If false, copies happen synchronously at `operator*` and one buffer + * is allocated. As an optimization, when `n_iters_ <= 1` (i.e. `batch_size >= + * input_view.extent(0)`) prefetching is internally downgraded to false and only one buffer is + * allocated, since there is no "next batch" to overlap with. + * * `initialize`: if true, stage source rows H2D into the buffer before yielding the batch. + * If false, the buffer is handed out uninitialized (kernel produces the data from scratch). + * * `host_writeback`: if true, queue D2H of every advanced batch back to `input_view_`. + * Pending writebacks are flushed on destruction. * - * In the case of pageable host buffer input, the iterator is by default (almost) synchronous due to - * the behavior of raft::copy. In order to achieve kernel and copy overlapping, a - * prefetch_next_batch (synchronous) API is provided. Note that since prefetch API is synchronous, - * user may want to schedule kernel, which is asynchronous, first. User is responsible to properly - * manage the order of prefetch and kernel to ensure overlapping. + * `initialize` and `host_writeback` are independent: it is legal to skip H2D + * (`initialize=false, host_writeback=true`) when the kernel produces the result from scratch. + * At least one of them must be true. + * + * Stream model (prefetch=true): + * * The user passes `copy_stream`, which should be distinct from `res`'s main stream (use + * `get_prefetch_stream(res)`); otherwise no real overlap is possible. + * * Writeback is done with a 2-iteration delay so that the D2H of batch `i` is queued in the + * `prefetch_next_batch()` of iteration `i+2` -- right before the same slot is overwritten by + * the next H2D. This way the D2H reads a slot whose user kernel (`kernel_i`) is two + * iterations old and therefore guaranteed finished, without needing CUDA events for + * cross-stream synchronization. + * * `prefetch_next_batch()` queues, on `copy_stream`, in order: (1) D2H of the prefetch slot's + * stale kernel output (if `host_writeback`), then (2) H2D of `pos` into the prefetch slot + * (if `initialize`); then it `sync_stream(res)`'s so the host stall for the previous kernel + * overlaps with the just-queued copies. + * * `operator*` (next iteration's `load()`) swaps the ring slots and `synchronize()`'s + * `copy_stream` so the swapped-in slot is fully staged before the user's next kernel runs. + * + * Stream model (prefetch=false): + * * Single buffer; copies happen synchronously inside `operator*` with `copy_stream` then + * `synchronize()`'d from the host. No overlap is attempted. + * + * Iteration ends when `operator++` reaches `n_iters_`. The iterator can be reused via `reset()`. + * + * Usage with prefetch (matches the legacy `batch_load_iterator` pattern): + * ``` + * auto [copy_stream, enable_prefetch] = utils::get_prefetch_stream(res); + * utils::batch_load_iterator iter(res, view, batch_size, copy_stream, mr, enable_prefetch); + * iter.prefetch_next_batch(); + * for (auto const& batch : iter) { + * kernel<<<..., raft::resource::get_cuda_stream(res)>>>(batch.data(), ...); + * iter.prefetch_next_batch(); + * } + * ``` + * + * Usage with writeback (replaces `batched_device_view`): + * ``` + * auto [copy_stream, enable_prefetch] = utils::get_prefetch_stream(res); + * utils::batch_load_iterator iter(res, view, batch_size, copy_stream, mr, enable_prefetch, + * /\*initialize=*\/false, /\*host_writeback=*\/true); + * iter.prefetch_next_batch(); + * for (auto& batch : iter) { + * kernel<<<..., raft::resource::get_cuda_stream(res)>>>(batch.view()); + * iter.prefetch_next_batch(); + * } + * ``` */ -template +template struct batch_load_iterator { - using size_type = size_t; + using mdspan_type = MdspanT; + using accessor_type = typename MdspanT::accessor_type; + using element_type = typename MdspanT::element_type; + using index_type = typename MdspanT::index_type; + using value_type_d = std::remove_const_t; + using size_type = size_t; + + static constexpr bool kPassthrough = accessor_type::is_device_accessible; + + // Type returned by `view()` for the passthrough strategy: a 2D submdspan of `input_view_` over a + // contiguous row range. Built without ever calling `data_handle()` on the input mdspan. + // (Per the mdspan spec, slicing a `layout_right` with a `tuple{lo, hi}` over the leading dim + // yields a `layout_stride` mdspan with the input's accessor preserved.) + using passthrough_view_type = + decltype(cuda::std::submdspan(std::declval(), + std::declval>(), + cuda::std::full_extent)); + // Type returned by `view()` for the copy_device strategy: a row-major exhaustive device view + // over the iterator's internal device buffer. + using copy_view_type = raft::device_matrix_view; + using batch_view_type = std::conditional_t; - /** A single batch of data residing in device memory. */ + /** A single batch of data residing in (or accessible from) device memory. */ struct batch { ~batch() noexcept { - /* - If there's no copy, there's no allocation owned by the batch. - If there's no allocation, there's no guarantee that the device pointer is stream-ordered. - If there's no stream order guarantee, we must synchronize with the stream before the batch is - destroyed to make sure all GPU operations in that stream finish earlier. - */ - if (!does_copy()) { RAFT_CUDA_TRY_NO_THROW(cudaStreamSynchronize(stream_)); } + if constexpr (!kPassthrough) { + if (host_writeback_ && source_ != nullptr) { + // Two slots may still hold un-flushed kernel output: + // * `prefetch_dev_ptr_` for the batch the loop didn't reach (its prefetch_next_batch() + // that would have queued the D2H never happened). + // * `dev_ptr_` for the most-recently-loaded batch (its writeback is always deferred to + // the destructor since no future iteration recycles its slot). + // The user kernel that wrote either slot may still be in flight on `res`'s main stream + // when this destructor runs, so host-stall on it before issuing the D2Hs to avoid a + // read-while-write race. + const bool has_pending = + (prefetch_ && prefetch_dirty_pos_.has_value()) || (dirty_cur_ && pos_.has_value()); + if (has_pending) { raft::resource::sync_stream(*res_); } + if (prefetch_ && prefetch_dirty_pos_.has_value()) { + queue_d2h(prefetch_dev_ptr_, *prefetch_dirty_pos_); + prefetch_dirty_pos_.reset(); + } + if (dirty_cur_ && pos_.has_value()) { + queue_d2h(dev_ptr_, *pos_); + dirty_cur_ = false; + } + } + } + // Stream is shared with the iterator; it must be sync'd before the underlying buffers (or, + // in the passthrough case, the source mdspan) can be safely reused. + copy_stream_.synchronize(); } - /** Logical width of a single row in a batch, in elements of type `T`. */ [[nodiscard]] auto row_width() const -> size_type { return row_width_; } - /** Logical offset of the batch, in rows (`row_width()`) */ [[nodiscard]] auto offset() const -> size_type { return pos_.value_or(0) * batch_size_; } - /** Logical size of the batch, in rows (`row_width()`) */ [[nodiscard]] auto size() const -> size_type { return batch_len_; } - /** Logical size of the batch, in rows (`row_width()`) */ - [[nodiscard]] auto data() const -> const T* { return const_cast(dev_ptr_); } - /** Whether this batch copies the data (i.e. the source is inaccessible from the device). */ - [[nodiscard]] auto does_copy() const -> bool { return needs_copy_; } + [[nodiscard]] auto does_copy() const -> bool { return !kPassthrough; } + + /** + * 2D view of the staged batch. + * + * Passthrough: a `cuda::std::submdspan` of `input_view_` over the active row range. The + * implementation never calls `data_handle()` on `input_view_`; the mdspan's accessor is + * preserved end-to-end, which is the contract that lets future device mdspans without a raw + * pointer flow through this code path unchanged. + * + * Copy_device: a `device_matrix_view` over the internal device buffer (row-major exhaustive). + */ + [[nodiscard]] auto view() const -> batch_view_type + { + if constexpr (kPassthrough) { + const index_type row_lo = static_cast(pos_.value_or(0) * batch_size_); + const index_type row_hi = static_cast(row_lo + batch_len_); + return cuda::std::submdspan( + input_view_, cuda::std::tuple{row_lo, row_hi}, cuda::std::full_extent); + } else { + return raft::make_device_matrix_view( + dev_ptr_, static_cast(batch_len_), static_cast(row_width_)); + } + } + + /** + * Raw device pointer of the staged batch. Provided for backward compatibility with raw-pointer + * call sites. In passthrough mode this forwards to `view().data_handle()`, which means it + * relies on the input mdspan's accessor exposing a pointer. Future device mdspans without a + * raw pointer should call `view()` instead and treat the result as an mdspan. + */ + [[nodiscard]] auto data() const -> element_type* + { + if constexpr (kPassthrough) { + return view().data_handle(); + } else { + return dev_ptr_; + } + } private: - batch(const T* source, - size_type n_rows, - size_type row_width, + template + friend struct batch_load_iterator; + + // Helper: only call `data_handle()` on the input mdspan in copy_device mode. In passthrough + // mode we keep `source_` at `nullptr` (it is never read) so the iterator imposes no + // raw-pointer requirement on the input accessor. + static auto get_source(MdspanT input_view) noexcept -> element_type* + { + if constexpr (kPassthrough) { + return nullptr; + } else { + return input_view.data_handle(); + } + } + + batch(raft::resources const& res, + MdspanT input_view, size_type batch_size, - rmm::cuda_stream_view stream, + rmm::cuda_stream_view copy_stream, rmm::device_async_resource_ref mr, - bool prefetch = false) - : stream_(stream), - buf_0_(0, stream, mr), - buf_1_(0, stream, mr), - source_(source), - dev_ptr_(nullptr), - n_rows_(n_rows), - row_width_(row_width), - batch_size_(std::min(batch_size, n_rows)), - pos_(std::nullopt), - prefetch_pos_(std::nullopt), - n_iters_(raft::div_rounding_up_safe(n_rows, batch_size)), - needs_copy_(false), - prefetch_(prefetch) + bool prefetch, + bool initialize, + bool host_writeback) + : copy_stream_(copy_stream), + res_(&res), + input_view_(input_view), + source_(get_source(input_view)), + n_rows_(static_cast(input_view.extent(0))), + row_width_(static_cast(input_view.extent(1))), + batch_size_(std::min(batch_size, std::max(n_rows_, 1))), + n_iters_(n_rows_ == 0 ? 0 : raft::div_rounding_up_safe(n_rows_, batch_size_)), + prefetch_(prefetch), + initialize_(initialize), + host_writeback_(host_writeback), + buf_0_(0, copy_stream, mr), + buf_1_(0, copy_stream, mr) { - if (source_ == nullptr) { return; } - cudaPointerAttributes attr; - RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, source_)); - dev_ptr_ = reinterpret_cast(attr.devicePointer); - - if (dev_ptr_ == nullptr) { needs_copy_ = true; } - if (attr.type != cudaMemoryTypeDevice) { - // Although data might be accessible on device through HMM or ATS, - // it is preferred to copy the dataset explicitly when it is not device data. - needs_copy_ = true; - } - if (needs_copy_) { - buf_0_.resize(row_width_ * batch_size_, stream); - dev_ptr_ = buf_0_.data(); - if (prefetch_) { - buf_1_.resize(row_width_ * batch_size_, stream); - prefetch_dev_ptr_ = buf_1_.data(); + if (n_rows_ == 0) { return; } + RAFT_EXPECTS(initialize_ || host_writeback_, + "At least one of initialize or host_writeback must be true"); + RAFT_EXPECTS(!host_writeback_ || !std::is_const_v, + "host_writeback=true requires a non-const element type"); + + if constexpr (!kPassthrough) { + if (source_ == nullptr) { + // Null source: yield batches with the right offsets/sizes but data() == nullptr. + // Skip allocation and never queue copies. + return; + } + buf_0_.resize(row_width_ * batch_size_, copy_stream); + dev_ptr_ = reinterpret_cast(buf_0_.data()); + // The second buffer is only useful when there is more than one batch to overlap. With + // n_iters_ <= 1, there is no "next batch" to stage while a kernel runs on the current + // one, so prefetching offers no benefit. Downgrade `prefetch_` to false to skip the + // buf_1_ allocation and have `prefetch()` / `load()` take the single-buffer fast path. + if (prefetch_ && n_iters_ > 1) { + buf_1_.resize(row_width_ * batch_size_, copy_stream); + prefetch_dev_ptr_ = reinterpret_cast(buf_1_.data()); + } else { + prefetch_ = false; } } } - rmm::cuda_stream_view stream_; - rmm::device_uvector buf_0_; - rmm::device_uvector buf_1_; - const T* source_; - size_type n_rows_; - size_type row_width_; - size_type batch_size_; - size_type n_iters_; - bool needs_copy_; - bool prefetch_; - - std::optional pos_; - std::optional prefetch_pos_; - size_type batch_len_; - T* dev_ptr_; - T* prefetch_dev_ptr_; - friend class batch_load_iterator; /** - * Changes the state of the batch to point at the `pos` index. - * If necessary, copies the data from the source in the registered stream. + * Make this batch represent position `pos`. + * + * Passthrough: pure bookkeeping; the per-batch view is recomputed on demand by `view()` via + * `cuda::std::submdspan`, never via pointer arithmetic on the input mdspan. + * + * Copy_device, prefetch=true: swap the ring slots and host-sync `copy_stream` so the + * swapped-in slot is fully staged before the user's next kernel runs. The slot we swapped + * out (now `prefetch_dev_ptr_`) carried the just-completed kernel's writes; if `host_writeback` + * is on, its writeback is recorded for the *next* `prefetch_next_batch()` to flush -- when + * that slot is about to be overwritten by a new H2D, by which time the kernel that wrote it + * is two iterations old and guaranteed finished (the legacy 2-iteration-delay model). + * + * Copy_device, prefetch=false: synchronously stage H2D into the single buffer and host-sync + * `copy_stream`. + * + * No-op if the buffer already holds `pos`. Iteration end is signaled by `pos >= n_iters_`. */ - void load(const size_type& pos) + void load(size_type pos) { - // No-op if the data is already loaded, or it's the end of the input. - if (pos == pos_ || pos >= n_iters_) { return; } - pos_.emplace(pos); - batch_len_ = std::min(batch_size_, n_rows_ - std::min(offset(), n_rows_)); - if (source_ == nullptr) { return; } - if (needs_copy_) { - if (size() > 0) { - RAFT_LOG_TRACE("batch_load_iterator::copy(offset = %zu, size = %zu, row_width = %zu)", - size_t(offset()), - size_t(size()), - size_t(row_width())); - if (prefetch_ && prefetch_pos_ == pos_) { - std::swap(dev_ptr_, prefetch_dev_ptr_); + if (n_iters_ == 0) { return; } + if (pos == pos_) { return; } + if (pos >= n_iters_) { return; } + + const size_type row_offset = pos * batch_size_; + const size_type len = + std::min(batch_size_, n_rows_ - std::min(row_offset, n_rows_)); + + if constexpr (kPassthrough) { + // Passthrough: just record the new slice; view() will compute the submdspan. + pos_.emplace(pos); + batch_len_ = len; + return; + } else { + if (source_ == nullptr) { + pos_.emplace(pos); + batch_len_ = len; + // dev_ptr_ remains nullptr (or the empty-source buffer state). + return; + } + + if (prefetch_ && prefetch_pos_.has_value() && *prefetch_pos_ == pos) { + // Hand-off from the prefetch slot. Before swapping, transfer the soon-to-be-stale + // dev_ptr_'s dirty state into prefetch_dirty_pos_: after the swap, the slot that just + // received `kernel_(prev pos_)`'s writes becomes prefetch_dev_ptr_, and its D2H will be + // queued by the *next* prefetch_next_batch() right before the slot is overwritten by a + // new H2D. By that point, the kernel that wrote it is two iterations old -- finished + // by main-stream FIFO, no events required. + if (host_writeback_ && dirty_cur_ && pos_.has_value()) { + prefetch_dirty_pos_ = pos_; } else { - raft::copy(dev_ptr_, source_ + offset() * row_width(), size() * row_width(), stream_); + prefetch_dirty_pos_.reset(); } + std::swap(dev_ptr_, prefetch_dev_ptr_); + prefetch_pos_.reset(); + // Ensure prefetch_next_batch()'s queued H2D into this slot (and any prior D2H of the + // slot from the previous overwrite) finished before the user kernel reads it. + copy_stream_.synchronize(); + } else { + // Non-pipelined fast path (prefetch_=false, or prefetch_pos_ didn't match). + if (host_writeback_ && dirty_cur_ && pos_.has_value()) { + queue_d2h(dev_ptr_, *pos_); + dirty_cur_ = false; + } + if (initialize_) { queue_h2d(dev_ptr_, row_offset, len); } + copy_stream_.synchronize(); + } + pos_.emplace(pos); + batch_len_ = len; + if (host_writeback_) { + // Every advanced batch is implicitly dirty: the user kernel will write to it before + // the next prefetch_next_batch() (or destructor) recycles the slot. + dirty_cur_ = true; } - } else { - dev_ptr_ = const_cast(source_) + offset() * row_width(); } } /** - * Helper function for prefetch. NOP if prefetch option is not enabled. This API is synchronous. + * Cross-stream pipelining step (called by the user *after* enqueuing the kernel for the + * current batch). + * + * On `copy_stream`, queues -- in this order, into the *prefetch* slot only: + * 1. D2H of the prefetch slot's stale kernel output (if `host_writeback` and the slot was + * written by a kernel two iterations ago). The kernel that produced that output is + * already finished by main-stream FIFO (the user has since enqueued the kernel for + * the slot currently in `dev_ptr_`), so no event is needed -- the D2H reads a slot the + * main stream is no longer touching. + * 2. H2D of `pos` into the prefetch slot (if `initialize`). + * + * Then host-stalls on `res`'s main stream so the host has waited out the just-enqueued user + * kernel by the time control returns; meanwhile the D2H/H2D queued above runs concurrently + * on `copy_stream`, overlapping the writeback with the kernel. + * + * No-op if prefetch is disabled, source is null, passthrough, or `pos >= n_iters_`. */ - void prefetch(const size_type& pos) + void prefetch(size_type pos) { - if (pos >= n_iters_ || !prefetch_ || !needs_copy_ || source_ == nullptr) { return; } - size_type prefetch_offset = batch_size_ * pos; - size_type prefetch_size = std::min(batch_size_, n_rows_ - std::min(prefetch_offset, n_rows_)); - raft::common::nvtx::push_range( - "batch_load_iterator::prefetch(offset = %zu, size = %zu, row_width = %zu)", - size_t(prefetch_offset), - size_t(prefetch_size), - size_t(row_width())); - raft::copy(prefetch_dev_ptr_, - source_ + prefetch_offset * row_width(), - prefetch_size * row_width(), - stream_); - raft::common::nvtx::pop_range(); - stream_.synchronize(); + if constexpr (kPassthrough) { return; } + if (n_iters_ == 0 || pos >= n_iters_ || source_ == nullptr) { return; } + if (!prefetch_) { + // No-op: in non-pipelined mode load() does the staging synchronously in operator*. + return; + } + + // 2-iteration-delayed writeback: the prefetch slot still holds kernel output from two + // iterations ago (from when this slot was last `dev_ptr_`); D2H it now -- right before we + // overwrite the slot with the next H2D. The kernel that wrote it is past on the main + // stream, so this D2H runs concurrently with the just-enqueued kernel for `dev_ptr_` + // without racing (different slots). + if (host_writeback_ && prefetch_dirty_pos_.has_value()) { + queue_d2h(prefetch_dev_ptr_, *prefetch_dirty_pos_); + prefetch_dirty_pos_.reset(); + } + + // H2D of the next batch into the prefetch slot. Sequenced after the D2H on `copy_stream`, + // so the H2D doesn't clobber the slot before its prior contents have been read out. + if (initialize_) { + const size_type row_offset = pos * batch_size_; + const size_type len = + std::min(batch_size_, n_rows_ - std::min(row_offset, n_rows_)); + queue_h2d(prefetch_dev_ptr_, row_offset, len); + } prefetch_pos_.emplace(pos); + + // Host-stall on the user's main stream: the just-enqueued kernel runs concurrently with + // the copies above. By the time control returns, the kernel is done and the next load() + // can safely read its writes (still in `dev_ptr_`'s slot, which becomes prefetch_dev_ptr_ + // after the upcoming swap). + raft::resource::sync_stream(*res_); } + + void queue_h2d(element_type* dst, size_type src_row_offset, size_type num_rows) + { + if (num_rows == 0) { return; } + const size_t n_bytes = num_rows * row_width_ * sizeof(value_type_d); + // dst is `element_type*` (potentially `const T*`), but it points into a non-const internal + // buffer (`rmm::device_uvector`); the const-cast restores the writable view. + // Use cudaMemcpyAsync directly (rather than raft::copy) to avoid issues with + // HMM/ATS-mapped host pointers being misclassified. + RAFT_CUDA_TRY(cudaMemcpyAsync(const_cast(dst), + source_ + src_row_offset * row_width_, + n_bytes, + cudaMemcpyHostToDevice, + copy_stream_)); + } + + void queue_d2h(element_type* src, size_type pos) + { + const size_type row_offset = pos * batch_size_; + const size_type num_rows = + std::min(batch_size_, n_rows_ - std::min(row_offset, n_rows_)); + if (num_rows == 0) { return; } + const size_t n_bytes = num_rows * row_width_ * sizeof(value_type_d); + RAFT_CUDA_TRY(cudaMemcpyAsync(const_cast(source_) + row_offset * row_width_, + src, + n_bytes, + cudaMemcpyDeviceToHost, + copy_stream_)); + } + + rmm::cuda_stream_view copy_stream_; + raft::resources const* res_; + MdspanT input_view_; + element_type* source_; + size_type n_rows_; + size_type row_width_; + size_type batch_size_; + size_type n_iters_; + bool prefetch_; + bool initialize_; + bool host_writeback_; + + rmm::device_uvector buf_0_; + rmm::device_uvector buf_1_; + + // Slot bookkeeping (only meaningful for !kPassthrough). + element_type* dev_ptr_ = nullptr; + element_type* prefetch_dev_ptr_ = nullptr; + // `pos_`: batch index currently held in `dev_ptr_`'s slot (the user-facing slot). + // `prefetch_pos_`: batch index pre-staged in `prefetch_dev_ptr_`'s slot via H2D. + // `prefetch_dirty_pos_`: batch index whose kernel output is sitting in `prefetch_dev_ptr_`'s + // slot and still needs to be D2H'd back to host. Set by `load()`'s + // swap when the slot it just retired held kernel writes; consumed + // (D2H'd) by the *next* `prefetch()` -- the legacy 2-iteration delay. + std::optional pos_; + std::optional prefetch_pos_; + std::optional prefetch_dirty_pos_; + size_type batch_len_ = 0; + bool dirty_cur_ = false; }; using value_type = batch; @@ -545,72 +834,78 @@ struct batch_load_iterator { using pointer = const value_type*; /** - * Create a batch iterator over the data `source`. + * Construct an iterator over `input_view`. * - * For convenience, the data `source` is read in logical units of size `row_width`; batch sizes - * and offsets are calculated in logical rows. Hence, can interpret the data as a contiguous - * row-major matrix of size [n_rows, row_width], and the batches are the sub-matrices of size - * [x<=batch_size, n_rows]. - * - * If prefetch option is enabled, the batch_load_iterator could help to achieve overlapping with - * prefetch_next_batch() with other workloads. This is useful if source buffer is in host memory. - * To achieve overlapping, the other workloads have to be async and scheduled before - * prefetch_next_batch(). Users also need to use a different stream for the workloads. E.g., - * utils::batch_load_iterator batches(..., stream_1, ..., true); - * batches.prefetch_next_batch(); - * for (const auto& batch : batches) { - * // The following kernel and prefetch_next_batch() could be overlapped. - * kernel<<<..., stream_2>>>(...); - * batches.prefetch_next_batch(); - * } - * - * @param source the input data -- host, device, or nullptr. - * @param n_rows the size of the input in logical rows. - * @param row_width the size of the logical row in the elements of type `T`. - * @param batch_size the desired size of the batch. - * @param stream the ordering for the host->device copies, if applicable. - * @param mr a custom memory resource for the intermediate buffer, if applicable. - * @param prefetch enable prefetch feature in order to achieve kernel/copy overlapping. + * @param res raft resources (must outlive the iterator) + * @param input_view typed mdspan to iterate; row-major; passthrough vs copy is decided + * at compile time from the accessor. + * @param batch_size desired batch size in rows. Clamped to n_rows. + * @param copy_stream stream used for H2D / D2H copies in copy_device mode. Pass a non-main + * stream (see `get_prefetch_stream`) to enable real overlap. + * @param mr memory resource for the internal device buffer(s). + * @param prefetch enable 2-buffer pipelining via `prefetch_next_batch()`. + * @param initialize stage H2D source rows before yielding each batch (default true). + * @param host_writeback queue D2H of every advanced batch back to `input_view`. + * At least one of `initialize` / `host_writeback` must be true for + * non-empty input. */ - batch_load_iterator( - const T* source, - size_type n_rows, - size_type row_width, - size_type batch_size, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource_ref(), - bool prefetch = false) - : cur_batch_(new batch(source, n_rows, row_width, batch_size, stream, mr, prefetch)), + batch_load_iterator(raft::resources const& res, + MdspanT input_view, + size_type batch_size, + rmm::cuda_stream_view copy_stream, + rmm::device_async_resource_ref mr, + bool prefetch = false, + bool initialize = true, + bool host_writeback = false) + : cur_batch_(new value_type( + res, input_view, batch_size, copy_stream, mr, prefetch, initialize, host_writeback)), cur_pos_(0), cur_prefetch_pos_(0) { } - /** - * Whether this iterator copies the data on every iteration - * (i.e. the source is inaccessible from the device). - */ - [[nodiscard]] auto does_copy() const -> bool { return cur_batch_->does_copy(); } - /** Reset the iterator position and prefetch position to `begin()` */ + + /** Convenience overload that uses `get_workspace_resource_ref(res)` as the memory resource. */ + batch_load_iterator(raft::resources const& res, + MdspanT input_view, + size_type batch_size, + rmm::cuda_stream_view copy_stream, + bool prefetch = false, + bool initialize = true, + bool host_writeback = false) + : batch_load_iterator(res, + input_view, + batch_size, + copy_stream, + raft::resource::get_workspace_resource_ref(res), + prefetch, + initialize, + host_writeback) + { + } + + /** Whether iteration copies the data on each step (i.e. not passthrough). */ + [[nodiscard]] auto does_copy() const -> bool { return !kPassthrough; } + /** Reset the iterator (and prefetch) position to begin(). Reusable iteration. */ void reset() { cur_pos_ = 0; cur_prefetch_pos_ = 0; } - /** Reset the iterator position and prefetch position to `end()` */ + /** Reset the iterator (and prefetch) position to end(). */ void reset_to_end() { cur_pos_ = cur_batch_->n_iters_; cur_prefetch_pos_ = cur_batch_->n_iters_; } - [[nodiscard]] auto begin() const -> const batch_load_iterator + [[nodiscard]] auto begin() const -> const batch_load_iterator { - batch_load_iterator x(*this); + batch_load_iterator x(*this); x.reset(); return x; } - [[nodiscard]] auto end() const -> const batch_load_iterator + [[nodiscard]] auto end() const -> const batch_load_iterator { - batch_load_iterator x(*this); + batch_load_iterator x(*this); x.reset_to_end(); return x; } @@ -624,36 +919,35 @@ struct batch_load_iterator { cur_batch_->load(cur_pos_); return cur_batch_.get(); } - /* Prefetch next batch. Users are responsible for calling this method to enable kernel/copy - * overlapping. Note that this API is synchronous. */ + /** Issue the prefetch for the next-but-one batch. See class doc for stream semantics. */ void prefetch_next_batch() { cur_batch_->prefetch(cur_prefetch_pos_++); } - friend auto operator==(const batch_load_iterator& x, const batch_load_iterator& y) -> bool + friend auto operator==(const batch_load_iterator& x, const batch_load_iterator& y) -> bool { return x.cur_batch_ == y.cur_batch_ && x.cur_pos_ == y.cur_pos_; - }; - friend auto operator!=(const batch_load_iterator& x, const batch_load_iterator& y) -> bool + } + friend auto operator!=(const batch_load_iterator& x, const batch_load_iterator& y) -> bool { return x.cur_batch_ != y.cur_batch_ || x.cur_pos_ != y.cur_pos_; - }; - auto operator++() -> batch_load_iterator& + } + auto operator++() -> batch_load_iterator& { ++cur_pos_; return *this; } - auto operator++(int) -> batch_load_iterator + auto operator++(int) -> batch_load_iterator { - batch_load_iterator x(*this); + batch_load_iterator x(*this); ++cur_pos_; return x; } - auto operator--() -> batch_load_iterator& + auto operator--() -> batch_load_iterator& { --cur_pos_; return *this; } - auto operator--(int) -> batch_load_iterator + auto operator--(int) -> batch_load_iterator { - batch_load_iterator x(*this); + batch_load_iterator x(*this); --cur_pos_; return x; } @@ -664,4 +958,299 @@ struct batch_load_iterator { size_type cur_prefetch_pos_; }; +/** + * Runtime-dispatched wrapper over `batch_load_iterator` that takes a raw pointer and + * picks the host- or device-typed mdspan instantiation based on `cudaPointerGetAttributes`, + * preserving the legacy "force copy unless cudaMemoryTypeDevice" policy. + * + * Use this at call sites that don't know statically whether `ptr` is host or device memory. + * Sites that already hold a typed mdspan should use `batch_load_iterator` directly to + * pick the strategy at compile time. + */ +template +class batch_load_iterator_dyn { + using HostMd = raft::host_matrix_view; + using DeviceMd = raft::device_matrix_view; + using HostIter = batch_load_iterator; + using DeviceIter = batch_load_iterator; + + public: + using size_type = size_t; + + /** Uniform batch-proxy view across host/device branches. */ + class batch { + public: + [[nodiscard]] auto data() const -> T* { return dev_ptr_; } + [[nodiscard]] auto size() const -> size_type { return batch_len_; } + [[nodiscard]] auto offset() const -> size_type { return offset_; } + [[nodiscard]] auto row_width() const -> size_type { return row_width_; } + [[nodiscard]] auto does_copy() const -> bool { return does_copy_; } + [[nodiscard]] auto view() const -> raft::device_matrix_view + { + return raft::make_device_matrix_view( + dev_ptr_, static_cast(batch_len_), static_cast(row_width_)); + } + + private: + template + friend class batch_load_iterator_dyn; + T* dev_ptr_ = nullptr; + size_type batch_len_ = 0; + size_type offset_ = 0; + size_type row_width_ = 0; + bool does_copy_ = false; + }; + + using value_type = batch; + using reference = const value_type&; + using pointer = const value_type*; + + /** + * Construct via runtime pointer dispatch. + * + * If `ptr` is a pure-device pointer (`cudaMemoryTypeDevice` with non-null `devicePointer`), + * the device-accessor branch is selected (passthrough). Otherwise (host, pinned, managed, + * unregistered, HMM/ATS, or `nullptr`), the host-accessor branch is selected (copy_device). + */ + batch_load_iterator_dyn(raft::resources const& res, + T* ptr, + IdxT n_rows, + IdxT row_width, + size_type batch_size, + rmm::cuda_stream_view copy_stream, + rmm::device_async_resource_ref mr, + bool prefetch = false, + bool initialize = true, + bool host_writeback = false) + : impl_(make_impl(res, + ptr, + n_rows, + row_width, + batch_size, + copy_stream, + mr, + prefetch, + initialize, + host_writeback)), + proxy_(std::make_shared()) + { + } + + /** Convenience overload that uses `get_workspace_resource_ref(res)` as the memory resource. */ + batch_load_iterator_dyn(raft::resources const& res, + T* ptr, + IdxT n_rows, + IdxT row_width, + size_type batch_size, + rmm::cuda_stream_view copy_stream, + bool prefetch = false, + bool initialize = true, + bool host_writeback = false) + : batch_load_iterator_dyn(res, + ptr, + n_rows, + row_width, + batch_size, + copy_stream, + raft::resource::get_workspace_resource_ref(res), + prefetch, + initialize, + host_writeback) + { + } + + [[nodiscard]] auto does_copy() const -> bool + { + return std::visit([](auto const& it) { return it.does_copy(); }, impl_); + } + void reset() + { + std::visit([](auto& it) { it.reset(); }, impl_); + } + void reset_to_end() + { + std::visit([](auto& it) { it.reset_to_end(); }, impl_); + } + [[nodiscard]] auto begin() const -> batch_load_iterator_dyn + { + batch_load_iterator_dyn x(*this); + x.reset(); + return x; + } + [[nodiscard]] auto end() const -> batch_load_iterator_dyn + { + batch_load_iterator_dyn x(*this); + x.reset_to_end(); + return x; + } + [[nodiscard]] auto operator*() const -> reference + { + std::visit( + [this](auto const& it) { + auto const& b = *it; + proxy_->dev_ptr_ = const_cast(b.data()); + proxy_->batch_len_ = b.size(); + proxy_->offset_ = b.offset(); + proxy_->row_width_ = b.row_width(); + proxy_->does_copy_ = b.does_copy(); + }, + impl_); + return *proxy_; + } + [[nodiscard]] auto operator->() const -> pointer + { + (void)**this; + return proxy_.get(); + } + void prefetch_next_batch() + { + std::visit([](auto& it) { it.prefetch_next_batch(); }, impl_); + } + friend auto operator==(const batch_load_iterator_dyn& x, const batch_load_iterator_dyn& y) -> bool + { + return x.impl_ == y.impl_; + } + friend auto operator!=(const batch_load_iterator_dyn& x, const batch_load_iterator_dyn& y) -> bool + { + return !(x == y); + } + auto operator++() -> batch_load_iterator_dyn& + { + std::visit([](auto& it) { ++it; }, impl_); + return *this; + } + auto operator++(int) -> batch_load_iterator_dyn + { + batch_load_iterator_dyn x(*this); + ++(*this); + return x; + } + auto operator--() -> batch_load_iterator_dyn& + { + std::visit([](auto& it) { --it; }, impl_); + return *this; + } + auto operator--(int) -> batch_load_iterator_dyn + { + batch_load_iterator_dyn x(*this); + --(*this); + return x; + } + + private: + std::variant impl_; + // Shared proxy: copies of the iterator share storage so that `*it1++` and `*it2` consistently + // observe the same backing buffer state, mirroring the legacy shared-batch contract. + std::shared_ptr proxy_; + + static auto make_impl(raft::resources const& res, + T* ptr, + IdxT n_rows, + IdxT row_width, + size_type batch_size, + rmm::cuda_stream_view copy_stream, + rmm::device_async_resource_ref mr, + bool prefetch, + bool initialize, + bool host_writeback) -> std::variant + { + bool is_pure_device = false; + if (ptr != nullptr) { + cudaPointerAttributes attr{}; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, ptr)); + is_pure_device = (attr.type == cudaMemoryTypeDevice) && (attr.devicePointer != nullptr); + } + if (is_pure_device) { + return DeviceIter(res, + raft::make_device_matrix_view(ptr, n_rows, row_width), + batch_size, + copy_stream, + mr, + prefetch, + initialize, + host_writeback); + } + return HostIter(res, + raft::make_host_matrix_view(ptr, n_rows, row_width), + batch_size, + copy_stream, + mr, + prefetch, + initialize, + host_writeback); + } +}; + +// Locally-rolled `type_identity_t` so this header compiles in TUs that build with C++17 +// (e.g. the cuvs C tests). Equivalent to `std::type_identity_t` (C++20). +namespace detail { +template +struct type_identity { + using type = T; +}; +template +using type_identity_t = typename type_identity::type; +} // namespace detail + +/** + * Builder for `batch_load_iterator_dyn`. Use at sites that have a raw pointer with + * unknown memory location and want the legacy "force copy unless cudaMemoryTypeDevice" semantics. + * + * `ptr` is taken by `T const*` so callers can pass either a `T*` or a `const T*` (matching the + * legacy `batch_load_iterator(const T* source, ...)` API). The iterator's element type is `T` + * (non-const), so `batch.data()` returns `T*` for kernels that expect a non-const view of the + * source. Const-correctness at the API boundary is the caller's responsibility. + * + * `n_rows` / `row_width` are placed in non-deduced contexts so that `IdxT` is taken from the + * explicit template argument (or its `int64_t` default) and the integer arguments are implicitly + * converted to `IdxT`, regardless of their incoming integer type. + */ +template +auto make_batch_load_iterator(raft::resources const& res, + T const* ptr, + detail::type_identity_t n_rows, + detail::type_identity_t row_width, + size_t batch_size, + rmm::cuda_stream_view copy_stream, + rmm::device_async_resource_ref mr, + bool prefetch = false, + bool initialize = true, + bool host_writeback = false) -> batch_load_iterator_dyn +{ + return batch_load_iterator_dyn(res, + const_cast(ptr), + n_rows, + row_width, + batch_size, + copy_stream, + mr, + prefetch, + initialize, + host_writeback); +} + +/** Convenience overload that uses `get_workspace_resource_ref(res)` as the memory resource. */ +template +auto make_batch_load_iterator(raft::resources const& res, + T const* ptr, + detail::type_identity_t n_rows, + detail::type_identity_t row_width, + size_t batch_size, + rmm::cuda_stream_view copy_stream, + bool prefetch = false, + bool initialize = true, + bool host_writeback = false) -> batch_load_iterator_dyn +{ + return make_batch_load_iterator(res, + ptr, + n_rows, + row_width, + batch_size, + copy_stream, + raft::resource::get_workspace_resource_ref(res), + prefetch, + initialize, + host_writeback); +} + } // namespace cuvs::spatial::knn::detail::utils diff --git a/cpp/src/neighbors/detail/cagra/add_nodes.cuh b/cpp/src/neighbors/detail/cagra/add_nodes.cuh index f198cf957d..5d0a6654e9 100644 --- a/cpp/src/neighbors/detail/cagra/add_nodes.cuh +++ b/cpp/src/neighbors/detail/cagra/add_nodes.cuh @@ -84,10 +84,11 @@ void add_node_core( auto host_neighbor_indices = raft::make_host_matrix(max_search_batch_size, base_degree); - cuvs::spatial::knn::detail::utils::batch_load_iterator additional_dataset_batch( + auto additional_dataset_batch = cuvs::spatial::knn::detail::utils::make_batch_load_iterator( + handle, additional_dataset_view.data_handle(), - num_add, - additional_dataset_view.stride(0), + static_cast(num_add), + static_cast(additional_dataset_view.stride(0)), max_search_batch_size, raft::resource::get_cuda_stream(handle), mr); diff --git a/cpp/src/neighbors/detail/cagra/cagra_build.cuh b/cpp/src/neighbors/detail/cagra/cagra_build.cuh index dd2042bd12..96ff8344d3 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_build.cuh @@ -831,29 +831,34 @@ inline std::pair optimize_workspace_size(size_t n_rows, mst_host += (graph_degree - 1) * (graph_degree - 1) * index_size; // iB_candidates } + // batchsize for both prune and combine stages + size_t batch_size = std::min(static_cast(256 * 1024), n_rows); + // Prune stage memory // We neglect 8 bytes (both on host and device) for stats - size_t prune_host = n_rows * intermediate_degree * sizeof(uint8_t); // detour count - - size_t prune_dev = n_rows * intermediate_degree * 1; // detour count (uint8_t) - prune_dev += n_rows * sizeof(uint32_t); // d_num_detour_edges - prune_dev += n_rows * intermediate_degree * index_size; // d_input_graph + size_t prune_dev = batch_size * intermediate_degree * 1; // detour count (uint8_t) + prune_dev += batch_size * sizeof(uint32_t); // d_num_detour_edges + prune_dev += n_rows * intermediate_degree * index_size; // d_input_graph + prune_dev += 2 * batch_size * graph_degree * index_size; // d_output_graph(2*batch) // Reverse graph stage memory - size_t rev_host = n_rows * graph_degree * index_size; // rev_graph - rev_host += n_rows * sizeof(uint32_t); // rev_graph_count - rev_host += n_rows * index_size; // dest_nodes - size_t rev_dev = n_rows * graph_degree * index_size; // d_rev_graph rev_dev += n_rows * sizeof(uint32_t); // d_rev_graph_count - rev_dev += n_rows * sizeof(uint32_t); // d_dest_nodes + rev_dev += n_rows * index_size; // d_dest_nodes - // Memory for merging graphs (host only) + // Memory for merging graphs (host only optional) size_t combine_host = n_rows * sizeof(uint32_t) + graph_degree * sizeof(uint32_t); // in_edge_count + hist - size_t total_host = mst_host + std::max({prune_host, rev_host, combine_host}); - size_t total_dev = std::max(prune_dev, rev_dev); + // additional memory for combine stage on device (3 batches) + size_t combine_dev = 2 * batch_size * graph_degree * index_size; // d_output_graph(2*batch) + if (mst_optimize) { + combine_dev += 2 * batch_size * graph_degree * index_size; // d_mst_graph(2*batch) + combine_dev += 2 * batch_size * sizeof(uint32_t); // d_mst_graph_num_edges(2*batch) + } + + size_t total_host = mst_host + combine_host; + size_t total_dev = std::max(prune_dev, rev_dev + combine_dev); return std::make_pair(total_host, total_dev); } @@ -1725,11 +1730,12 @@ void build_knn_graph( bool first = true; const auto start_clock = std::chrono::system_clock::now(); - cuvs::spatial::knn::detail::utils::batch_load_iterator vec_batches( + auto vec_batches = cuvs::spatial::knn::detail::utils::make_batch_load_iterator( + res, dataset.data_handle(), - dataset.extent(0), - dataset.extent(1), - static_cast(max_queries), + static_cast(dataset.extent(0)), + static_cast(dataset.extent(1)), + static_cast(max_queries), raft::resource::get_cuda_stream(res), workspace_mr); @@ -2110,10 +2116,11 @@ auto iterative_build_graph( // Search. // Since there are many queries, divide them into batches and search them. - cuvs::spatial::knn::detail::utils::batch_load_iterator query_batch( + auto query_batch = cuvs::spatial::knn::detail::utils::make_batch_load_iterator( + res, dev_query_view.data_handle(), - curr_query_size, - dev_query_view.extent(1), + static_cast(curr_query_size), + static_cast(dev_query_view.extent(1)), max_chunk_size, raft::resource::get_cuda_stream(res), raft::resource::get_workspace_resource_ref(res)); diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index d16759b2b0..2c21018deb 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -22,17 +22,23 @@ #include #include +#include #include #include #include +#include +#include + #include #include #include #include +namespace cg = cooperative_groups; + namespace cuvs::neighbors::cagra::detail::graph { // unnamed namespace to avoid multiple definition error @@ -165,42 +171,98 @@ __global__ void kern_sort(const DATA_T* const dataset, // [dataset_chunk_size, } } -template -__global__ void kern_prune(const IdxT* const knn_graph, // [graph_chunk_size, graph_degree] - const uint32_t graph_size, - const uint32_t graph_degree, - const uint32_t degree, - const uint32_t batch_size, - const uint32_t batch_id, - uint8_t* const detour_count, // [graph_chunk_size, graph_degree] - uint32_t* const num_no_detour_edges, // [graph_size] - uint64_t* const stats) +template +__global__ void kern_make_rev_graph_k( + OutputMatrixView output_graph, // [graph_size, degree] + raft::device_matrix_view rev_graph, // [graph_size, degree] + raft::device_vector_view rev_graph_count, // [graph_size] + uint64_t k) +{ + const uint64_t tid = threadIdx.x + (blockDim.x * blockIdx.x); + const uint64_t tnum = blockDim.x * gridDim.x; + + const uint64_t graph_size = rev_graph.extent(0); + const uint32_t rev_graph_degree = rev_graph.extent(1); + const uint32_t output_graph_degree = output_graph.extent(1); + + for (uint64_t src_id = tid; src_id < graph_size; src_id += tnum) { + IdxT dest_id = output_graph(src_id, k); + if (dest_id >= graph_size) continue; + + const uint32_t pos = atomicAdd(&rev_graph_count(dest_id), 1); + if (pos < rev_graph_degree) { rev_graph(dest_id, pos) = static_cast(src_id); } + } +} + +// KnnGraphView and OutputGraphView are mdspans with element type IdxT and +// dynamic 2D extents. They may have layout_right (raft::device_matrix_view) or +// layout_stride (the result of cuda::std::submdspan), and any accessor that is +// device-accessible (default_accessor, raft::host_device_accessor with a +// device-accessible memory_type, etc.). +template +__global__ void kern_fused_prune(KnnGraphView knn_graph, // [graph_chunk_size, graph_degree] + OutputGraphView output_graph, // [batch_size, output_graph_degree] + const uint32_t batch_size, + const uint32_t batch_id, + uint32_t* const d_invalid_neighbor_list, + uint64_t* const stats) { - __shared__ uint32_t smem_num_detour[MAX_DEGREE]; + // Check assumption we have at least one warp per row of the batch + assert(blockDim.x == raft::WarpSize * num_warps); + assert(gridDim.x * num_warps >= batch_size); + + extern __shared__ unsigned char smem_buf[]; + + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); + + const uint32_t wid = threadIdx.x / raft::WarpSize; + const uint32_t lane_id = threadIdx.x % raft::WarpSize; + + const uint64_t graph_size = knn_graph.extent(0); + const uint32_t knn_graph_degree = knn_graph.extent(1); + const uint32_t output_graph_degree = output_graph.extent(1); + + IdxT* const smem_indices = + reinterpret_cast(smem_buf + wid * knn_graph_degree * sizeof(IdxT)); + uint32_t* const smem_num_detour = + reinterpret_cast(smem_buf + num_warps * knn_graph_degree * sizeof(IdxT) + + wid * knn_graph_degree * sizeof(uint32_t)); + +#ifndef NDEBUG uint64_t* const num_retain = stats; uint64_t* const num_full = stats + 1; +#endif + + const uint32_t maxval16 = 0x0000ffff; - const uint64_t iA = blockIdx.x + (batch_size * batch_id); - if (iA >= graph_size) { return; } - for (uint32_t k = threadIdx.x; k < graph_degree; k += blockDim.x) { + const uint32_t nid_batch = blockIdx.x * num_warps + wid; + const uint64_t nid = static_cast(nid_batch) + + (static_cast(batch_size) * static_cast(batch_id)); + + if (nid >= graph_size) { return; } + + // Load this node's neighbor row into shared memory to reduce global reads + for (uint32_t k = lane_id; k < knn_graph_degree; k += raft::WarpSize) { smem_num_detour[k] = 0; - if (knn_graph[k + ((uint64_t)graph_degree * iA)] == iA) { + smem_indices[k] = knn_graph(nid, k); + if (smem_indices[k] == nid) { // Lower the priority of self-edge - smem_num_detour[k] = graph_degree; + smem_num_detour[k] = knn_graph_degree; } } - __syncthreads(); + warp.sync(); // count number of detours (A->D->B) - for (uint32_t kAD = 0; kAD < graph_degree - 1; kAD++) { - const uint64_t iD = knn_graph[kAD + (graph_degree * iA)]; + for (uint32_t kAD = 0; kAD < knn_graph_degree - 1; kAD++) { + const uint64_t iD = smem_indices[kAD]; if (iD >= graph_size) { continue; } - for (uint32_t kDB = threadIdx.x; kDB < graph_degree; kDB += blockDim.x) { - const uint64_t iB_candidate = knn_graph[kDB + ((uint64_t)graph_degree * iD)]; - for (uint32_t kAB = kAD + 1; kAB < graph_degree; kAB++) { + for (uint32_t kDB = lane_id; kDB < knn_graph_degree; kDB += raft::WarpSize) { + const uint64_t iB_candidate = knn_graph(iD, kDB); + for (uint32_t kAB = kAD + 1; kAB < knn_graph_degree; kAB++) { // if ( kDB < kAB ) { - const uint64_t iB = knn_graph[kAB + (graph_degree * iA)]; + const uint64_t iB = smem_indices[kAB]; if (iB == iB_candidate) { atomicAdd(smem_num_detour + kAB, 1); break; @@ -208,44 +270,203 @@ __global__ void kern_prune(const IdxT* const knn_graph, // [graph_chunk_size, g } } } - __syncthreads(); + warp.sync(); } uint32_t num_edges_no_detour = 0; - for (uint32_t k = threadIdx.x; k < graph_degree; k += blockDim.x) { - detour_count[k + (graph_degree * iA)] = min(smem_num_detour[k], (uint32_t)255); + for (uint32_t k = lane_id; k < knn_graph_degree; k += raft::WarpSize) { + smem_num_detour[k] = min(smem_num_detour[k], maxval16); if (smem_num_detour[k] == 0) { num_edges_no_detour++; } + if (smem_indices[k] >= graph_size) { smem_num_detour[k] = maxval16; } } - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 1); - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 2); - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 4); - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 8); - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 16); - num_edges_no_detour = min(num_edges_no_detour, degree); - if (threadIdx.x == 0) { - num_no_detour_edges[iA] = num_edges_no_detour; + warp.sync(); + + num_edges_no_detour = cg::reduce(warp, num_edges_no_detour, cg::plus()); + num_edges_no_detour = min(num_edges_no_detour, output_graph_degree); + +#ifndef NDEBUG + if (lane_id == 0) { atomicAdd((unsigned long long int*)num_retain, (unsigned long long int)num_edges_no_detour); - if (num_edges_no_detour >= degree) { atomicAdd((unsigned long long int*)num_full, 1); } + if (num_edges_no_detour >= output_graph_degree) { + atomicAdd((unsigned long long int*)num_full, 1); + } + } +#endif + + for (uint32_t i = 0; i < output_graph_degree; i++) { + uint32_t local_min = maxval16; + uint32_t local_idx = maxval16; + for (uint32_t k = lane_id; k < knn_graph_degree; k += raft::WarpSize) { + if (smem_num_detour[k] < local_min) { + local_min = smem_num_detour[k]; + local_idx = k; + } + } + + uint32_t local_min_with_tag = (local_min << 16) | ((uint32_t)local_idx); + uint32_t warp_min_with_tag = cg::reduce(warp, local_min_with_tag, cg::less()); + uint32_t warp_min_count = warp_min_with_tag >> 16; + uint32_t warp_local_idx = warp_min_with_tag & 0xffff; + + if (warp_min_count == maxval16 || warp_local_idx == maxval16) { + if (lane_id == 0) { atomicExch(d_invalid_neighbor_list, 1u); } + break; + } + + IdxT selected_node = smem_indices[warp_local_idx]; + + for (uint32_t k = lane_id; k < knn_graph_degree; k += raft::WarpSize) { + if (smem_indices[k] == selected_node) { smem_num_detour[k] = maxval16; } + } + warp.sync(); + + if (lane_id == 0) { output_graph(nid_batch, i) = selected_node; } } } -template -__global__ void kern_make_rev_graph(const IdxT* const dest_nodes, // [graph_size] - IdxT* const rev_graph, // [size, degree] - uint32_t* const rev_graph_count, // [graph_size] - const uint32_t graph_size, - const uint32_t degree) +// Helper functions for merging the graph +template +__device__ unsigned int warp_pos_in_array(T val, const T* array, uint64_t num) { - const uint32_t tid = threadIdx.x + (blockDim.x * blockIdx.x); - const uint32_t tnum = blockDim.x * gridDim.x; + unsigned int ret = num; + const uint32_t lane_id = threadIdx.x % 32; + for (uint64_t i = lane_id; i < num; i += 32) { + if (val == array[i]) { + ret = i; + break; + } + } - for (uint32_t src_id = tid; src_id < graph_size; src_id += tnum) { - const IdxT dest_id = dest_nodes[src_id]; - if (dest_id >= graph_size) continue; + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); + ret = cg::reduce(warp, ret, cg::less()); + return ret; +} + +template +__device__ void warp_shift_array_one_right(uint32_t lane_id, T* array, uint64_t num) +{ + if (num == 0) { return; } + for (auto chunk_end = static_cast(num); chunk_end >= 1; chunk_end -= 31) { + const int64_t chunk_start_lo = chunk_end - 31; + const int64_t chunk_start = (chunk_start_lo > 0) ? chunk_start_lo : 0; + const int64_t k = chunk_start + static_cast(lane_id); + T val{}; + const bool read_active = (k <= chunk_end); + if (read_active) { val = array[k]; } + const T shifted = raft::shfl_up(val, 1); + const bool write_active = (lane_id > 0) && read_active; + if (write_active) { array[k] = shifted; } + } +} + +// OutputGraphView, MstGraphView and MstNumEdgesView are 2D mdspans that may +// have layout_right or layout_stride and any device-accessible accessor; see +// the comment on kern_fused_prune for details. +template +__global__ void kern_merge_graph( + OutputGraphView output_graph, // [batch_size, output_graph_degree] + raft::device_matrix_view rev_graph, // [graph_size, output_graph_degree] + raft::device_vector_view rev_graph_count, // [graph_size] + MstGraphView mst_graph, // [batch_size, output_graph_degree] + MstNumEdgesView mst_graph_num_edges, // [batch_size, 1] + const uint32_t batch_size, + const uint32_t batch_id, + bool guarantee_connectivity, + uint32_t* check_num_protected_edges) +{ + // Check assumption we have at least one warp per row of the batch + assert(blockDim.x == raft::WarpSize * num_warps); + assert(gridDim.x * num_warps >= batch_size); + + const uint64_t graph_size = rev_graph.extent(0); + const uint32_t output_graph_degree = output_graph.extent(1); + const uint32_t mst_graph_degree = mst_graph.extent(1); + + extern __shared__ unsigned char smem_buf[]; + + const uint32_t wid = threadIdx.x / raft::WarpSize; + const uint32_t lane_id = threadIdx.x % raft::WarpSize; + + IdxT* smem_sorted_output_graph = + reinterpret_cast(smem_buf + wid * output_graph_degree * sizeof(IdxT)); + + const uint32_t nid_batch = blockIdx.x * num_warps + wid; + const uint64_t nid = static_cast(nid_batch) + + (static_cast(batch_size) * static_cast(batch_id)); + + if (nid >= graph_size) { return; } + + const auto current_mst_graph_num_edges = + guarantee_connectivity ? mst_graph_num_edges(nid_batch, 0) : 0; + // If guarantee_connectivity == true, use a temporal list to merge the + // neighbor lists of the graphs. + if (guarantee_connectivity) { + for (uint32_t i = lane_id; i < current_mst_graph_num_edges; i += raft::WarpSize) { + smem_sorted_output_graph[i] = mst_graph(nid_batch, i); + } + __syncwarp(); + for (uint32_t pruned_j = 0, output_j = current_mst_graph_num_edges; + (pruned_j < output_graph_degree) && (output_j < output_graph_degree); + pruned_j++) { + const auto v = output_graph(nid_batch, pruned_j); + unsigned int dup = 0; + for (uint32_t m = lane_id; m < output_j; m += raft::WarpSize) { + if (v == smem_sorted_output_graph[m]) { + dup = 1; + break; + } + } + + auto warp_dup = raft::ballot(dup); + if (warp_dup == 0) { + if (lane_id == 0) smem_sorted_output_graph[output_j] = v; + output_j++; + } + __syncwarp(); + } + } + + else { + for (uint32_t i = lane_id; i < output_graph_degree; i += raft::WarpSize) { + smem_sorted_output_graph[i] = output_graph(nid_batch, i); + } + __syncwarp(); + } + + const auto num_protected_edges = max(current_mst_graph_num_edges, output_graph_degree / 2); + + if (num_protected_edges > output_graph_degree) { + check_num_protected_edges[0] = 0u; + return; + } + if (num_protected_edges == output_graph_degree) { return; } + + auto kr = min(rev_graph_count(nid), output_graph_degree); + + while (kr) { + kr -= 1; + const auto rev_graph_value = rev_graph(nid, kr); + if (rev_graph_value < graph_size) { + uint64_t pos = + warp_pos_in_array(rev_graph_value, smem_sorted_output_graph, output_graph_degree); + if (pos < num_protected_edges) { continue; } + uint64_t num_shift = pos - num_protected_edges; + if (pos >= output_graph_degree) { num_shift = output_graph_degree - num_protected_edges - 1; } + warp_shift_array_one_right( + lane_id, smem_sorted_output_graph + num_protected_edges, num_shift); + if (lane_id == 0) { smem_sorted_output_graph[num_protected_edges] = rev_graph_value; } + __syncwarp(); + } + } - const uint32_t pos = atomicAdd(rev_graph_count + dest_id, 1); - if (pos < degree) { rev_graph[pos + ((uint64_t)degree * dest_id)] = src_id; } + for (uint32_t i = lane_id; i < output_graph_degree; i += raft::WarpSize) { + output_graph(nid_batch, i) = smem_sorted_output_graph[i]; } } @@ -301,7 +522,8 @@ __global__ void kern_mst_opt_update_graph(IdxT* mst_graph, // [graph_size, grap ret = 0; auto kj = atomicAdd(incoming_num_edges + j, (IdxT)1); if (kj < incoming_max_edges[j]) { - auto ki = outgoing_num_edges[i]++; + auto ki = outgoing_num_edges[i]; + outgoing_num_edges[i] = ki + 1; mst_graph[(graph_degree * (i)) + ki] = j; // outgoing mst_graph[(graph_degree * (j + 1)) - 1 - kj] = i; // incoming ret = 1; @@ -327,7 +549,7 @@ __global__ void kern_mst_opt_update_graph(IdxT* mst_graph, // [graph_size, grap // Check to avoid duplication for (uint64_t kl = 0; kl < graph_degree; kl++) { uint64_t m = mst_graph[(graph_degree * l) + kl]; - if (m > graph_size) continue; + if (m >= graph_size) continue; uint32_t rm = get_root_label(m, label); if (ri == rm) { ret = 0; @@ -356,30 +578,31 @@ __global__ void kern_mst_opt_labeling(IdxT* label, // [graph_size] uint64_t* stats) { const uint64_t i = threadIdx.x + (blockDim.x * blockIdx.x); - if (i >= graph_size) return; __shared__ uint32_t smem_updated[1]; if (threadIdx.x == 0) { smem_updated[0] = 0; } __syncthreads(); - for (uint64_t ki = 0; ki < graph_degree; ki++) { - uint64_t j = mst_graph[(graph_degree * i) + ki]; - if (j >= graph_size) continue; + if (i < graph_size) { + for (uint64_t ki = 0; ki < graph_degree; ki++) { + uint64_t j = mst_graph[(graph_degree * i) + ki]; + if (j >= graph_size) continue; - IdxT li = label[i]; - IdxT ri = get_root_label(i, label); - if (ri < li) { atomicMin(label + i, ri); } - IdxT lj = label[j]; - IdxT rj = get_root_label(j, label); - if (rj < lj) { atomicMin(label + j, rj); } - if (ri == rj) continue; - - if (ri > rj) { - atomicCAS(label + i, ri, rj); - } else if (rj > ri) { - atomicCAS(label + j, rj, ri); + IdxT li = label[i]; + IdxT ri = get_root_label(i, label); + if (ri < li) { atomicMin(label + i, ri); } + IdxT lj = label[j]; + IdxT rj = get_root_label(j, label); + if (rj < lj) { atomicMin(label + j, rj); } + if (ri == rj) continue; + + if (ri > rj) { + atomicCAS(label + i, ri, rj); + } else if (rj > ri) { + atomicCAS(label + j, rj, ri); + } + smem_updated[0] = 1; } - smem_updated[0] = 1; } __syncthreads(); @@ -393,18 +616,19 @@ __global__ void kern_mst_opt_cluster_size(IdxT* cluster_size, // [graph_size] uint64_t* stats) { const uint64_t i = threadIdx.x + (blockDim.x * blockIdx.x); - if (i >= graph_size) return; __shared__ uint64_t smem_num_clusters[1]; if (threadIdx.x == 0) { smem_num_clusters[0] = 0; } __syncthreads(); - IdxT ri = get_root_label(i, label); - if (ri == i) { - atomicAdd((unsigned long long int*)smem_num_clusters, 1); - } else { - atomicAdd(cluster_size + ri, cluster_size[i]); - cluster_size[i] = 0; + if (i < graph_size) { + IdxT ri = get_root_label(i, label); + if (ri == i) { + atomicAdd((unsigned long long int*)smem_num_clusters, 1); + } else { + atomicAdd(cluster_size + ri, cluster_size[i]); + cluster_size[i] = 0; + } } __syncthreads(); @@ -424,8 +648,6 @@ __global__ void kern_mst_opt_postprocessing(IdxT* outgoing_num_edges, // [graph uint64_t* stats) { const uint64_t i = threadIdx.x + (blockDim.x * blockIdx.x); - if (i >= graph_size) return; - __shared__ uint64_t smem_cluster_size_min[1]; __shared__ uint64_t smem_cluster_size_max[1]; __shared__ uint64_t smem_total_outgoing_edges[1]; @@ -438,34 +660,36 @@ __global__ void kern_mst_opt_postprocessing(IdxT* outgoing_num_edges, // [graph } __syncthreads(); - // Adjust incoming_num_edges - if (incoming_num_edges[i] > incoming_max_edges[i]) { - incoming_num_edges[i] = incoming_max_edges[i]; - } - - // Calculate min/max of cluster_size - if (cluster_size[i] > 0) { - if (smem_cluster_size_min[0] > cluster_size[i]) { - atomicMin((unsigned long long int*)smem_cluster_size_min, - (unsigned long long int)(cluster_size[i])); + if (i < graph_size) { + // Adjust incoming_num_edges + if (incoming_num_edges[i] > incoming_max_edges[i]) { + incoming_num_edges[i] = incoming_max_edges[i]; } - if (smem_cluster_size_max[0] < cluster_size[i]) { - atomicMax((unsigned long long int*)smem_cluster_size_max, - (unsigned long long int)(cluster_size[i])); + + // Calculate min/max of cluster_size + if (cluster_size[i] > 0) { + if (smem_cluster_size_min[0] > cluster_size[i]) { + atomicMin((unsigned long long int*)smem_cluster_size_min, + (unsigned long long int)(cluster_size[i])); + } + if (smem_cluster_size_max[0] < cluster_size[i]) { + atomicMax((unsigned long long int*)smem_cluster_size_max, + (unsigned long long int)(cluster_size[i])); + } } - } - // Calculate total number of outgoing/incoming edges - atomicAdd((unsigned long long int*)smem_total_outgoing_edges, - (unsigned long long int)(outgoing_num_edges[i])); - atomicAdd((unsigned long long int*)smem_total_incoming_edges, - (unsigned long long int)(incoming_num_edges[i])); - - // Adjust incoming/outgoing_max_edges - if (outgoing_num_edges[i] == outgoing_max_edges[i]) { - if (outgoing_num_edges[i] + incoming_num_edges[i] < graph_degree) { - outgoing_max_edges[i] += 1; - incoming_max_edges[i] -= 1; + // Calculate total number of outgoing/incoming edges + atomicAdd((unsigned long long int*)smem_total_outgoing_edges, + (unsigned long long int)(outgoing_num_edges[i])); + atomicAdd((unsigned long long int*)smem_total_incoming_edges, + (unsigned long long int)(incoming_num_edges[i])); + + // Adjust incoming/outgoing_max_edges + if (outgoing_num_edges[i] == outgoing_max_edges[i]) { + if (outgoing_num_edges[i] + incoming_num_edges[i] < graph_degree) { + outgoing_max_edges[i] += 1; + incoming_max_edges[i] -= 1; + } } } @@ -482,23 +706,270 @@ __global__ void kern_mst_opt_postprocessing(IdxT* outgoing_num_edges, // [graph } } -template -uint64_t pos_in_array(T val, const T* array, uint64_t num) +template +void log_incoming_edges_histogram( + raft::resources const& res, + raft::mdspan, raft::row_major, AccessorOutputGraph> + output_graph) +{ + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/check_edges"); + const uint64_t graph_size = output_graph.extent(0); + const uint64_t output_graph_degree = output_graph.extent(1); + + // copy to host if required + IdxT* output_graph_ptr = nullptr; + int64_t buffer_size = AccessorOutputGraph::is_host_accessible ? graph_size : 0; + auto host_copy_output_graph = + raft::make_host_matrix(buffer_size, output_graph_degree); + if constexpr (AccessorOutputGraph::is_host_accessible) { + output_graph_ptr = output_graph.data_handle(); + } else { + raft::copy(res, host_copy_output_graph.view(), output_graph); + output_graph_ptr = host_copy_output_graph.data_handle(); + } + + auto in_edge_count = raft::make_host_vector(graph_size); + auto in_edge_count_ptr = in_edge_count.data_handle(); +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + in_edge_count_ptr[i] = 0; + } +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + for (uint64_t k = 0; k < output_graph_degree; k++) { + const uint64_t j = output_graph_ptr[k + (output_graph_degree * i)]; + if (j >= graph_size) continue; +#pragma omp atomic + in_edge_count_ptr[j] += 1; + } + } + auto hist = raft::make_host_vector(output_graph_degree); + auto hist_ptr = hist.data_handle(); + for (uint64_t k = 0; k < output_graph_degree; k++) { + hist_ptr[k] = 0; + } +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + uint32_t count = in_edge_count_ptr[i]; + if (count >= output_graph_degree) continue; +#pragma omp atomic + hist_ptr[count] += 1; + } + RAFT_LOG_DEBUG("# Histogram for number of incoming edges\n"); + uint32_t sum_hist = 0; + for (uint64_t k = 0; k < output_graph_degree; k++) { + sum_hist += hist_ptr[k]; + RAFT_LOG_DEBUG("# %3lu, %8u, %lf, (%8u, %lf)\n", + k, + hist_ptr[k], + (double)hist_ptr[k] / graph_size, + sum_hist, + (double)sum_hist / graph_size); + } +} + +template +void check_duplicates_and_out_of_range( + raft::resources const& res, + raft::mdspan, raft::row_major, AccessorOutputGraph> + output_graph) { - for (uint64_t i = 0; i < num; i++) { - if (val == array[i]) { return i; } + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/check_duplicates"); + + const uint64_t graph_size = output_graph.extent(0); + const uint64_t output_graph_degree = output_graph.extent(1); + + // copy to host if required + IdxT* output_graph_ptr = nullptr; + int64_t buffer_size = AccessorOutputGraph::is_host_accessible ? graph_size : 0; + auto host_copy_output_graph = + raft::make_host_matrix(buffer_size, output_graph_degree); + if constexpr (AccessorOutputGraph::is_host_accessible) { + output_graph_ptr = output_graph.data_handle(); + } else { + raft::copy(res, host_copy_output_graph.view(), output_graph); + output_graph_ptr = host_copy_output_graph.data_handle(); + } + + uint64_t num_dup = 0; + uint64_t num_oor = 0; +#pragma omp parallel for reduction(+ : num_dup) reduction(+ : num_oor) + for (uint64_t i = 0; i < graph_size; i++) { + auto my_out_graph = output_graph_ptr + (output_graph_degree * i); + for (uint32_t j = 0; j < output_graph_degree; j++) { + const auto neighbor_a = my_out_graph[j]; + + if (neighbor_a >= graph_size) { + num_oor++; + continue; + } + + for (uint32_t k = j + 1; k < output_graph_degree; k++) { + const auto neighbor_b = my_out_graph[k]; + if (neighbor_a == neighbor_b) { + num_dup++; + break; + } + } + } } - return num; + RAFT_EXPECTS( + num_dup == 0, "%lu duplicated node(s) are found in the generated CAGRA graph", num_dup); + RAFT_EXPECTS( + num_oor == 0, "%lu out-of-range index node(s) are found in the generated CAGRA graph", num_oor); } -template -void shift_array(T* array, uint64_t num) +template +void merge_graph_gpu( + raft::resources const& res, + raft::mdspan, raft::row_major, AccessorOutputGraph> + output_graph, + raft::device_matrix_view d_rev_graph, + raft::device_vector_view d_rev_graph_count, + raft::host_matrix_view mst_graph, + raft::host_vector_view mst_graph_num_edges, + bool guarantee_connectivity) { - for (uint64_t i = num; i > 0; i--) { - array[i] = array[i - 1]; + const uint64_t graph_size = output_graph.extent(0); + const uint64_t output_graph_degree = output_graph.extent(1); + + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/combine"); + + const double merge_graph_start = cur_time(); + + auto d_check_num_protected_edges = raft::make_device_scalar(res, 1u); + + // The batchsize is statically set to 256 * 1024 which corresponds to 256MB for a graph + // degree of 128 and 16byte index type. This is a trade-off between memory usage and performance. + // When choosing dynamically based on available memory, we would also need to modify the static + // size assumption in the cagra_build.cuh::optimize_workspace_size function. + uint32_t batch_size = static_cast(std::min(graph_size, 256 * 1024)); + const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; + + namespace bli = cuvs::spatial::knn::detail::utils; + auto [copy_stream, enable_prefetch] = bli::get_prefetch_stream(res); + auto workspace_mr = raft::resource::get_workspace_resource_ref(res); + + bli::batch_load_iterator< + raft::mdspan, raft::row_major, AccessorOutputGraph>> + d_output_graph(res, + output_graph, + batch_size, + copy_stream, + workspace_mr, + enable_prefetch, + /*initialize=*/true, + /*host_writeback=*/true); + + bli::batch_load_iterator> d_mst_graph( + res, mst_graph, batch_size, copy_stream, workspace_mr, enable_prefetch); + + bli::batch_load_iterator> d_mst_graph_num_edges( + res, + raft::make_host_matrix_view( + mst_graph_num_edges.data_handle(), mst_graph_num_edges.extent(0), 1l), + batch_size, + copy_stream, + workspace_mr, + enable_prefetch); + + d_output_graph.prefetch_next_batch(); + d_mst_graph.prefetch_next_batch(); + d_mst_graph_num_edges.prefetch_next_batch(); + + const uint32_t num_warps = 4; + const dim3 threads_merge(raft::WarpSize * num_warps, 1, 1); + const dim3 blocks_merge(raft::ceildiv(batch_size, num_warps), 1, 1); + const size_t merge_smem_size = num_warps * output_graph_degree * sizeof(IdxT); + for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { + auto mst_graph_view = (*d_mst_graph).view(); + auto mst_graph_num_edges_view = (*d_mst_graph_num_edges).view(); + auto output_view = (*d_output_graph).view(); + kern_merge_graph + <<>>( + output_view, + d_rev_graph, + d_rev_graph_count, + mst_graph_view, + mst_graph_num_edges_view, + batch_size, + i_batch, + guarantee_connectivity, + d_check_num_protected_edges.data_handle()); + + d_output_graph.prefetch_next_batch(); + d_mst_graph.prefetch_next_batch(); + d_mst_graph_num_edges.prefetch_next_batch(); + ++d_output_graph; + ++d_mst_graph; + ++d_mst_graph_num_edges; + } + + uint32_t check_num_protected_edges = 1u; + raft::copy(res, + raft::make_host_scalar_view(&check_num_protected_edges), + d_check_num_protected_edges.view()); + raft::resource::sync_stream(res); + const auto merge_graph_end = cur_time(); + RAFT_EXPECTS(check_num_protected_edges, + "Failed to merge the MST, pruned, and reverse edge graphs. " + "Some nodes have too " + "many MST optimization edges."); + + RAFT_LOG_DEBUG("# Time for merging graphs: %.1lf ms", + (merge_graph_end - merge_graph_start) * 1000.0); +} + +template +void make_reverse_graph_gpu( + raft::resources const& res, + raft::mdspan, raft::row_major, AccessorOutputGraph> + output_graph, + raft::device_matrix_view d_rev_graph, + raft::device_vector_view d_rev_graph_count) +{ + const uint64_t graph_size = output_graph.extent(0); + const uint64_t output_graph_degree = output_graph.extent(1); + + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/reverse"); + + // + // Make reverse graph + // + raft::matrix::fill(res, d_rev_graph, IdxT(-1)); + raft::matrix::fill(res, d_rev_graph_count, uint32_t(0)); + + if constexpr (AccessorOutputGraph::is_device_accessible) { + // output graph is fully device accessible, so we need no copy to device + dim3 threads(256, 1, 1); + dim3 blocks(1024, 1, 1); + for (uint64_t k = 0; k < output_graph_degree; k++) { + kern_make_rev_graph_k<<>>( + output_graph, d_rev_graph, d_rev_graph_count, k); + } + } else { + auto d_dest_nodes = raft::make_device_matrix(res, graph_size, 1); + auto dest_nodes = raft::make_host_vector(graph_size); + for (uint64_t k = 0; k < output_graph_degree; k++) { +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + dest_nodes(i) = output_graph(i, k); + } + raft::copy(res, d_dest_nodes.view(), raft::make_const_mdspan(dest_nodes.view())); + + dim3 threads(256, 1, 1); + dim3 blocks(1024, 1, 1); + kern_make_rev_graph_k<<>>( + d_dest_nodes.view(), d_rev_graph, d_rev_graph_count, 0); + raft::resource::sync_stream(res); + RAFT_LOG_DEBUG("# Making reverse graph on GPUs: %lu / %u \r", k, output_graph_degree); + } } } -} // namespace template graph_size) continue; + if (m >= graph_size) continue; if (label_ptr[i] == label_ptr[m]) { ret = 0; break; @@ -711,12 +1182,13 @@ void mst_opt_update_graph(IdxT* mst_graph_ptr, // an approximate MST. // * If the input kNN graph is disconnected, random connection is added to the largest cluster. // -template -void mst_optimization(raft::resources const& res, - raft::host_matrix_view input_graph, - raft::host_matrix_view output_graph, - raft::host_vector_view mst_graph_num_edges, - bool use_gpu = true) +template +void mst_optimization( + raft::resources const& res, + raft::mdspan, raft::row_major, AccessorKnnGraph> input_graph, + raft::host_matrix_view output_graph, + raft::host_vector_view mst_graph_num_edges, + bool use_gpu = true) { if (use_gpu) { RAFT_LOG_DEBUG("# MST optimization on GPU"); @@ -728,9 +1200,6 @@ void mst_optimization(raft::resources const& res, const IdxT graph_size = input_graph.extent(0); const uint32_t input_graph_degree = input_graph.extent(1); const uint32_t output_graph_degree = output_graph.extent(1); - auto input_graph_ptr = input_graph.data_handle(); - auto output_graph_ptr = output_graph.data_handle(); - auto mst_graph_num_edges_ptr = mst_graph_num_edges.data_handle(); // Allocate temporal arrays const uint32_t mst_graph_degree = output_graph_degree; @@ -848,9 +1317,28 @@ void mst_optimization(raft::resources const& res, } } else { // Copy rank-k edges from the input knn graph to 'candidate_edges' + if constexpr (AccessorKnnGraph::is_host_accessible) { #pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - candidate_edges_ptr[i] = input_graph_ptr[k + (input_graph_degree * i)]; + for (uint64_t i = 0; i < graph_size; i++) { + candidate_edges_ptr[i] = input_graph(i, k); + } + } else { + // handle device knn graph + RAFT_CUDA_TRY(cudaMemcpy2DAsync(candidate_edges_ptr, + sizeof(IdxT), + &input_graph(0, k), + input_graph_degree * sizeof(IdxT), + 1 * sizeof(IdxT), // width + graph_size, + cudaMemcpyDeviceToHost, + raft::resource::get_cuda_stream(res))); + raft::resource::sync_stream(res); + + // FIXME: use submdspan and raft::copy once supported + /*auto column_view = cuda::std::submdspan(input_graph, + cuda::std::tuple{uint64_t(0), + uint64_t(graph_size)}, cuda::std::tuple{uint64_t(k), uint64_t(k+1)}); raft::copy(res, + candidate_edges.view(), raft::make_const_mdspan(column_view));*/ } } @@ -1069,94 +1557,174 @@ void mst_optimization(raft::resources const& res, for (uint64_t i = 0; i < graph_size; i++) { uint64_t k = 0; for (uint64_t kj = 0; kj < mst_graph_degree; kj++) { - uint64_t j = mst_graph_ptr[(mst_graph_degree * i) + kj]; + uint64_t j = mst_graph(i, kj); if (j >= graph_size) continue; // Check to avoid duplication auto flag_match = false; for (uint64_t ki = 0; ki < k; ki++) { - if (j == output_graph_ptr[(output_graph_degree * i) + ki]) { + if (j == output_graph(i, ki)) { flag_match = true; break; } } if (flag_match) continue; - output_graph_ptr[(output_graph_degree * i) + k] = j; + output_graph(i, k) = j; k += 1; } - mst_graph_num_edges_ptr[i] = k; + mst_graph_num_edges(i) = k; } const double time_mst_opt_end = cur_time(); RAFT_LOG_DEBUG("# MST optimization time: %.1lf sec", time_mst_opt_end - time_mst_opt_start); } -template -void count_2hop_detours(raft::host_matrix_view knn_graph, - raft::host_matrix_view detour_count) +// +// Prune unimportant edges based on 2-hop detour counts. +// +// The edge to be retained is determined without explicitly considering distance or angle. +// Suppose the edge is the k-th edge of some node-A to node-B (A->B). Among the edges +// originating at node-A, there are k-1 edges shorter than the edge A->B. Each of these +// k-1 edges are connected to a different k-1 nodes. Among these k-1 nodes, count the +// number of nodes with edges to node-B, which is the number of 2-hop detours for the +// edge A->B. Once the number of 2-hop detours has been counted for all edges, the +// specified number of edges are picked up for each node, starting with the edge with +// the lowest number of 2-hop detours. +// +template +void prune_graph_gpu( + raft::resources const& res, + raft::mdspan, raft::row_major, AccessorKnnGraph> knn_graph, + raft::mdspan, raft::row_major, AccessorOutputGraph> + output_graph) { - RAFT_EXPECTS(knn_graph.extent(0) == detour_count.extent(0), - "knn_graph and detour_count are expected to have the same number of rows"); - RAFT_EXPECTS(knn_graph.extent(1) == detour_count.extent(1), - "knn_graph and detour_count are expected to have the same number of cols"); - const uint64_t graph_size = knn_graph.extent(0); - const uint64_t graph_degree = knn_graph.extent(1); - -#pragma omp parallel for - for (IdxT iA = 0; iA < graph_size; iA++) { - // Create a list of nodes, iB_candidates, that can be reached in 2-hops from node A. - auto iB_candidates = - raft::make_host_vector((graph_degree - 1) * (graph_degree - 1)); - for (uint64_t kAC = 0; kAC < graph_degree - 1; kAC++) { - IdxT iC = knn_graph(iA, kAC); - for (uint64_t kCB = 0; kCB < graph_degree - 1; kCB++) { - IdxT iB_candidate; - if (iC == iA || iC >= graph_size) { - iB_candidate = graph_size; - } else { - iB_candidate = knn_graph(iC, kCB); - if (iB_candidate == iA || iB_candidate == iC) { iB_candidate = graph_size; } - } - uint64_t idx; - if (kAC < kCB) { - idx = (kCB * kCB) + kAC; - } else { - idx = (kAC * (kAC + 1)) + kCB; - } - iB_candidates(idx) = iB_candidate; - } - } - // Count how many 2-hop detours are on each edge of node A. - for (uint64_t kAB = 0; kAB < graph_degree; kAB++) { - constexpr uint32_t max_count = 255; - uint32_t count = 0; - IdxT iB = knn_graph(iA, kAB); - if (iB == iA) { - count = max_count; - } else { - for (uint64_t idx = 0; idx < kAB * kAB; idx++) { - if (iB_candidates(idx) == iB) { count += 1; } - } - } - detour_count(iA, kAB) = std::min(count, max_count); - } + const uint64_t graph_size = output_graph.extent(0); + const uint64_t knn_graph_degree = knn_graph.extent(1); + const uint64_t output_graph_degree = output_graph.extent(1); + + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/prune"); + + // The batchsize is statically set to 256 * 1024 which corresponds to 256MB for a graph + // degree of 128 and 16byte index type. This is a trade-off between memory usage and performance. + // When choosing dynamically based on available memory, we would also need to modify the static + // size assumption in the cagra_build.cuh::optimize_workspace_size function. + uint32_t batch_size = static_cast(std::min(graph_size, 256 * 1024)); + const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; + + RAFT_LOG_DEBUG("# Pruning kNN Graph on GPUs\r"); + + const double prune_start = cur_time(); + + [[maybe_unused]] uint64_t num_keep = 0; + [[maybe_unused]] uint64_t num_full = 0; + auto dev_stats = raft::make_device_vector(res, 2); + auto host_stats = raft::make_host_vector(2); + raft::matrix::fill(res, dev_stats.view(), uint64_t(0)); + + namespace bli = cuvs::spatial::knn::detail::utils; + auto [copy_stream, enable_prefetch] = bli::get_prefetch_stream(res); + auto workspace_mr = raft::resource::get_workspace_resource_ref(res); + + // Single-batch read-only iterator for the input graph (graph_size rows fit in one batch). + bli::batch_load_iterator< + raft::mdspan, raft::row_major, AccessorKnnGraph>> + d_input_graph(res, knn_graph, graph_size, copy_stream, workspace_mr); + auto input_view = (*d_input_graph).view(); + + bli::batch_load_iterator< + raft::mdspan, raft::row_major, AccessorOutputGraph>> + d_output_graph(res, + output_graph, + batch_size, + copy_stream, + workspace_mr, + enable_prefetch, + /*initialize=*/false, + /*host_writeback=*/true); + d_output_graph.prefetch_next_batch(); + + auto d_invalid_neighbor_list = raft::make_device_scalar(res, 0u); + + const uint32_t num_warps = 4; + const dim3 threads_prune(raft::WarpSize * num_warps, 1, 1); + const dim3 blocks_prune(raft::ceildiv(batch_size, num_warps), 1, 1); + const size_t prune_smem_size = num_warps * knn_graph_degree * (sizeof(IdxT) + sizeof(uint32_t)); + + for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { + auto output_view = (*d_output_graph).view(); + kern_fused_prune + <<>>( + input_view, + output_view, + batch_size, + i_batch, + d_invalid_neighbor_list.data_handle(), + dev_stats.data_handle()); + + d_output_graph.prefetch_next_batch(); + ++d_output_graph; + RAFT_LOG_DEBUG( + "# Pruning kNN Graph on GPUs (%.1lf %%)\r", + (double)std::min((i_batch + 1) * batch_size, graph_size) / graph_size * 100); } + raft::resource::sync_stream(res); + RAFT_LOG_DEBUG("\n"); + + uint32_t invalid_neighbor_list = 0; + raft::copy( + res, raft::make_host_scalar_view(&invalid_neighbor_list), d_invalid_neighbor_list.view()); + raft::copy(res, host_stats.view(), raft::make_const_mdspan(dev_stats.view())); + raft::resource::sync_stream(res); + + RAFT_EXPECTS( + invalid_neighbor_list == 0, + "Could not generate an intermediate CAGRA graph because the initial kNN graph contains too " + "many invalid or duplicated neighbor nodes. This error can occur, for example, if too many " + "overflows occur during the norm computation between the dataset vectors."); + + num_keep = host_stats.data_handle()[0]; + num_full = host_stats.data_handle()[1]; + + const double prune_end = cur_time(); + RAFT_LOG_DEBUG( + "# Time for pruning on GPU: %.1lf sec, " + "avg_no_detour_edges_per_node: %.2lf/%u, " + "nodes_with_no_detour_at_all_edges: %.1lf%%", + prune_end - prune_start, + (double)num_keep / graph_size, + output_graph_degree, + (double)num_full / graph_size * 100); } +} // namespace + template , raft::memory_type::host>, + typename AccessorOutputGraph = raft::host_device_accessor, raft::memory_type::host>> void optimize( raft::resources const& res, - raft::mdspan, raft::row_major, g_accessor> knn_graph, - raft::host_matrix_view new_graph, - const bool guarantee_connectivity = true, - const bool use_gpu = true) + raft::mdspan, raft::row_major, AccessorKnnGraph> knn_graph, + raft::mdspan, raft::row_major, AccessorOutputGraph> new_graph, + const bool guarantee_connectivity = true, + const bool use_gpu_for_mst_optimization = true) { RAFT_LOG_DEBUG( "# Pruning kNN graph (size=%lu, degree=%lu)\n", knn_graph.extent(0), knn_graph.extent(1)); + + // large temporary memory for large arrays, e.g. everything >= O(graph_size) auto large_tmp_mr = raft::resource::get_large_workspace_resource_ref(res); + // temporary memory for small arrays, e.g. everything <= O(batchsize * graph_degree) + auto default_ws_mr = raft::resource::get_workspace_resource_ref(res); + + // create a stream pool if not already present + if (!res.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL) || + raft::resource::get_stream_pool_size(res) < 1) { + raft::resource::set_cuda_stream_pool(res, std::make_shared(1)); + } RAFT_EXPECTS(knn_graph.extent(0) == new_graph.extent(0), "Each input array is expected to have the same number of rows"); @@ -1166,465 +1734,82 @@ void optimize( const uint64_t knn_graph_degree = knn_graph.extent(1); const uint64_t output_graph_degree = new_graph.extent(1); const uint64_t graph_size = new_graph.extent(0); - // auto input_graph_ptr = knn_graph.data_handle(); - auto output_graph_ptr = new_graph.data_handle(); + raft::common::nvtx::range fun_scope( "cagra::graph::optimize(%zu, %zu, %u)", graph_size, knn_graph_degree, output_graph_degree); // MST optimization - auto mst_graph = raft::make_host_matrix(0, 0); - auto mst_graph_num_edges = raft::make_host_vector(graph_size); - auto mst_graph_num_edges_ptr = mst_graph_num_edges.data_handle(); -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - mst_graph_num_edges_ptr[i] = 0; - } + // currently, only using GPU path for MST optimization + int64_t mst_graph_size = guarantee_connectivity ? graph_size : 0; + auto mst_graph = + raft::make_host_matrix(mst_graph_size, output_graph_degree); + auto mst_graph_num_edges = raft::make_host_vector(mst_graph_size); + if (guarantee_connectivity) { raft::common::nvtx::range block_scope( "cagra::graph::optimize/check_connectivity"); - mst_graph = - raft::make_host_matrix(graph_size, output_graph_degree); RAFT_LOG_INFO("MST optimization is used to guarantee graph connectivity."); - mst_optimization(res, knn_graph, mst_graph.view(), mst_graph_num_edges.view(), use_gpu); + mst_optimization( + res, knn_graph, mst_graph.view(), mst_graph_num_edges.view(), use_gpu_for_mst_optimization); for (uint64_t i = 0; i < graph_size; i++) { if (i < 8 || i >= graph_size - 8) { - RAFT_LOG_DEBUG("# mst_graph_num_edges_ptr[%lu]: %u\n", i, mst_graph_num_edges_ptr[i]); + RAFT_LOG_DEBUG("# mst_graph_num_edges[%lu]: %u\n", i, mst_graph_num_edges(i)); } } } + // prune graph -- will always use GPU path { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/prune"); - const double time_prune_start = cur_time(); - - // - // Prune unimportant edges. - // - // The edge to be retained is determined without explicitly considering - // distance or angle. Suppose the edge is the k-th edge of some node-A to - // node-B (A->B). Among the edges originating at node-A, there are k-1 edges - // shorter than the edge A->B. Each of these k-1 edges are connected to a - // different k-1 nodes. Among these k-1 nodes, count the number of nodes with - // edges to node-B, which is the number of 2-hop detours for the edge A->B. - // Once the number of 2-hop detours has been counted for all edges, the - // specified number of edges are picked up for each node, starting with the - // edge with the lowest number of 2-hop detours. - // - auto detour_count = raft::make_host_matrix(graph_size, knn_graph_degree); - - // - // If the available device memory is insufficient, do not use the GPU to count - // the number of 2-hop detours, but use the CPU. - // - bool _use_gpu = use_gpu; - if (_use_gpu) { - try { - auto d_detour_count = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size, knn_graph_degree)); - auto d_num_no_detour_edges = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size)); - auto d_input_graph = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size, knn_graph_degree)); - } catch (std::bad_alloc& e) { - RAFT_LOG_DEBUG("Insufficient memory for 2-hop node counting on GPU"); - _use_gpu = false; - } catch (raft::logic_error& e) { - RAFT_LOG_DEBUG("Insufficient memory for 2-hop node counting on GPU (logic error)"); - _use_gpu = false; - } - } - if (_use_gpu) { - // Count 2-hop detours on GPU - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/prune/2-hop-counting-by-GPU"); - const double time_2hop_count_start = cur_time(); - - uint64_t num_keep __attribute__((unused)) = 0; - uint64_t num_full __attribute__((unused)) = 0; - auto d_detour_count = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size, knn_graph_degree)); - - raft::matrix::fill(res, d_detour_count.view(), uint8_t(0xff)); - - auto d_num_no_detour_edges = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size)); - raft::matrix::fill(res, d_num_no_detour_edges.view(), uint32_t(0)); - - auto dev_stats = raft::make_device_vector(res, 2); - auto host_stats = raft::make_host_vector(2); - - RAFT_LOG_DEBUG("# Pruning kNN Graph on GPUs\r"); - - // Copy knn_graph over to device if necessary - device_matrix_view_from_host d_input_graph( - res, - raft::make_host_matrix_view( - knn_graph.data_handle(), graph_size, knn_graph_degree)); - - constexpr int MAX_DEGREE = 1024; - if (knn_graph_degree > MAX_DEGREE) { - RAFT_FAIL( - "The degree of input knn graph is too large (%zu). " - "It must be equal to or smaller than %d.", - knn_graph_degree, - MAX_DEGREE); - } - const uint32_t batch_size = - std::min(static_cast(graph_size), static_cast(256 * 1024)); - const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; - const dim3 threads_prune(32, 1, 1); - const dim3 blocks_prune(batch_size, 1, 1); - - raft::matrix::fill(res, dev_stats.view(), uint64_t(0)); - - for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { - kern_prune - <<>>( - d_input_graph.data_handle(), - graph_size, - knn_graph_degree, - output_graph_degree, - batch_size, - i_batch, - d_detour_count.data_handle(), - d_num_no_detour_edges.data_handle(), - dev_stats.data_handle()); - raft::resource::sync_stream(res); - RAFT_LOG_DEBUG( - "# Pruning kNN Graph on GPUs (%.1lf %%)\r", - (double)std::min((i_batch + 1) * batch_size, graph_size) / graph_size * 100); - } - raft::resource::sync_stream(res); - RAFT_LOG_DEBUG("\n"); - - raft::copy(res, detour_count.view(), raft::make_const_mdspan(d_detour_count.view())); - - raft::copy(res, host_stats.view(), raft::make_const_mdspan(dev_stats.view())); - num_keep = host_stats.data_handle()[0]; - num_full = host_stats.data_handle()[1]; - - const double time_2hop_count_end = cur_time(); - RAFT_LOG_DEBUG( - "# Time for 2-hop detour counting on GPU: %.1lf sec, " - "avg_no_detour_edges_per_node: %.2lf/%u, " - "nodes_with_no_detour_at_all_edges: %.1lf%%", - time_2hop_count_end - time_2hop_count_start, - (double)num_keep / graph_size, - output_graph_degree, - (double)num_full / graph_size * 100); - } else { - // Count 2-hop detours on CPU - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/prune/2-hop-counting-by-CPU"); - const double time_2hop_count_start = cur_time(); - - count_2hop_detours(knn_graph, detour_count.view()); - - const double time_2hop_count_end = cur_time(); - RAFT_LOG_DEBUG("# Time for 2-hop detour counting on CPU: %.1lf sec", - time_2hop_count_end - time_2hop_count_start); - } - - // Create pruned kNN graph - bool invalid_neighbor_list = false; -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - // Find the `output_graph_degree` smallest detourable count nodes by checking the detourable - // count of the neighbors while increasing the target detourable count from zero. - uint64_t pk = 0; - uint32_t num_detour = 0; - for (uint32_t l = 0; l < knn_graph_degree && pk < output_graph_degree; l++) { - uint32_t next_num_detour = std::numeric_limits::max(); - for (uint64_t k = 0; k < knn_graph_degree; k++) { - const auto num_detour_k = detour_count(i, k); - // Find the detourable count to check in the next iteration - if (num_detour_k > num_detour) { - next_num_detour = std::min(static_cast(num_detour_k), next_num_detour); - } - - // Store the neighbor index if its detourable count is equal to `num_detour`. - if (num_detour_k != num_detour) { continue; } - - // Check duplication and append - const auto candidate_node = knn_graph(i, k); - bool dup = false; - for (uint32_t dk = 0; dk < pk; dk++) { - if (candidate_node == output_graph_ptr[i * output_graph_degree + dk]) { - dup = true; - break; - } - } - if (!dup && candidate_node < graph_size) { - output_graph_ptr[i * output_graph_degree + pk] = candidate_node; - pk += 1; - } - if (pk >= output_graph_degree) break; - } - if (pk >= output_graph_degree) break; - - if (next_num_detour == std::numeric_limits::max()) { - // There are no valid edges enough in the initial kNN graph. Break the loop here and catch - // the error at the next validation (pk != output_graph_degree). - break; - } - num_detour = next_num_detour; - } - if (pk != output_graph_degree) { - RAFT_LOG_DEBUG( - "Couldn't find the output_graph_degree (%lu) smallest detourable count nodes for " - "node %lu in the rank-based node reranking process", - output_graph_degree, - i); - invalid_neighbor_list = true; - } - } - RAFT_EXPECTS( - !invalid_neighbor_list, - "Could not generate an intermediate CAGRA graph because the initial kNN graph contains too " - "many invalid or duplicated neighbor nodes. This error can occur, for example, if too many " - "overflows occur during the norm computation between the dataset vectors."); - - const double time_prune_end = cur_time(); - RAFT_LOG_DEBUG("# Pruning time: %.1lf ms", (time_prune_end - time_prune_start) * 1000.0); + prune_graph_gpu(res, knn_graph, new_graph); } - auto rev_graph = raft::make_host_matrix(graph_size, output_graph_degree); - auto rev_graph_count = raft::make_host_vector(graph_size); - - { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/reverse"); - // - // Make reverse graph - // - const double time_make_start = cur_time(); - - device_matrix_view_from_host d_rev_graph(res, rev_graph.view()); - raft::matrix::fill(res, - raft::make_device_vector_view( - d_rev_graph.data_handle(), graph_size * output_graph_degree), - IdxT(-1)); - - auto d_rev_graph_count = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size)); - raft::matrix::fill(res, d_rev_graph_count.view(), uint32_t(0)); - - auto dest_nodes = raft::make_host_vector(graph_size); - auto d_dest_nodes = - raft::make_device_mdarray(res, large_tmp_mr, raft::make_extents(graph_size)); - - for (uint64_t k = 0; k < output_graph_degree; k++) { -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - // dest_nodes.data_handle()[i] = output_graph_ptr[k + (output_graph_degree * i)]; - dest_nodes(i) = output_graph_ptr[k + (output_graph_degree * i)]; - } - raft::resource::sync_stream(res); - - raft::copy(res, d_dest_nodes.view(), raft::make_const_mdspan(dest_nodes.view())); - - dim3 threads(256, 1, 1); - dim3 blocks(1024, 1, 1); - kern_make_rev_graph<<>>( - d_dest_nodes.data_handle(), - d_rev_graph.data_handle(), - d_rev_graph_count.data_handle(), - graph_size, - output_graph_degree); - RAFT_LOG_DEBUG("# Making reverse graph on GPUs: %lu / %u \r", k, output_graph_degree); - } - - raft::resource::sync_stream(res); - RAFT_LOG_DEBUG("\n"); - - if (d_rev_graph.allocated_memory()) { - raft::copy(res, rev_graph.view(), raft::make_const_mdspan(d_rev_graph.view())); - } - raft::copy(res, rev_graph_count.view(), raft::make_const_mdspan(d_rev_graph_count.view())); - - const double time_make_end = cur_time(); - RAFT_LOG_DEBUG("# Making reverse graph time: %.1lf ms", - (time_make_end - time_make_start) * 1000.0); + // reverse graph creation will always use the GPU + // using default workspace resource for random access + // otherwise will be managed memory which is slow upon first access + auto d_rev_graph = raft::make_device_mdarray(res, raft::make_extents(0, 0)); + try { + d_rev_graph = raft::make_device_mdarray( + res, default_ws_mr, raft::make_extents(graph_size, output_graph_degree)); + } catch (const std::exception& e) { + RAFT_LOG_DEBUG( + "Failed to create device matrix for reverse graph, switching to large workspace resource"); + d_rev_graph = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(graph_size, output_graph_degree)); } + // This should use the default workspace resource for random access / atomics + auto d_rev_graph_count = raft::make_device_mdarray( + res, default_ws_mr, raft::make_extents(graph_size)); - { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/combine"); - // - // Create search graphs from MST and pruned and reverse graphs - // - const double time_replace_start = cur_time(); - - bool check_num_protected_edges = true; -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - auto my_rev_graph = rev_graph.data_handle() + (output_graph_degree * i); - auto my_out_graph = output_graph_ptr + (output_graph_degree * i); - - // If guarantee_connectivity == true, use a temporal list to merge the neighbor lists of the - // graphs. - std::vector temp_output_neighbor_list; - if (guarantee_connectivity) { - temp_output_neighbor_list.resize(output_graph_degree); - my_out_graph = temp_output_neighbor_list.data(); - const auto mst_graph_num_edges = mst_graph_num_edges_ptr[i]; - - // Set MST graph edges - for (uint32_t j = 0; j < mst_graph_num_edges; j++) { - my_out_graph[j] = mst_graph(i, j); - } - - // Set pruned graph edges - for (uint32_t pruned_j = 0, output_j = mst_graph_num_edges; - (pruned_j < output_graph_degree) && (output_j < output_graph_degree); - pruned_j++) { - const auto v = output_graph_ptr[output_graph_degree * i + pruned_j]; - - // duplication check - bool dup = false; - for (uint32_t m = 0; m < output_j; m++) { - if (v == my_out_graph[m]) { - dup = true; - break; - } - } - - if (!dup) { - my_out_graph[output_j] = v; - output_j++; - } - } - } - - const auto num_protected_edges = - std::max(mst_graph_num_edges_ptr[i], output_graph_degree / 2); - if (num_protected_edges > output_graph_degree) { check_num_protected_edges = false; } - if (num_protected_edges == output_graph_degree) continue; - - // Replace some edges of the output graph with edges of the reverse graph. - auto kr = std::min(rev_graph_count.data_handle()[i], output_graph_degree); - while (kr) { - kr -= 1; - if (my_rev_graph[kr] < graph_size) { - uint64_t pos = pos_in_array(my_rev_graph[kr], my_out_graph, output_graph_degree); - if (pos < num_protected_edges) { continue; } - uint64_t num_shift = pos - num_protected_edges; - if (pos >= output_graph_degree) { - num_shift = output_graph_degree - num_protected_edges - 1; - } - shift_array(my_out_graph + num_protected_edges, num_shift); - my_out_graph[num_protected_edges] = my_rev_graph[kr]; - } - } + const double time_make_start = cur_time(); - // If guarantee_connectivity == true, move the output neighbor list from the temporal list to - // the output list. If false, the copy is not needed because my_out_graph is a pointer to the - // output buffer. - if (guarantee_connectivity) { - for (uint32_t j = 0; j < output_graph_degree; j++) { - output_graph_ptr[(output_graph_degree * i) + j] = my_out_graph[j]; - } - } - } - RAFT_EXPECTS(check_num_protected_edges, - "Failed to merge the MST, pruned, and reverse edge graphs. Some nodes have too " - "many MST optimization edges."); + make_reverse_graph_gpu(res, new_graph, d_rev_graph.view(), d_rev_graph_count.view()); - const double time_replace_end = cur_time(); - RAFT_LOG_DEBUG("# Replacing edges time: %.1lf ms", - (time_replace_end - time_replace_start) * 1000.0); + raft::resource::sync_stream(res); - /* stats */ - uint64_t num_replaced_edges = 0; -#pragma omp parallel for reduction(+ : num_replaced_edges) - for (uint64_t i = 0; i < graph_size; i++) { - for (uint64_t k = 0; k < output_graph_degree; k++) { - const uint64_t j = output_graph_ptr[k + (output_graph_degree * i)]; - const uint64_t pos = - pos_in_array(j, output_graph_ptr + (output_graph_degree * i), output_graph_degree); - if (pos == output_graph_degree) { num_replaced_edges += 1; } - } - } - RAFT_LOG_DEBUG("# Average number of replaced edges per node: %.2f", - (double)num_replaced_edges / graph_size); - } + const double time_make_end = cur_time(); + RAFT_LOG_DEBUG("\n# Making reverse graph time: %.1lf ms", + (time_make_end - time_make_start) * 1000.0); - // Check number of incoming edges + // merge graph -- will always use GPU path { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/check_edges"); - auto in_edge_count = raft::make_host_vector(graph_size); - auto in_edge_count_ptr = in_edge_count.data_handle(); -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - in_edge_count_ptr[i] = 0; - } -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - for (uint64_t k = 0; k < output_graph_degree; k++) { - const uint64_t j = output_graph_ptr[k + (output_graph_degree * i)]; - if (j >= graph_size) continue; -#pragma omp atomic - in_edge_count_ptr[j] += 1; - } - } - auto hist = raft::make_host_vector(output_graph_degree); - auto hist_ptr = hist.data_handle(); - for (uint64_t k = 0; k < output_graph_degree; k++) { - hist_ptr[k] = 0; - } -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - uint32_t count = in_edge_count_ptr[i]; - if (count >= output_graph_degree) continue; -#pragma omp atomic - hist_ptr[count] += 1; - } - RAFT_LOG_DEBUG("# Histogram for number of incoming edges\n"); - uint32_t sum_hist = 0; - for (uint64_t k = 0; k < output_graph_degree; k++) { - sum_hist += hist_ptr[k]; - RAFT_LOG_DEBUG("# %3lu, %8u, %lf, (%8u, %lf)\n", - k, - hist_ptr[k], - (double)hist_ptr[k] / graph_size, - sum_hist, - (double)sum_hist / graph_size); - } + merge_graph_gpu(res, + new_graph, + d_rev_graph.view(), + d_rev_graph_count.view(), + mst_graph.view(), + mst_graph_num_edges.view(), + guarantee_connectivity); } - // Check duplication and out-of-range indices - { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/check_duplicates"); - uint64_t num_dup = 0; - uint64_t num_oor = 0; -#pragma omp parallel for reduction(+ : num_dup) reduction(+ : num_oor) - for (uint64_t i = 0; i < graph_size; i++) { - auto my_out_graph = output_graph_ptr + (output_graph_degree * i); - for (uint32_t j = 0; j < output_graph_degree; j++) { - const auto neighbor_a = my_out_graph[j]; - - // Check oor - if (neighbor_a > graph_size) { - num_oor++; - continue; - } + raft::resource::sync_stream(res); - // Check duplication - for (uint32_t k = j + 1; k < output_graph_degree; k++) { - const auto neighbor_b = my_out_graph[k]; - if (neighbor_a == neighbor_b) { num_dup++; } - } - } - } - RAFT_EXPECTS( - num_dup == 0, "%lu duplicated node(s) are found in the generated CAGRA graph", num_dup); - RAFT_EXPECTS(num_oor == 0, - "%lu out-of-range index node(s) are found in the generated CAGRA graph", - num_oor); + // These host-side checks are expensive (O(N*D^2)) and only used as debug + // diagnostics, so only run them when debug logging is active at runtime. + if (raft::default_logger().should_log(rapids_logger::level_enum::debug)) { + log_incoming_edges_histogram(res, new_graph); + + check_duplicates_and_out_of_range(res, new_graph); } } diff --git a/cpp/src/neighbors/detail/cagra/utils.hpp b/cpp/src/neighbors/detail/cagra/utils.hpp index 91d6619c7d..58bf68bb43 100644 --- a/cpp/src/neighbors/detail/cagra/utils.hpp +++ b/cpp/src/neighbors/detail/cagra/utils.hpp @@ -9,10 +9,15 @@ #include #include #include +#include +#include +#include +#include +#include #include #include #include - +#include #include #include @@ -20,6 +25,7 @@ #include #include +#include #include namespace cuvs::neighbors::cagra::detail { @@ -294,4 +300,5 @@ void copy_with_padding( raft::resource::get_cuda_stream(res)); } } + } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/nn_descent.cuh b/cpp/src/neighbors/detail/nn_descent.cuh index 19a86dba56..61b8f80f10 100644 --- a/cpp/src/neighbors/detail/nn_descent.cuh +++ b/cpp/src/neighbors/detail/nn_descent.cuh @@ -1507,8 +1507,13 @@ void GNND::build(Data_t* data, RAFT_CUDA_TRY(cudaPointerGetAttributes(&data_ptr_attr, data)); size_t batch_size = (data_ptr_attr.devicePointer == nullptr) ? 100000 : nrow_; - cuvs::spatial::knn::detail::utils::batch_load_iterator vec_batches{ - data, static_cast(nrow_), build_config_.dataset_dim, batch_size, stream}; + auto vec_batches = cuvs::spatial::knn::detail::utils::make_batch_load_iterator( + res, + data, + static_cast(nrow_), + static_cast(build_config_.dataset_dim), + batch_size, + stream); for (auto const& batch : vec_batches) { if (d_data_float_.has_value()) { preprocess_data_kernel<< build( // TODO: with scaling workspace we could choose the batch size dynamically constexpr uint32_t kReasonableMaxBatchSize = 65536; const uint32_t max_batch_size = std::min(n_rows, kReasonableMaxBatchSize); - for (const auto& batch : cuvs::spatial::knn::detail::utils::batch_load_iterator( - dataset.data_handle(), - n_rows, - dim, - max_batch_size, - raft::resource::get_cuda_stream(res), - raft::resource::get_workspace_resource_ref(res))) { + auto _vamana_batches = cuvs::spatial::knn::detail::utils::make_batch_load_iterator( + res, + dataset.data_handle(), + static_cast(n_rows), + static_cast(dim), + static_cast(max_batch_size), + raft::resource::get_cuda_stream(res), + raft::resource::get_workspace_resource_ref(res)); + for (const auto& batch : _vamana_batches) { // perform rotation auto dataset_rotated = raft::make_device_matrix(res, batch.size(), dim); if constexpr (std::is_same_v) { diff --git a/cpp/src/neighbors/detail/vpq_dataset.cuh b/cpp/src/neighbors/detail/vpq_dataset.cuh index 75785da5d2..d4410afa69 100644 --- a/cpp/src/neighbors/detail/vpq_dataset.cuh +++ b/cpp/src/neighbors/detail/vpq_dataset.cuh @@ -506,13 +506,15 @@ void process_and_fill_codes( return; } - for (const auto& batch : cuvs::spatial::knn::detail::utils::batch_load_iterator( - dataset.data_handle(), - n_rows, - dim, - max_batch_size, - stream, - rmm::mr::get_current_device_resource_ref())) { + auto _vpq_batches_codes = cuvs::spatial::knn::detail::utils::make_batch_load_iterator( + res, + dataset.data_handle(), + static_cast(n_rows), + static_cast(dim), + static_cast(max_batch_size), + stream, + rmm::mr::get_current_device_resource_ref()); + for (const auto& batch : _vpq_batches_codes) { auto batch_view = raft::make_device_matrix_view(batch.data(), ix_t(batch.size()), dim); auto batch_labels_view = raft::make_device_vector_view(nullptr, 0); if (inline_vq_labels) { @@ -899,11 +901,12 @@ void process_and_fill_codes_subspaces( enable_prefetch_stream = true; copy_stream = raft::resource::get_stream_from_stream_pool(res); } - auto vec_batches = cuvs::spatial::knn::detail::utils::batch_load_iterator( + auto vec_batches = cuvs::spatial::knn::detail::utils::make_batch_load_iterator( + res, dataset.data_handle(), - n_rows, - dim, - max_batch_size, + static_cast(n_rows), + static_cast(dim), + static_cast(max_batch_size), copy_stream, raft::resource::get_workspace_resource_ref(res), enable_prefetch_stream); diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh index 35005b4279..fffe5134ae 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh @@ -203,13 +203,15 @@ void extend(raft::resources const& handle, } } // Predict the cluster labels for the new data, in batches if necessary - utils::batch_load_iterator vec_batches(new_vectors, - n_rows, - index->dim(), - max_batch_size, - copy_stream, - raft::resource::get_workspace_resource_ref(handle), - enable_prefetch); + auto vec_batches = + utils::make_batch_load_iterator(handle, + new_vectors, + n_rows, + IdxT{index->dim()}, + max_batch_size, + copy_stream, + raft::resource::get_workspace_resource_ref(handle), + enable_prefetch); vec_batches.prefetch_next_batch(); for (const auto& batch : vec_batches) { @@ -296,17 +298,19 @@ void extend(raft::resources const& handle, raft::make_device_vector_view(list_sizes_ptr, n_lists), raft::make_device_vector_view(old_list_sizes_dev.data_handle(), n_lists)); - utils::batch_load_iterator vec_indices(new_indices, - n_rows, - 1, - max_batch_size, - stream, - raft::resource::get_workspace_resource_ref(handle)); + auto vec_indices = + utils::make_batch_load_iterator(handle, + new_indices, + n_rows, + IdxT{1}, + max_batch_size, + stream, + raft::resource::get_workspace_resource_ref(handle)); vec_batches.reset(); vec_batches.prefetch_next_batch(); - utils::batch_load_iterator idx_batch = vec_indices.begin(); - size_t next_report_offset = 0; - size_t d_report_offset = n_rows * 5 / 100; + auto idx_batch = vec_indices.begin(); + size_t next_report_offset = 0; + size_t d_report_offset = n_rows * 5 / 100; for (const auto& batch : vec_batches) { auto batch_data_view = raft::make_device_matrix_view(batch.data(), batch.size(), index->dim()); diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh index d7eb46c5a3..50f9564ca5 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh @@ -1078,8 +1078,14 @@ void extend(raft::resources const& handle, } } // Predict the cluster labels for the new data, in batches if necessary - utils::batch_load_iterator vec_batches( - new_vectors, n_rows, index->dim(), max_batch_size, copy_stream, device_memory, enable_prefetch); + auto vec_batches = utils::make_batch_load_iterator(handle, + new_vectors, + n_rows, + index->dim(), + max_batch_size, + copy_stream, + device_memory, + enable_prefetch); // Release the placeholder memory, because we don't intend to allocate any more long-living // temporary buffers before we allocate the index data. // This memory could potentially speed up UVM accesses, if any. @@ -1175,8 +1181,8 @@ void extend(raft::resources const& handle, // By this point, the index state is updated and valid except it doesn't contain the new data // Fill the extended index with the new data (possibly, in batches) - utils::batch_load_iterator idx_batches( - new_indices, n_rows, 1, max_batch_size, stream, batches_mr); + auto idx_batches = utils::make_batch_load_iterator( + handle, new_indices, n_rows, IdxT{1}, max_batch_size, stream, batches_mr); vec_batches.reset(); vec_batches.prefetch_next_batch(); for (const auto& vec_batch : vec_batches) { diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_transform.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_transform.cuh index 1d55692488..b5ea8afab4 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_transform.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_transform.cuh @@ -140,13 +140,14 @@ void transform(raft::resources const& res, constexpr size_t max_batch_size = 65536; rmm::device_async_resource_ref device_memory = raft::resource::get_workspace_resource_ref(res); - utils::batch_load_iterator vec_batches(dataset.data_handle(), - n_rows, - index.dim(), - max_batch_size, - copy_stream, - device_memory, - enable_prefetch); + auto vec_batches = utils::make_batch_load_iterator(res, + dataset.data_handle(), + n_rows, + IdxT{index.dim()}, + max_batch_size, + copy_stream, + device_memory, + enable_prefetch); vec_batches.prefetch_next_batch(); for (const auto& batch : vec_batches) { diff --git a/cpp/src/neighbors/scann/detail/scann_build.cuh b/cpp/src/neighbors/scann/detail/scann_build.cuh index 66ebd994b2..c01e50bc83 100644 --- a/cpp/src/neighbors/scann/detail/scann_build.cuh +++ b/cpp/src/neighbors/scann/detail/scann_build.cuh @@ -102,13 +102,15 @@ index build( } } - utils::batch_load_iterator dataset_vec_batches(dataset.data_handle(), - dataset.extent(0), - dataset.extent(1), - max_batch_size, - copy_stream, - device_memory, - enable_prefetch); + auto dataset_vec_batches = + utils::make_batch_load_iterator(res, + dataset.data_handle(), + static_cast(dataset.extent(0)), + static_cast(dataset.extent(1)), + max_batch_size, + copy_stream, + device_memory, + enable_prefetch); dataset_vec_batches.reset(); dataset_vec_batches.prefetch_next_batch(); diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index b30f108789..2f46abb812 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -182,7 +182,7 @@ ConfigureTest( ConfigureTest( NAME NEIGHBORS_ANN_CAGRA_HELPERS_TEST - PATH neighbors/ann_cagra/test_optimize_uint32_t.cu + PATH neighbors/ann_cagra/test_optimize_uint32_t.cu neighbors/ann_cagra/test_batch_load_iterator.cu GPUS 1 PERCENT 100 ) diff --git a/cpp/tests/neighbors/ann_cagra/test_batch_load_iterator.cu b/cpp/tests/neighbors/ann_cagra/test_batch_load_iterator.cu new file mode 100644 index 0000000000..c0df2b8826 --- /dev/null +++ b/cpp/tests/neighbors/ann_cagra/test_batch_load_iterator.cu @@ -0,0 +1,375 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "../../../src/neighbors/detail/ann_utils.cuh" + +#include +#include +#include +#include + +namespace cuvs::neighbors::cagra { + +namespace bli = cuvs::spatial::knn::detail::utils; + +using IdxT = uint32_t; +using DeviceAccessor = + raft::host_device_accessor, raft::memory_type::device>; +using HostAccessor = + raft::host_device_accessor, raft::memory_type::host>; +using PinnedAccessor = + raft::host_device_accessor, raft::memory_type::pinned>; +using ManagedAccessor = + raft::host_device_accessor, raft::memory_type::managed>; + +struct BatchConfig { + bool initialize; + bool host_writeback; +}; + +struct DimsConfig { + int64_t n_rows; + int64_t n_cols; + uint64_t batch_size; +}; + +class BatchLoadIteratorTest : public ::testing::Test { + protected: + void SetUp() override + { + // Provide a stream pool so get_prefetch_stream(res) returns a non-main stream and + // the iterator exercises its real pipelined path. + raft::resource::set_cuda_stream_pool(res, std::make_shared(1)); + raft::resource::sync_stream(res); + } + + /** + * Run batch_load_iterator over input_view, copy device views back, and verify against the + * input. Mirrors the old batched_device_view test: writeback fills with IdxT(17), readback + * (when initialize==true) verifies pre-fill of IdxT(13). + */ + template + void run_and_verify_batched( + raft::mdspan, raft::row_major, AccessorInputView> input_view, + uint64_t batch_size, + bool host_writeback, + bool initialize) + { + using mdspan_t = decltype(input_view); + + int64_t n_rows = input_view.extent(0); + int64_t n_cols = input_view.extent(1); + + std::vector readback(n_rows * n_cols); + + int64_t total_processed = 0; + + auto [copy_stream, enable_prefetch] = bli::get_prefetch_stream(res); + auto workspace_mr = raft::resource::get_workspace_resource_ref(res); + + { + bli::batch_load_iterator iter(res, + input_view, + batch_size, + copy_stream, + workspace_mr, + enable_prefetch, + initialize, + host_writeback); + iter.prefetch_next_batch(); + + for (auto& batch : iter) { + if (batch.size() == 0) break; + + auto dev_view = batch.view(); + + if (initialize) { + raft::copy(readback.data() + total_processed * n_cols, + dev_view.data_handle(), + dev_view.extent(0) * dev_view.extent(1), + raft::resource::get_cuda_stream(res)); + } + if (host_writeback) { + // The passthrough strategy returns `cuda::std::submdspan(input_view, ...)`, which + // may be `layout_stride` and/or carry a non-default accessor (e.g. PinnedAccessor), + // neither of which `raft::matrix::fill`'s `device_matrix_view` overload accepts. + // Re-wrap as a plain device_matrix_view since the slice is always exhaustive + // (a contiguous row range of a row-major input). This re-wrap should eventually be + // removed once raft::matrix::fill grows a more generic mdspan overload. + raft::matrix::fill(res, + raft::make_device_matrix_view( + dev_view.data_handle(), dev_view.extent(0), dev_view.extent(1)), + IdxT(17)); + } + total_processed += dev_view.extent(0); + + iter.prefetch_next_batch(); + } + } + raft::resource::sync_stream(res); + + EXPECT_EQ(total_processed, n_rows); + if (initialize) { + for (int64_t i = 0; i < n_rows * n_cols; ++i) { + EXPECT_EQ(readback[i], IdxT(13)) << "Mismatch (initialize) at index " << i; + } + } + if (host_writeback) { + auto readback_view = + raft::make_host_matrix_view(readback.data(), n_rows, n_cols); + raft::copy(res, readback_view, input_view); + raft::resource::sync_stream(res); + for (int64_t i = 0; i < n_rows * n_cols; ++i) { + EXPECT_EQ(readback[i], IdxT(17)) << "Mismatch (host_writeback) at index " << i; + } + } + } + + raft::resources res; +}; + +TEST_F(BatchLoadIteratorTest, EmptyViewFromHost) +{ + auto host_empty = raft::make_host_matrix(0, 8); + auto host_view = host_empty.view(); + + auto [copy_stream, enable_prefetch] = bli::get_prefetch_stream(res); + auto workspace_mr = raft::resource::get_workspace_resource_ref(res); + + bli::batch_load_iterator> iter( + res, host_view, /*batch_size=*/128, copy_stream, workspace_mr, enable_prefetch); + EXPECT_TRUE(iter.begin() == iter.end()); +} + +TEST_F(BatchLoadIteratorTest, EmptyViewFromDevice) +{ + auto device_empty = raft::make_device_matrix(res, 0, 8); + auto device_view = device_empty.view(); + + auto [copy_stream, enable_prefetch] = bli::get_prefetch_stream(res); + auto workspace_mr = raft::resource::get_workspace_resource_ref(res); + + bli::batch_load_iterator> iter( + res, device_view, /*batch_size=*/128, copy_stream, workspace_mr, enable_prefetch); + EXPECT_TRUE(iter.begin() == iter.end()); +} + +using BatchDimsParam = std::tuple; + +class BatchLoadIteratorParameterizedTest : public BatchLoadIteratorTest, + public ::testing::WithParamInterface {}; + +TEST_P(BatchLoadIteratorParameterizedTest, VectorHostData) +{ + auto [batch_config, dims_config] = GetParam(); + auto [initialize, host_writeback] = batch_config; + auto [n_rows, n_cols, batch_size] = dims_config; + + std::vector host_data(n_rows * n_cols); + auto host_view = raft::make_host_matrix_view(host_data.data(), n_rows, n_cols); + + std::fill(host_view.data_handle(), host_view.data_handle() + n_rows * n_cols, IdxT(13)); + + run_and_verify_batched(host_view, batch_size, host_writeback, initialize); +} + +TEST_P(BatchLoadIteratorParameterizedTest, PinnedMemory) +{ + auto [batch_config, dims_config] = GetParam(); + auto [initialize, host_writeback] = batch_config; + auto [n_rows, n_cols, batch_size] = dims_config; + + auto pinned_matrix = raft::make_pinned_matrix(res, n_rows, n_cols); + auto pinned_view = + raft::mdspan, raft::row_major, PinnedAccessor>( + pinned_matrix.data_handle(), n_rows, n_cols); + std::fill(pinned_view.data_handle(), pinned_view.data_handle() + n_rows * n_cols, IdxT(13)); + + run_and_verify_batched(pinned_view, batch_size, host_writeback, initialize); +} + +TEST_P(BatchLoadIteratorParameterizedTest, PinnedMemoryForcedToHost) +{ + auto [batch_config, dims_config] = GetParam(); + auto [initialize, host_writeback] = batch_config; + auto [n_rows, n_cols, batch_size] = dims_config; + + auto pinned_matrix = raft::make_pinned_matrix(res, n_rows, n_cols); + + auto pinned_view = + raft::mdspan, raft::row_major, HostAccessor>( + pinned_matrix.data_handle(), n_rows, n_cols); + + std::fill(pinned_view.data_handle(), pinned_view.data_handle() + n_rows * n_cols, IdxT(13)); + run_and_verify_batched(pinned_view, batch_size, host_writeback, initialize); +} + +TEST_P(BatchLoadIteratorParameterizedTest, ManagedMemory) +{ + auto [batch_config, dims_config] = GetParam(); + auto [initialize, host_writeback] = batch_config; + auto [n_rows, n_cols, batch_size] = dims_config; + + auto managed_matrix = raft::make_managed_matrix(res, n_rows, n_cols); + auto managed_view = managed_matrix.view(); + + std::fill(managed_view.data_handle(), managed_view.data_handle() + n_rows * n_cols, IdxT(13)); + + run_and_verify_batched(managed_view, batch_size, host_writeback, initialize); +} + +TEST_P(BatchLoadIteratorParameterizedTest, ManagedMemoryForcedToHost) +{ + auto [batch_config, dims_config] = GetParam(); + auto [initialize, host_writeback] = batch_config; + auto [n_rows, n_cols, batch_size] = dims_config; + + auto managed_matrix = raft::make_managed_matrix(res, n_rows, n_cols); + + auto managed_view = + raft::mdspan, raft::row_major, HostAccessor>( + managed_matrix.data_handle(), n_rows, n_cols); + + std::fill(managed_view.data_handle(), managed_view.data_handle() + n_rows * n_cols, IdxT(13)); + + run_and_verify_batched(managed_view, batch_size, host_writeback, initialize); +} + +TEST_P(BatchLoadIteratorParameterizedTest, DeviceMemory) +{ + auto [batch_config, dims_config] = GetParam(); + auto [initialize, host_writeback] = batch_config; + auto [n_rows, n_cols, batch_size] = dims_config; + + auto device_matrix = raft::make_device_matrix(res, n_rows, n_cols); + auto device_view = device_matrix.view(); + + raft::matrix::fill(res, device_view, IdxT(13)); + + run_and_verify_batched(device_view, batch_size, host_writeback, initialize); +} + +/** + * Drive the runtime-dispatched wrapper. Verifies that: + * * a host pointer dispatches to the copy_device branch (does_copy() == true), and + * * a device pointer dispatches to passthrough (does_copy() == false), + * and that initialize-only iteration yields the expected pre-fill on each batch. + */ +TEST_F(BatchLoadIteratorTest, MakeBatchLoadIteratorHostPtr) +{ + const int64_t n_rows = 256; + const int64_t n_cols = 32; + const size_t batch_size_rows = 64; + + std::vector host_data(n_rows * n_cols, IdxT(13)); + auto [copy_stream, enable_prefetch] = bli::get_prefetch_stream(res); + + auto iter = bli::make_batch_load_iterator(res, + host_data.data(), + n_rows, + n_cols, + batch_size_rows, + copy_stream, + raft::resource::get_workspace_resource_ref(res), + enable_prefetch); + EXPECT_TRUE(iter.does_copy()); + + std::vector readback(n_rows * n_cols, IdxT(0)); + int64_t total = 0; + iter.prefetch_next_batch(); + for (auto const& batch : iter) { + if (batch.size() == 0) break; + raft::copy(readback.data() + total * n_cols, + batch.data(), + batch.size() * n_cols, + raft::resource::get_cuda_stream(res)); + total += batch.size(); + iter.prefetch_next_batch(); + } + raft::resource::sync_stream(res); + EXPECT_EQ(total, n_rows); + for (int64_t i = 0; i < n_rows * n_cols; ++i) { + EXPECT_EQ(readback[i], IdxT(13)) << "Mismatch at index " << i; + } +} + +TEST_F(BatchLoadIteratorTest, MakeBatchLoadIteratorDevicePtr) +{ + const int64_t n_rows = 256; + const int64_t n_cols = 32; + const size_t batch_size_rows = 64; + + auto device_matrix = raft::make_device_matrix(res, n_rows, n_cols); + raft::matrix::fill(res, device_matrix.view(), IdxT(13)); + + auto [copy_stream, enable_prefetch] = bli::get_prefetch_stream(res); + auto iter = bli::make_batch_load_iterator(res, + device_matrix.data_handle(), + n_rows, + n_cols, + batch_size_rows, + copy_stream, + raft::resource::get_workspace_resource_ref(res), + enable_prefetch); + EXPECT_FALSE(iter.does_copy()); + + std::vector readback(n_rows * n_cols, IdxT(0)); + int64_t total = 0; + iter.prefetch_next_batch(); + for (auto const& batch : iter) { + if (batch.size() == 0) break; + raft::copy(readback.data() + total * n_cols, + batch.data(), + batch.size() * n_cols, + raft::resource::get_cuda_stream(res)); + total += batch.size(); + iter.prefetch_next_batch(); + } + raft::resource::sync_stream(res); + EXPECT_EQ(total, n_rows); + for (int64_t i = 0; i < n_rows * n_cols; ++i) { + EXPECT_EQ(readback[i], IdxT(13)) << "Mismatch at index " << i; + } +} + +static const std::array kBatchConfigs = {{ + {/*initialize=*/true, /*host_writeback=*/false}, + {/*initialize=*/false, /*host_writeback=*/true}, + {/*initialize=*/true, /*host_writeback=*/true}, +}}; + +static const std::array kDimsConfigs = {{ + {/*n_rows=*/64, /*n_cols=*/32, /*batch_size=*/256}, // rows less than batch size, single batch + {/*n_rows=*/64, /*n_cols=*/32, /*batch_size=*/64}, // single batch + {/*n_rows=*/256, /*n_cols=*/32, /*batch_size=*/32}, // multiple batches + {/*n_rows=*/500, + /*n_cols=*/32, + /*batch_size=*/128}, // multiple batches, partial batch in the end +}}; + +INSTANTIATE_TEST_SUITE_P(BatchConfigs, + BatchLoadIteratorParameterizedTest, + ::testing::Combine(::testing::ValuesIn(kBatchConfigs), + ::testing::ValuesIn(kDimsConfigs))); + +} // namespace cuvs::neighbors::cagra