From 609b0f3db6643c0a62dd89d59202ec4a43b339c5 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Mon, 16 Feb 2026 18:52:40 +0000 Subject: [PATCH 01/40] prune kernel smem --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 41efa1686f..270e76838a 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -166,14 +166,20 @@ __global__ void kern_prune(const IdxT* const knn_graph, // [graph_chunk_size, g uint64_t* const stats) { __shared__ uint32_t smem_num_detour[MAX_DEGREE]; + extern __shared__ unsigned char smem_buf[]; + IdxT* const smem_knn_iA_neighbors = reinterpret_cast(smem_buf); + uint64_t* const num_retain = stats; uint64_t* const num_full = stats + 1; const uint64_t iA = blockIdx.x + (batch_size * batch_id); if (iA >= graph_size) { return; } + + // Load this node's neighbor row into shared memory to reduce global reads for (uint32_t k = threadIdx.x; k < graph_degree; k += blockDim.x) { - smem_num_detour[k] = 0; - if (knn_graph[k + ((uint64_t)graph_degree * iA)] == iA) { + smem_num_detour[k] = 0; + smem_knn_iA_neighbors[k] = knn_graph[k + ((uint64_t)graph_degree * iA)]; + if (smem_knn_iA_neighbors[k] == iA) { // Lower the priority of self-edge smem_num_detour[k] = graph_degree; } @@ -182,14 +188,14 @@ __global__ void kern_prune(const IdxT* const knn_graph, // [graph_chunk_size, g // count number of detours (A->D->B) for (uint32_t kAD = 0; kAD < graph_degree - 1; kAD++) { - const uint64_t iD = knn_graph[kAD + (graph_degree * iA)]; + const uint64_t iD = smem_knn_iA_neighbors[kAD]; if (iD >= graph_size) { continue; } for (uint32_t kDB = threadIdx.x; kDB < graph_degree; kDB += blockDim.x) { const uint64_t iB_candidate = knn_graph[kDB + ((uint64_t)graph_degree * iD)]; for (uint32_t kAB = kAD + 1; kAB < graph_degree; kAB++) { // if ( kDB < kAB ) { - const uint64_t iB = knn_graph[kAB + (graph_degree * iA)]; + const uint64_t iB = smem_knn_iA_neighbors[kAB]; if (iB == iB_candidate) { atomicAdd(smem_num_detour + kAB, 1); break; @@ -1298,9 +1304,10 @@ void optimize( RAFT_CUDA_TRY(cudaMemsetAsync( dev_stats.data_handle(), 0, sizeof(uint64_t) * 2, raft::resource::get_cuda_stream(res))); + const size_t prune_smem_size = knn_graph_degree * sizeof(IdxT); for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { kern_prune - <<>>( + <<>>( d_input_graph.data_handle(), graph_size, knn_graph_degree, From a320e0e90453527e156da007bb96dc00de3898c0 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Wed, 18 Feb 2026 16:26:54 +0000 Subject: [PATCH 02/40] reduce copies within reverse graph compute --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 61 ++++++++++++------- 1 file changed, 39 insertions(+), 22 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 270e76838a..f3b0f0778e 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -244,6 +244,29 @@ __global__ void kern_make_rev_graph(const IdxT* const dest_nodes, // [graph_ } } +// Build reverse graph from column k of output_graph (avoids per-column host fill and H2D copy). +template +__global__ void kern_make_rev_graph_column(const IdxT* const output_graph, // [graph_size, degree] + IdxT* const rev_graph, + uint32_t* const rev_graph_count, + const uint32_t graph_size, + const uint32_t degree, + const uint32_t k) +{ + const uint64_t tid = threadIdx.x + (blockDim.x * blockIdx.x); + const uint64_t tnum = blockDim.x * gridDim.x; + + for (uint64_t src_id = tid; src_id < graph_size; src_id += tnum) { + const IdxT dest_id = output_graph[k + (static_cast(degree) * src_id)]; + if (dest_id >= graph_size) continue; + + const uint32_t pos = atomicAdd(rev_graph_count + dest_id, 1); + if (pos < degree) { + rev_graph[(static_cast(degree) * dest_id) + pos] = static_cast(src_id); + } + } +} + template __device__ __host__ LabelT get_root_label(IdxT i, const LabelT* label) { @@ -1444,32 +1467,26 @@ void optimize( graph_size * sizeof(uint32_t), raft::resource::get_cuda_stream(res))); - auto dest_nodes = raft::make_host_vector(graph_size); - auto d_dest_nodes = - raft::make_device_mdarray(res, large_tmp_mr, raft::make_extents(graph_size)); - - for (uint64_t k = 0; k < output_graph_degree; k++) { -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - // dest_nodes.data_handle()[i] = output_graph_ptr[k + (output_graph_degree * i)]; - dest_nodes(i) = output_graph_ptr[k + (output_graph_degree * i)]; - } - raft::resource::sync_stream(res); - - raft::copy(d_dest_nodes.data_handle(), - dest_nodes.data_handle(), - graph_size, - raft::resource::get_cuda_stream(res)); + // Copy full output graph to device once; kernel indexes by column k (no per-column H2D copy). + // TODO: depending on available device memory, this may need to be split into multiple copies. + auto d_output_graph = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(graph_size, output_graph_degree)); + raft::copy(d_output_graph.data_handle(), + output_graph_ptr, + static_cast(graph_size) * output_graph_degree, + raft::resource::get_cuda_stream(res)); - dim3 threads(256, 1, 1); - dim3 blocks(1024, 1, 1); - kern_make_rev_graph<<>>( - d_dest_nodes.data_handle(), + dim3 threads(256, 1, 1); + dim3 blocks(1024, 1, 1); + for (uint32_t k = 0; k < output_graph_degree; k++) { + kern_make_rev_graph_column<<>>( + d_output_graph.data_handle(), d_rev_graph.data_handle(), d_rev_graph_count.data_handle(), graph_size, - output_graph_degree); - RAFT_LOG_DEBUG("# Making reverse graph on GPUs: %lu / %u \r", k, output_graph_degree); + output_graph_degree, + k); + RAFT_LOG_DEBUG("# Making reverse graph on GPUs: %u / %u \r", k, output_graph_degree); } raft::resource::sync_stream(res); From 6d1a6187f2cc138c4941608d4d9c746a11e0d774 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Thu, 19 Feb 2026 23:07:23 +0000 Subject: [PATCH 03/40] optimize() draft move more compute to GPU --- .../neighbors/detail/cagra/cagra_build.cuh | 2 + cpp/src/neighbors/detail/cagra/graph_core.cuh | 864 ++++++++++++------ 2 files changed, 590 insertions(+), 276 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/cagra_build.cuh b/cpp/src/neighbors/detail/cagra/cagra_build.cuh index 97d7bb1bac..152b603286 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_build.cuh @@ -822,6 +822,8 @@ inline std::pair optimize_workspace_size(size_t n_rows, size_t index_size, bool mst_optimize = false) { + // TODO: MODIFY!! + // MST optimization memory (host only) size_t mst_host = n_rows * index_size; // mst_graph_num_edges if (mst_optimize) { diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index f3b0f0778e..f2cd79ecb6 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -161,8 +161,8 @@ __global__ void kern_prune(const IdxT* const knn_graph, // [graph_chunk_size, g const uint32_t degree, const uint32_t batch_size, const uint32_t batch_id, - uint8_t* const detour_count, // [graph_chunk_size, graph_degree] - uint32_t* const num_no_detour_edges, // [graph_size] + uint8_t* const detour_count, // [batch_size, graph_degree] + uint32_t* const num_no_detour_edges, // [batch_size] uint64_t* const stats) { __shared__ uint32_t smem_num_detour[MAX_DEGREE]; @@ -172,7 +172,9 @@ __global__ void kern_prune(const IdxT* const knn_graph, // [graph_chunk_size, g uint64_t* const num_retain = stats; uint64_t* const num_full = stats + 1; - const uint64_t iA = blockIdx.x + (batch_size * batch_id); + const uint64_t iA = blockIdx.x + (batch_size * batch_id); + const uint64_t iA_batch = iA % static_cast(batch_size); + if (iA >= graph_size) { return; } // Load this node's neighbor row into shared memory to reduce global reads @@ -208,7 +210,7 @@ __global__ void kern_prune(const IdxT* const knn_graph, // [graph_chunk_size, g uint32_t num_edges_no_detour = 0; for (uint32_t k = threadIdx.x; k < graph_degree; k += blockDim.x) { - detour_count[k + (graph_degree * iA)] = min(smem_num_detour[k], (uint32_t)255); + detour_count[k + (graph_degree * iA_batch)] = min(smem_num_detour[k], (uint32_t)255); if (smem_num_detour[k] == 0) { num_edges_no_detour++; } } num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 1); @@ -219,7 +221,7 @@ __global__ void kern_prune(const IdxT* const knn_graph, // [graph_chunk_size, g num_edges_no_detour = min(num_edges_no_detour, degree); if (threadIdx.x == 0) { - num_no_detour_edges[iA] = num_edges_no_detour; + num_no_detour_edges[iA_batch] = num_edges_no_detour; atomicAdd((unsigned long long int*)num_retain, (unsigned long long int)num_edges_no_detour); if (num_edges_no_detour >= degree) { atomicAdd((unsigned long long int*)num_full, 1); } } @@ -244,26 +246,179 @@ __global__ void kern_make_rev_graph(const IdxT* const dest_nodes, // [graph_ } } -// Build reverse graph from column k of output_graph (avoids per-column host fill and H2D copy). +// Select output_graph_degree neighbors with smallest detour count per node (writes to device). template -__global__ void kern_make_rev_graph_column(const IdxT* const output_graph, // [graph_size, degree] - IdxT* const rev_graph, - uint32_t* const rev_graph_count, - const uint32_t graph_size, - const uint32_t degree, - const uint32_t k) +__global__ void kern_select_smallest_detour_neighbors( + const IdxT* const knn_graph, + uint64_t graph_size, + uint64_t knn_graph_degree, + uint64_t output_graph_degree, + const uint8_t* const d_detour_count, // [batch_size, graph_degree] + IdxT* output_graph_ptr, // [batch_size, output_graph_degree] + const uint32_t batch_size, + const uint32_t batch_id) { - const uint64_t tid = threadIdx.x + (blockDim.x * blockIdx.x); - const uint64_t tnum = blockDim.x * gridDim.x; + // FIXME: this does not really work for num_warps > 1 + constexpr unsigned warp_mask = 0xffffffff; + const uint32_t num_warps = blockDim.x / raft::WarpSize; + extern __shared__ unsigned char smem_buf[]; + uint32_t* smem_indices = reinterpret_cast(smem_buf); + uint16_t* smem_detour_count = + reinterpret_cast(&smem_indices[knn_graph_degree * num_warps]); - for (uint64_t src_id = tid; src_id < graph_size; src_id += tnum) { - const IdxT dest_id = output_graph[k + (static_cast(degree) * src_id)]; - if (dest_id >= graph_size) continue; + const uint32_t wid = threadIdx.x / raft::WarpSize; + const uint32_t lane_id = threadIdx.x % raft::WarpSize; + const uint64_t nid = static_cast(blockIdx.x) * num_warps + + (static_cast(batch_size) * batch_id * num_warps) + wid; - const uint32_t pos = atomicAdd(rev_graph_count + dest_id, 1); - if (pos < degree) { - rev_graph[(static_cast(degree) * dest_id) + pos] = static_cast(src_id); + const uint64_t nid_batch = nid % static_cast(batch_size); + + if (nid >= graph_size) return; + + for (uint32_t k = lane_id; k < knn_graph_degree; k += raft::WarpSize) { + smem_detour_count[(knn_graph_degree * wid) + k] = + d_detour_count[nid_batch * knn_graph_degree + k]; + smem_indices[(knn_graph_degree * wid) + k] = k; + } + __syncwarp(warp_mask); + + for (uint32_t i = 0; i < output_graph_degree; i++) { + uint32_t local_min = 256; + uint32_t local_idx = 0xffffffff; + for (uint32_t k = lane_id; k < knn_graph_degree; k += raft::WarpSize) { + uint32_t c = smem_detour_count[(knn_graph_degree * wid) + k]; + if (c < local_min) { + local_min = c; + local_idx = smem_indices[(knn_graph_degree * wid) + k]; + } + } + uint32_t local_min_with_tag = (local_min << 16) | local_idx; + for (int offset = raft::WarpSize / 2; offset > 0; offset /= 2) { + uint32_t other = __shfl_down_sync(warp_mask, local_min_with_tag, offset); + local_min_with_tag = (local_min_with_tag <= other) ? local_min_with_tag : other; + } + uint32_t warp_min_tag = __shfl_sync(warp_mask, local_min_with_tag, 0); + uint32_t warp_local_idx = warp_min_tag & 0xffff; + + if (local_idx == warp_local_idx) { + output_graph_ptr[nid_batch * output_graph_degree + i] = + knn_graph[knn_graph_degree * nid + warp_local_idx]; + smem_detour_count[knn_graph_degree * wid + warp_local_idx] = 255; + } + __syncwarp(warp_mask); + } +} + +// Helper functions for merging the graph +template +__device__ unsigned int warp_pos_in_array(T val, const T* array, uint64_t num) +{ + unsigned int ret = num; + const uint32_t lane_id = threadIdx.x % 32; + for (uint64_t i = lane_id; i < num; i += 32) { + if (val == array[i]) { + ret = i; + break; + } + } + ret = __reduce_min_sync(0xffffffff, ret); + return ret; +} + +template +__device__ void thread_shift_array(T* array, uint64_t num) +{ + for (uint64_t i = num; i > 0; i--) { + array[i] = array[i - 1]; + } +} + +template +__global__ void kern_merge_graph(IdxT* output_graph, + const IdxT* const rev_graph, + uint32_t* const rev_graph_count, // [graph_size] + const uint32_t graph_size, + const uint32_t output_graph_degree, + const IdxT* const mst_graph, + const uint32_t mst_graph_degree, + const uint32_t* const mst_graph_num_edges_ptr, + const uint32_t batch_size, + const uint32_t batch_id, + bool guarantee_connectivity, + bool* check_num_protected_edges) +{ + extern __shared__ unsigned char smem_buf[]; + IdxT* smem_sorted_output_graph = reinterpret_cast(smem_buf); + + const uint32_t wid = threadIdx.x / 32; + const uint32_t lane_id = threadIdx.x % 32; + const uint32_t num_warps = blockDim.x / 32; + const uint64_t nid = blockIdx.x * num_warps + (batch_size * batch_id * num_warps) + wid; + if (nid >= graph_size) { return; } + + if (lane_id == 0) check_num_protected_edges[0] = true; + + const auto mst_graph_num_edges = mst_graph_num_edges_ptr[nid]; + // If guarantee_connectivity == true, use a temporal list to merge the + // neighbor lists of the graphs. + if (guarantee_connectivity) { + for (uint32_t i = lane_id; i < mst_graph_degree; i += 32) { + smem_sorted_output_graph[i] = mst_graph[nid * mst_graph_degree + i]; } + __syncwarp(); + for (uint32_t pruned_j = 0, output_j = mst_graph_num_edges; + (pruned_j < output_graph_degree) && (output_j < output_graph_degree); + pruned_j++) { + const auto v = output_graph[output_graph_degree * nid + pruned_j]; + unsigned int dup = 0; + for (uint32_t m = lane_id; m < output_j; m += 32) { + if (v == smem_sorted_output_graph[m]) { + dup = 1; + break; + } + } + + unsigned int warp_dup = __ballot_sync(0xffffffff, dup); + if (warp_dup == 0) { + if (lane_id == 0) smem_sorted_output_graph[output_j] = v; + output_j++; + } + __syncwarp(); + } + } + + else { + for (uint32_t i = lane_id; i < output_graph_degree; i += 32) { + smem_sorted_output_graph[i] = output_graph[output_graph_degree * nid + i]; + } + __syncwarp(); + } + + const auto num_protected_edges = max(mst_graph_num_edges, output_graph_degree / 2); + + if (num_protected_edges > output_graph_degree) { check_num_protected_edges[0] = false; } + if (num_protected_edges == output_graph_degree) { return; } + + auto kr = min(rev_graph_count[nid], output_graph_degree); + + while (kr) { + kr -= 1; + if (rev_graph[kr + (output_graph_degree * nid)] < graph_size) { + uint64_t pos = warp_pos_in_array( + rev_graph[kr + (output_graph_degree * nid)], smem_sorted_output_graph, output_graph_degree); + if (pos < num_protected_edges) { continue; } + uint64_t num_shift = pos - num_protected_edges; + if (pos >= output_graph_degree) { num_shift = output_graph_degree - num_protected_edges - 1; } + if (lane_id == 0) { + thread_shift_array(smem_sorted_output_graph + num_protected_edges, num_shift); + smem_sorted_output_graph[num_protected_edges] = rev_graph[kr + (output_graph_degree * nid)]; + } + __syncwarp(); + } + } + + for (uint32_t i = lane_id; i < output_graph_degree; i += 32) { + output_graph[(output_graph_degree * nid) + i] = smem_sorted_output_graph[i]; } } @@ -737,11 +892,11 @@ void mst_opt_update_graph(IdxT* mst_graph_ptr, // an approximate MST. // * If the input kNN graph is disconnected, random connection is added to the largest cluster. // -template +template void mst_optimization(raft::resources const& res, - raft::host_matrix_view input_graph, - raft::host_matrix_view output_graph, - raft::host_vector_view mst_graph_num_edges, + InputMatrixView input_graph, + OutputMatrixView output_graph, + VectorView mst_graph_num_edges, bool use_gpu = true) { if (use_gpu) { @@ -1185,6 +1340,7 @@ void count_2hop_detours(raft::host_matrix_view k } } +// TODO allow pinned input for both knn_graph and new_graph template , raft::memory_type::host>> @@ -1213,9 +1369,10 @@ void optimize( "cagra::graph::optimize(%zu, %zu, %u)", graph_size, knn_graph_degree, output_graph_degree); // MST optimization - auto mst_graph = raft::make_host_matrix(0, 0); - auto mst_graph_num_edges = raft::make_host_vector(graph_size); + auto mst_graph = raft::make_pinned_matrix(res, 0, 0); + auto mst_graph_num_edges = raft::make_pinned_vector(res, graph_size); auto mst_graph_num_edges_ptr = mst_graph_num_edges.data_handle(); + #pragma omp parallel for for (uint64_t i = 0; i < graph_size; i++) { mst_graph_num_edges_ptr[i] = 0; @@ -1223,10 +1380,10 @@ void optimize( if (guarantee_connectivity) { raft::common::nvtx::range block_scope( "cagra::graph::optimize/check_connectivity"); - mst_graph = - raft::make_host_matrix(graph_size, output_graph_degree); + mst_graph = raft::make_pinned_matrix( + res, graph_size, output_graph_degree); RAFT_LOG_INFO("MST optimization is used to guarantee graph connectivity."); - mst_optimization(res, knn_graph, mst_graph.view(), mst_graph_num_edges.view(), use_gpu); + mst_optimization(res, knn_graph, mst_graph.view(), mst_graph_num_edges.view(), use_gpu); for (uint64_t i = 0; i < graph_size; i++) { if (i < 8 || i >= graph_size - 8) { @@ -1235,6 +1392,37 @@ void optimize( } } + uint32_t batch_size = + std::min(static_cast(graph_size), static_cast(256 * 1024)); + const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; + + // + // If the available device memory is insufficient, do not use the GPU to count + // the number of 2-hop detours, but use the CPU. + // + // TODO: we should decide on a global strategy for this in a single place + // it comes down to input memory type and available memory which data should be copied to GPU + bool _use_gpu_prune = use_gpu; + if (_use_gpu_prune) { + try { + auto d_detour_count = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(batch_size, knn_graph_degree)); + auto d_num_no_detour_edges = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(batch_size)); + auto d_output_graph = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(batch_size, output_graph_degree)); + // TODO we also want to consider pinned memory in case we are short on memory + auto d_input_graph = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(graph_size, knn_graph_degree)); + } catch (std::bad_alloc& e) { + RAFT_LOG_DEBUG("Insufficient memory for pruning on GPU"); + _use_gpu_prune = false; + } catch (raft::logic_error& e) { + RAFT_LOG_DEBUG("Insufficient memory for pruning on GPU (logic error)"); + _use_gpu_prune = false; + } + } + { raft::common::nvtx::range block_scope( "cagra::graph::optimize/prune"); @@ -1253,63 +1441,10 @@ void optimize( // specified number of edges are picked up for each node, starting with the // edge with the lowest number of 2-hop detours. // - auto detour_count = raft::make_host_matrix(graph_size, knn_graph_degree); - - // - // If the available device memory is insufficient, do not use the GPU to count - // the number of 2-hop detours, but use the CPU. - // - bool _use_gpu = use_gpu; - if (_use_gpu) { - try { - auto d_detour_count = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size, knn_graph_degree)); - auto d_num_no_detour_edges = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size)); - auto d_input_graph = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size, knn_graph_degree)); - } catch (std::bad_alloc& e) { - RAFT_LOG_DEBUG("Insufficient memory for 2-hop node counting on GPU"); - _use_gpu = false; - } catch (raft::logic_error& e) { - RAFT_LOG_DEBUG("Insufficient memory for 2-hop node counting on GPU (logic error)"); - _use_gpu = false; - } - } - if (_use_gpu) { - // Count 2-hop detours on GPU - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/prune/2-hop-counting-by-GPU"); - const double time_2hop_count_start = cur_time(); - - uint64_t num_keep __attribute__((unused)) = 0; - uint64_t num_full __attribute__((unused)) = 0; - auto d_detour_count = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size, knn_graph_degree)); - - RAFT_CUDA_TRY(cudaMemsetAsync(d_detour_count.data_handle(), - 0xff, - graph_size * knn_graph_degree * sizeof(uint8_t), - raft::resource::get_cuda_stream(res))); - - auto d_num_no_detour_edges = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size)); - RAFT_CUDA_TRY(cudaMemsetAsync(d_num_no_detour_edges.data_handle(), - 0x00, - graph_size * sizeof(uint32_t), - raft::resource::get_cuda_stream(res))); - - auto dev_stats = raft::make_device_vector(res, 2); - auto host_stats = raft::make_host_vector(2); - + if (_use_gpu_prune) { + // Pruning on GPU RAFT_LOG_DEBUG("# Pruning kNN Graph on GPUs\r"); - // Copy knn_graph over to device if necessary - device_matrix_view_from_host d_input_graph( - res, - raft::make_host_matrix_view( - knn_graph.data_handle(), graph_size, knn_graph_degree)); - constexpr int MAX_DEGREE = 1024; if (knn_graph_degree > MAX_DEGREE) { RAFT_FAIL( @@ -1318,17 +1453,47 @@ void optimize( knn_graph_degree, MAX_DEGREE); } - const uint32_t batch_size = - std::min(static_cast(graph_size), static_cast(256 * 1024)); - const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; - const dim3 threads_prune(32, 1, 1); - const dim3 blocks_prune(batch_size, 1, 1); + const double prune_start = cur_time(); + + uint64_t num_keep __attribute__((unused)) = 0; + uint64_t num_full __attribute__((unused)) = 0; + auto dev_stats = raft::make_device_vector(res, 2); + auto host_stats = raft::make_host_vector(2); RAFT_CUDA_TRY(cudaMemsetAsync( dev_stats.data_handle(), 0, sizeof(uint64_t) * 2, raft::resource::get_cuda_stream(res))); - const size_t prune_smem_size = knn_graph_degree * sizeof(IdxT); + // Copy knn_graph over to device if necessary + // TODO: should we use pinned memory if we have issues fitting on GPU? + device_matrix_view_from_host d_input_graph( + res, + raft::make_host_matrix_view( + knn_graph.data_handle(), graph_size, knn_graph_degree)); + + // data structures per batch + auto d_detour_count = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(batch_size, knn_graph_degree)); + auto d_num_no_detour_edges = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(batch_size)); + auto d_output_graph = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(batch_size, output_graph_degree)); + for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { + // initialize the detour_count and num_no_detour_edges for the current batch + RAFT_CUDA_TRY(cudaMemsetAsync(d_detour_count.data_handle(), + 0xff, + batch_size * knn_graph_degree * sizeof(uint8_t), + raft::resource::get_cuda_stream(res))); + + RAFT_CUDA_TRY(cudaMemsetAsync(d_num_no_detour_edges.data_handle(), + 0x00, + batch_size * sizeof(uint32_t), + raft::resource::get_cuda_stream(res))); + + // count 2-hop detours for the current batch + const dim3 threads_prune(32, 1, 1); + const dim3 blocks_prune(batch_size, 1, 1); + const size_t prune_smem_size = knn_graph_degree * sizeof(IdxT); kern_prune <<>>( d_input_graph.data_handle(), @@ -1340,6 +1505,30 @@ void optimize( d_detour_count.data_handle(), d_num_no_detour_edges.data_handle(), dev_stats.data_handle()); + + // select smallest-detour neighbors for the current batch + const size_t select_smem_size = + (knn_graph_degree * knn_graph_degree) * (sizeof(uint16_t) + sizeof(uint32_t)); + const dim3 threads_select(32, 1, 1); + const dim3 blocks_select(batch_size, 1, 1); + kern_select_smallest_detour_neighbors + <<>>(d_input_graph.data_handle(), + graph_size, + knn_graph_degree, + output_graph_degree, + d_detour_count.data_handle(), + d_output_graph.data_handle(), + batch_size, + i_batch); + + raft::copy(output_graph_ptr, + d_output_graph.data_handle() + i_batch * batch_size * output_graph_degree, + static_cast(batch_size) * output_graph_degree, + raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); RAFT_LOG_DEBUG( "# Pruning kNN Graph on GPUs (%.1lf %%)\r", @@ -1348,96 +1537,93 @@ void optimize( raft::resource::sync_stream(res); RAFT_LOG_DEBUG("\n"); - raft::copy(detour_count.data_handle(), - d_detour_count.data_handle(), - detour_count.size(), - raft::resource::get_cuda_stream(res)); - raft::copy( host_stats.data_handle(), dev_stats.data_handle(), 2, raft::resource::get_cuda_stream(res)); num_keep = host_stats.data_handle()[0]; num_full = host_stats.data_handle()[1]; - const double time_2hop_count_end = cur_time(); + const double prune_end = cur_time(); RAFT_LOG_DEBUG( - "# Time for 2-hop detour counting on GPU: %.1lf sec, " + "# Time for pruning on GPU: %.1lf sec, " "avg_no_detour_edges_per_node: %.2lf/%u, " "nodes_with_no_detour_at_all_edges: %.1lf%%", - time_2hop_count_end - time_2hop_count_start, + prune_end - prune_start, (double)num_keep / graph_size, output_graph_degree, (double)num_full / graph_size * 100); } else { - // Count 2-hop detours on CPU - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/prune/2-hop-counting-by-CPU"); - const double time_2hop_count_start = cur_time(); + // Pruning on CPU + auto detour_count = raft::make_host_matrix(graph_size, knn_graph_degree); - count_2hop_detours(knn_graph, detour_count.view()); + { + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/prune/2-hop-counting-by-CPU"); + const double time_2hop_count_start = cur_time(); - const double time_2hop_count_end = cur_time(); - RAFT_LOG_DEBUG("# Time for 2-hop detour counting on CPU: %.1lf sec", - time_2hop_count_end - time_2hop_count_start); - } + count_2hop_detours(knn_graph, detour_count.view()); - // Create pruned kNN graph - bool invalid_neighbor_list = false; + const double time_2hop_count_end = cur_time(); + RAFT_LOG_DEBUG("# Time for 2-hop detour counting on CPU: %.1lf sec", + time_2hop_count_end - time_2hop_count_start); + } + bool invalid_neighbor_list = false; #pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - // Find the `output_graph_degree` smallest detourable count nodes by checking the detourable - // count of the neighbors while increasing the target detourable count from zero. - uint64_t pk = 0; - uint32_t num_detour = 0; - for (uint32_t l = 0; l < knn_graph_degree && pk < output_graph_degree; l++) { - uint32_t next_num_detour = std::numeric_limits::max(); - for (uint64_t k = 0; k < knn_graph_degree; k++) { - const auto num_detour_k = detour_count(i, k); - // Find the detourable count to check in the next iteration - if (num_detour_k > num_detour) { - next_num_detour = std::min(static_cast(num_detour_k), next_num_detour); - } - - // Store the neighbor index if its detourable count is equal to `num_detour`. - if (num_detour_k != num_detour) { continue; } + for (uint64_t i = 0; i < graph_size; i++) { + // Find the `output_graph_degree` smallest detourable count nodes by checking the detourable + // count of the neighbors while increasing the target detourable count from zero. + uint64_t pk = 0; + uint32_t num_detour = 0; + for (uint32_t l = 0; l < knn_graph_degree && pk < output_graph_degree; l++) { + uint32_t next_num_detour = std::numeric_limits::max(); + for (uint64_t k = 0; k < knn_graph_degree; k++) { + const auto num_detour_k = detour_count(i, k); + // Find the detourable count to check in the next iteration + if (num_detour_k > num_detour) { + next_num_detour = std::min(static_cast(num_detour_k), next_num_detour); + } - // Check duplication and append - const auto candidate_node = knn_graph(i, k); - bool dup = false; - for (uint32_t dk = 0; dk < pk; dk++) { - if (candidate_node == output_graph_ptr[i * output_graph_degree + dk]) { - dup = true; - break; + // Store the neighbor index if its detourable count is equal to `num_detour`. + if (num_detour_k != num_detour) { continue; } + + // Check duplication and append + const auto candidate_node = knn_graph(i, k); + bool dup = false; + for (uint32_t dk = 0; dk < pk; dk++) { + if (candidate_node == output_graph_ptr[i * output_graph_degree + dk]) { + dup = true; + break; + } } - } - if (!dup && candidate_node < graph_size) { - output_graph_ptr[i * output_graph_degree + pk] = candidate_node; - pk += 1; + if (!dup && candidate_node < graph_size) { + output_graph_ptr[i * output_graph_degree + pk] = candidate_node; + pk += 1; + } + if (pk >= output_graph_degree) break; } if (pk >= output_graph_degree) break; - } - if (pk >= output_graph_degree) break; - if (next_num_detour == std::numeric_limits::max()) { - // There are no valid edges enough in the initial kNN graph. Break the loop here and catch - // the error at the next validation (pk != output_graph_degree). - break; + if (next_num_detour == std::numeric_limits::max()) { + // There are no valid edges enough in the initial kNN graph. Break the loop here and + // catch the error at the next validation (pk != output_graph_degree). + break; + } + num_detour = next_num_detour; + } + if (pk != output_graph_degree) { + RAFT_LOG_DEBUG( + "Couldn't find the output_graph_degree (%lu) smallest detourable count nodes for " + "node %lu in the rank-based node reranking process", + output_graph_degree, + i); + invalid_neighbor_list = true; } - num_detour = next_num_detour; - } - if (pk != output_graph_degree) { - RAFT_LOG_DEBUG( - "Couldn't find the output_graph_degree (%lu) smallest detourable count nodes for " - "node %lu in the rank-based node reranking process", - output_graph_degree, - i); - invalid_neighbor_list = true; } + RAFT_EXPECTS( + !invalid_neighbor_list, + "Could not generate an intermediate CAGRA graph because the initial kNN graph contains too " + "many invalid or duplicated neighbor nodes. This error can occur, for example, if too many " + "overflows occur during the norm computation between the dataset vectors."); } - RAFT_EXPECTS( - !invalid_neighbor_list, - "Could not generate an intermediate CAGRA graph because the initial kNN graph contains too " - "many invalid or duplicated neighbor nodes. This error can occur, for example, if too many " - "overflows occur during the norm computation between the dataset vectors."); const double time_prune_end = cur_time(); RAFT_LOG_DEBUG("# Pruning time: %.1lf ms", (time_prune_end - time_prune_start) * 1000.0); @@ -1446,155 +1632,281 @@ void optimize( auto rev_graph = raft::make_host_matrix(graph_size, output_graph_degree); auto rev_graph_count = raft::make_host_vector(graph_size); - { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/reverse"); + bool _use_gpu_rev_graph = use_gpu; + // TODO: should we use pinned memory if we have issues fitting on GPU? + if (_use_gpu_rev_graph) { + try { + auto d_rev_graph_count = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(graph_size)); + auto d_dest_nodes = + raft::make_device_mdarray(res, large_tmp_mr, raft::make_extents(graph_size)); + auto d_rev_graph = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(graph_size, output_graph_degree)); + } catch (std::bad_alloc& e) { + RAFT_LOG_DEBUG("Insufficient memory for reverse graph on GPU"); + _use_gpu_rev_graph = false; + } catch (raft::logic_error& e) { + RAFT_LOG_DEBUG("Insufficient memory for reverse graph on GPU (logic error)"); + _use_gpu_rev_graph = false; + } + } + + const double time_make_start = cur_time(); + if (_use_gpu_rev_graph) { // - // Make reverse graph + // Make reverse graph on GPU // - const double time_make_start = cur_time(); + auto d_rev_graph_count = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(graph_size)); device_matrix_view_from_host d_rev_graph(res, rev_graph.view()); - RAFT_CUDA_TRY(cudaMemsetAsync(d_rev_graph.data_handle(), - 0xff, - graph_size * output_graph_degree * sizeof(IdxT), - raft::resource::get_cuda_stream(res))); + device_matrix_view_from_host d_output_graph( + res, + raft::make_host_matrix_view( + output_graph_ptr, graph_size, output_graph_degree)); - auto d_rev_graph_count = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size)); - RAFT_CUDA_TRY(cudaMemsetAsync(d_rev_graph_count.data_handle(), - 0x00, - graph_size * sizeof(uint32_t), - raft::resource::get_cuda_stream(res))); - - // Copy full output graph to device once; kernel indexes by column k (no per-column H2D copy). - // TODO: depending on available device memory, this may need to be split into multiple copies. - auto d_output_graph = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size, output_graph_degree)); - raft::copy(d_output_graph.data_handle(), - output_graph_ptr, - static_cast(graph_size) * output_graph_degree, - raft::resource::get_cuda_stream(res)); + { + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/reverse"); + auto dest_nodes = raft::make_host_vector(graph_size); + auto d_dest_nodes = + raft::make_device_mdarray(res, large_tmp_mr, raft::make_extents(graph_size)); - dim3 threads(256, 1, 1); - dim3 blocks(1024, 1, 1); - for (uint32_t k = 0; k < output_graph_degree; k++) { - kern_make_rev_graph_column<<>>( - d_output_graph.data_handle(), - d_rev_graph.data_handle(), - d_rev_graph_count.data_handle(), - graph_size, - output_graph_degree, - k); - RAFT_LOG_DEBUG("# Making reverse graph on GPUs: %u / %u \r", k, output_graph_degree); - } + RAFT_CUDA_TRY(cudaMemsetAsync(d_rev_graph.data_handle(), + 0xff, + graph_size * output_graph_degree * sizeof(IdxT), + raft::resource::get_cuda_stream(res))); - raft::resource::sync_stream(res); - RAFT_LOG_DEBUG("\n"); + RAFT_CUDA_TRY(cudaMemsetAsync(d_rev_graph_count.data_handle(), + 0x00, + graph_size * sizeof(uint32_t), + raft::resource::get_cuda_stream(res))); - if (d_rev_graph.allocated_memory()) { - raft::copy(rev_graph.data_handle(), - d_rev_graph.data_handle(), - graph_size * output_graph_degree, + for (uint64_t k = 0; k < output_graph_degree; k++) { +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + // dest_nodes.data_handle()[i] = output_graph_ptr[k + (output_graph_degree * i)]; + dest_nodes(i) = output_graph_ptr[k + (output_graph_degree * i)]; + } + raft::resource::sync_stream(res); + + raft::copy(d_dest_nodes.data_handle(), + dest_nodes.data_handle(), + graph_size, + raft::resource::get_cuda_stream(res)); + + dim3 threads(256, 1, 1); + dim3 blocks(1024, 1, 1); + kern_make_rev_graph<<>>( + d_dest_nodes.data_handle(), + d_rev_graph.data_handle(), + d_rev_graph_count.data_handle(), + graph_size, + output_graph_degree); + RAFT_LOG_DEBUG("# Making reverse graph on GPUs: %lu / %u \r", k, output_graph_degree); + } + + raft::resource::sync_stream(res); + RAFT_LOG_DEBUG("\n"); + + if (d_rev_graph.allocated_memory()) { + raft::copy(rev_graph.data_handle(), + d_rev_graph.data_handle(), + graph_size * output_graph_degree, + raft::resource::get_cuda_stream(res)); + } + raft::copy(rev_graph_count.data_handle(), + d_rev_graph_count.data_handle(), + graph_size, raft::resource::get_cuda_stream(res)); + + const double time_make_end = cur_time(); + RAFT_LOG_DEBUG("# Making reverse graph time: %.1lf ms", + (time_make_end - time_make_start) * 1000.0); } - raft::copy(rev_graph_count.data_handle(), - d_rev_graph_count.data_handle(), - graph_size, - raft::resource::get_cuda_stream(res)); - const double time_make_end = cur_time(); - RAFT_LOG_DEBUG("# Making reverse graph time: %.1lf ms", - (time_make_end - time_make_start) * 1000.0); - } + { + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/combine"); - { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/combine"); - // - // Create search graphs from MST and pruned and reverse graphs - // - const double time_replace_start = cur_time(); + // Merging the prunned graph and the reverse graph + const double merge_graph_start = cur_time(); + + // Create a boolean variable on the GPU using RAFT device allocator + auto d_check_num_protected_edges = raft::make_device_scalar(res, true); + + const dim3 threads_merge(32, 1, 1); + const dim3 blocks_merge(batch_size, 1, 1); + const size_t merge_smem_size = (output_graph_degree + output_graph_degree) * sizeof(IdxT); + for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { + kern_merge_graph + <<>>( + d_output_graph.data_handle(), + d_rev_graph.data_handle(), + d_rev_graph_count.data_handle(), + graph_size, + output_graph_degree, + mst_graph.data_handle(), + output_graph_degree, + mst_graph_num_edges_ptr, + batch_size, + i_batch, + guarantee_connectivity, + d_check_num_protected_edges.data_handle()); + } + + bool check_num_protected_edges = true; + raft::copy(&check_num_protected_edges, + d_check_num_protected_edges.data_handle(), + 1, + raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); + + // TODO: is this required? + if (d_output_graph.allocated_memory()) { + raft::copy(output_graph_ptr, + d_output_graph.data_handle(), + graph_size * output_graph_degree, + raft::resource::get_cuda_stream(res)); + } + + const auto merge_graph_end = cur_time(); + RAFT_EXPECTS(check_num_protected_edges, + "Failed to merge the MST, pruned, and reverse edge graphs. " + "Some nodes have too " + "many MST optimization edges."); + + RAFT_LOG_DEBUG("# Time for merging graphs: %.1lf ms", + (merge_graph_end - merge_graph_start) * 1000.0); + } + } else { + { + // Make reverse graph on CPU + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/reverse"); + + auto rev_graph_ptr = rev_graph.data_handle(); + auto rev_graph_count_ptr = rev_graph_count.data_handle(); - bool check_num_protected_edges = true; #pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - auto my_rev_graph = rev_graph.data_handle() + (output_graph_degree * i); - auto my_out_graph = output_graph_ptr + (output_graph_degree * i); + for (uint64_t i = 0; i < graph_size; i++) { + rev_graph_count_ptr[i] = 0; + } - // If guarantee_connectivity == true, use a temporal list to merge the neighbor lists of the - // graphs. - std::vector temp_output_neighbor_list; - if (guarantee_connectivity) { - temp_output_neighbor_list.resize(output_graph_degree); - my_out_graph = temp_output_neighbor_list.data(); - const auto mst_graph_num_edges = mst_graph_num_edges_ptr[i]; - - // Set MST graph edges - for (uint32_t j = 0; j < mst_graph_num_edges; j++) { - my_out_graph[j] = mst_graph(i, j); + for (uint32_t k = 0; k < output_graph_degree; k++) { +#pragma omp parallel for + for (uint64_t src_id = 0; src_id < graph_size; src_id++) { + const IdxT dest_id = + output_graph_ptr[k + (static_cast(output_graph_degree) * src_id)]; + if (dest_id >= graph_size) continue; + uint32_t pos; +#pragma omp atomic capture + pos = rev_graph_count_ptr[dest_id]++; + if (pos < output_graph_degree) { + rev_graph_ptr[(static_cast(output_graph_degree) * dest_id) + pos] = + static_cast(src_id); + } } + } - // Set pruned graph edges - for (uint32_t pruned_j = 0, output_j = mst_graph_num_edges; - (pruned_j < output_graph_degree) && (output_j < output_graph_degree); - pruned_j++) { - const auto v = output_graph_ptr[output_graph_degree * i + pruned_j]; - - // duplication check - bool dup = false; - for (uint32_t m = 0; m < output_j; m++) { - if (v == my_out_graph[m]) { - dup = true; - break; - } + const double time_make_end = cur_time(); + RAFT_LOG_DEBUG("# Making reverse graph time (CPU): %.1lf ms", + (time_make_end - time_make_start) * 1000.0); + } + + { + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/combine"); + // + // Create search graphs from MST and pruned and reverse graphs + // + const double time_replace_start = cur_time(); + + bool check_num_protected_edges = true; +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + auto my_rev_graph = rev_graph.data_handle() + (output_graph_degree * i); + auto my_out_graph = output_graph_ptr + (output_graph_degree * i); + + // If guarantee_connectivity == true, use a temporal list to merge the neighbor lists of the + // graphs. + std::vector temp_output_neighbor_list; + if (guarantee_connectivity) { + temp_output_neighbor_list.resize(output_graph_degree); + my_out_graph = temp_output_neighbor_list.data(); + const auto mst_graph_num_edges = mst_graph_num_edges_ptr[i]; + + // Set MST graph edges + for (uint32_t j = 0; j < mst_graph_num_edges; j++) { + my_out_graph[j] = mst_graph(i, j); } - if (!dup) { - my_out_graph[output_j] = v; - output_j++; + // Set pruned graph edges + for (uint32_t pruned_j = 0, output_j = mst_graph_num_edges; + (pruned_j < output_graph_degree) && (output_j < output_graph_degree); + pruned_j++) { + const auto v = output_graph_ptr[output_graph_degree * i + pruned_j]; + + // duplication check + bool dup = false; + for (uint32_t m = 0; m < output_j; m++) { + if (v == my_out_graph[m]) { + dup = true; + break; + } + } + + if (!dup) { + my_out_graph[output_j] = v; + output_j++; + } } } - } - const auto num_protected_edges = - std::max(mst_graph_num_edges_ptr[i], output_graph_degree / 2); - if (num_protected_edges > output_graph_degree) { check_num_protected_edges = false; } - if (num_protected_edges == output_graph_degree) continue; - - // Replace some edges of the output graph with edges of the reverse graph. - auto kr = std::min(rev_graph_count.data_handle()[i], output_graph_degree); - while (kr) { - kr -= 1; - if (my_rev_graph[kr] < graph_size) { - uint64_t pos = pos_in_array(my_rev_graph[kr], my_out_graph, output_graph_degree); - if (pos < num_protected_edges) { continue; } - uint64_t num_shift = pos - num_protected_edges; - if (pos >= output_graph_degree) { - num_shift = output_graph_degree - num_protected_edges - 1; + const auto num_protected_edges = + std::max(mst_graph_num_edges_ptr[i], output_graph_degree / 2); + if (num_protected_edges > output_graph_degree) { check_num_protected_edges = false; } + if (num_protected_edges == output_graph_degree) continue; + + // Replace some edges of the output graph with edges of the reverse graph. + auto kr = std::min(rev_graph_count.data_handle()[i], output_graph_degree); + while (kr) { + kr -= 1; + if (my_rev_graph[kr] < graph_size) { + uint64_t pos = pos_in_array(my_rev_graph[kr], my_out_graph, output_graph_degree); + if (pos < num_protected_edges) { continue; } + uint64_t num_shift = pos - num_protected_edges; + if (pos >= output_graph_degree) { + num_shift = output_graph_degree - num_protected_edges - 1; + } + shift_array(my_out_graph + num_protected_edges, num_shift); + my_out_graph[num_protected_edges] = my_rev_graph[kr]; } - shift_array(my_out_graph + num_protected_edges, num_shift); - my_out_graph[num_protected_edges] = my_rev_graph[kr]; } - } - // If guarantee_connectivity == true, move the output neighbor list from the temporal list to - // the output list. If false, the copy is not needed because my_out_graph is a pointer to the - // output buffer. - if (guarantee_connectivity) { - for (uint32_t j = 0; j < output_graph_degree; j++) { - output_graph_ptr[(output_graph_degree * i) + j] = my_out_graph[j]; + // If guarantee_connectivity == true, move the output neighbor list from the temporal list + // to the output list. If false, the copy is not needed because my_out_graph is a pointer to + // the output buffer. + if (guarantee_connectivity) { + for (uint32_t j = 0; j < output_graph_degree; j++) { + output_graph_ptr[(output_graph_degree * i) + j] = my_out_graph[j]; + } } } - } - RAFT_EXPECTS(check_num_protected_edges, - "Failed to merge the MST, pruned, and reverse edge graphs. Some nodes have too " - "many MST optimization edges."); + RAFT_EXPECTS(check_num_protected_edges, + "Failed to merge the MST, pruned, and reverse edge graphs. Some nodes have too " + "many MST optimization edges."); - const double time_replace_end = cur_time(); - RAFT_LOG_DEBUG("# Replacing edges time: %.1lf ms", - (time_replace_end - time_replace_start) * 1000.0); + const double time_replace_end = cur_time(); + RAFT_LOG_DEBUG("# Replacing edges time: %.1lf ms", + (time_replace_end - time_replace_start) * 1000.0); + } + } + // Check stats + { + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/stats"); /* stats */ uint64_t num_replaced_edges = 0; #pragma omp parallel for reduction(+ : num_replaced_edges) From 822faea739f9b77f13642c8201090273e27d32bc Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Fri, 20 Feb 2026 12:14:35 +0000 Subject: [PATCH 04/40] some fixes, cleanup --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 137 ++++++++++-------- 1 file changed, 77 insertions(+), 60 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index f2cd79ecb6..1b7e46e535 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -173,7 +173,7 @@ __global__ void kern_prune(const IdxT* const knn_graph, // [graph_chunk_size, g uint64_t* const num_full = stats + 1; const uint64_t iA = blockIdx.x + (batch_size * batch_id); - const uint64_t iA_batch = iA % static_cast(batch_size); + const uint64_t iA_batch = blockIdx.x; if (iA >= graph_size) { return; } @@ -246,66 +246,69 @@ __global__ void kern_make_rev_graph(const IdxT* const dest_nodes, // [graph_ } } -// Select output_graph_degree neighbors with smallest detour count per node (writes to device). -template +// Based on the detour count, select the smallest detour count and its index +// (Pruning Update Kernel) +template __global__ void kern_select_smallest_detour_neighbors( - const IdxT* const knn_graph, + const IdxT* const knn_graph, // [graph_chunk_size, graph_degree] uint64_t graph_size, uint64_t knn_graph_degree, uint64_t output_graph_degree, - const uint8_t* const d_detour_count, // [batch_size, graph_degree] - IdxT* output_graph_ptr, // [batch_size, output_graph_degree] - const uint32_t batch_size, - const uint32_t batch_id) + uint8_t* const d_detour_count, // [batch_size, graph_degree] + IdxT* output_graph_ptr, + const uint32_t batch_size, // [batch_size, output_graph_degree] + const uint32_t batch_id, + uint32_t* const d_invalid_neighbor_list) { - // FIXME: this does not really work for num_warps > 1 - constexpr unsigned warp_mask = 0xffffffff; - const uint32_t num_warps = blockDim.x / raft::WarpSize; - extern __shared__ unsigned char smem_buf[]; - uint32_t* smem_indices = reinterpret_cast(smem_buf); - uint16_t* smem_detour_count = - reinterpret_cast(&smem_indices[knn_graph_degree * num_warps]); + assert(blockDim.x == 32); - const uint32_t wid = threadIdx.x / raft::WarpSize; - const uint32_t lane_id = threadIdx.x % raft::WarpSize; - const uint64_t nid = static_cast(blockIdx.x) * num_warps + - (static_cast(batch_size) * batch_id * num_warps) + wid; + // Allocate shared memory for detour counts and their indices + extern __shared__ IdxT smem_indices[]; + uint16_t* smem_detour_count = (uint16_t*)&smem_indices[knn_graph_degree]; - const uint64_t nid_batch = nid % static_cast(batch_size); + const uint64_t nid = blockIdx.x + (batch_size * batch_id); + const uint64_t nid_batch = blockIdx.x; - if (nid >= graph_size) return; + if (nid >= graph_size) { return; } - for (uint32_t k = lane_id; k < knn_graph_degree; k += raft::WarpSize) { - smem_detour_count[(knn_graph_degree * wid) + k] = - d_detour_count[nid_batch * knn_graph_degree + k]; - smem_indices[(knn_graph_degree * wid) + k] = k; + // Each uint64_t loads detour_count for its assigned k + for (uint32_t k = threadIdx.x; k < knn_graph_degree; k += blockDim.x) { + smem_detour_count[k] = d_detour_count[nid_batch * knn_graph_degree + k]; + smem_indices[k] = knn_graph[knn_graph_degree * nid + k]; } - __syncwarp(warp_mask); + __syncwarp(); + + const unsigned warp_mask = 0xffffffff; for (uint32_t i = 0; i < output_graph_degree; i++) { - uint32_t local_min = 256; + uint32_t local_min = 255; uint32_t local_idx = 0xffffffff; - for (uint32_t k = lane_id; k < knn_graph_degree; k += raft::WarpSize) { - uint32_t c = smem_detour_count[(knn_graph_degree * wid) + k]; - if (c < local_min) { - local_min = c; - local_idx = smem_indices[(knn_graph_degree * wid) + k]; + for (uint32_t k = threadIdx.x; k < knn_graph_degree; k += blockDim.x) { + if (smem_detour_count[k] < local_min) { + local_min = smem_detour_count[k]; + local_idx = k; } } - uint32_t local_min_with_tag = (local_min << 16) | local_idx; - for (int offset = raft::WarpSize / 2; offset > 0; offset /= 2) { - uint32_t other = __shfl_down_sync(warp_mask, local_min_with_tag, offset); - local_min_with_tag = (local_min_with_tag <= other) ? local_min_with_tag : other; + + uint32_t local_min_with_tag = (local_min << 16) | ((uint32_t)local_idx); + uint32_t warp_min_with_tag = __reduce_min_sync(warp_mask, local_min_with_tag); + uint32_t warp_min_count = warp_min_with_tag >> 16; + uint32_t warp_local_idx = warp_min_with_tag & 0xffff; + + if (warp_min_count == 255) { + // No valid position left; set error flag and fill remaining slots with sentinel + if (threadIdx.x == 0) { atomicExch(d_invalid_neighbor_list, 1u); } + break; } - uint32_t warp_min_tag = __shfl_sync(warp_mask, local_min_with_tag, 0); - uint32_t warp_local_idx = warp_min_tag & 0xffff; - if (local_idx == warp_local_idx) { - output_graph_ptr[nid_batch * output_graph_degree + i] = - knn_graph[knn_graph_degree * nid + warp_local_idx]; - smem_detour_count[knn_graph_degree * wid + warp_local_idx] = 255; + IdxT selected_node = smem_indices[warp_local_idx]; + + for (uint32_t k = threadIdx.x; k < knn_graph_degree; k += blockDim.x) { + if (smem_indices[k] == selected_node) { smem_detour_count[k] = 255; } } __syncwarp(warp_mask); + + if (threadIdx.x == 0) { output_graph_ptr[nid_batch * output_graph_degree + i] = selected_node; } } } @@ -350,19 +353,18 @@ __global__ void kern_merge_graph(IdxT* output_graph, extern __shared__ unsigned char smem_buf[]; IdxT* smem_sorted_output_graph = reinterpret_cast(smem_buf); - const uint32_t wid = threadIdx.x / 32; - const uint32_t lane_id = threadIdx.x % 32; - const uint32_t num_warps = blockDim.x / 32; - const uint64_t nid = blockIdx.x * num_warps + (batch_size * batch_id * num_warps) + wid; + assert(blockDim.x == 32); + + const uint64_t nid = blockIdx.x + (batch_size * batch_id); if (nid >= graph_size) { return; } - if (lane_id == 0) check_num_protected_edges[0] = true; + if (threadIdx.x == 0) check_num_protected_edges[0] = true; const auto mst_graph_num_edges = mst_graph_num_edges_ptr[nid]; // If guarantee_connectivity == true, use a temporal list to merge the // neighbor lists of the graphs. if (guarantee_connectivity) { - for (uint32_t i = lane_id; i < mst_graph_degree; i += 32) { + for (uint32_t i = threadIdx.x; i < mst_graph_degree; i += 32) { smem_sorted_output_graph[i] = mst_graph[nid * mst_graph_degree + i]; } __syncwarp(); @@ -371,7 +373,7 @@ __global__ void kern_merge_graph(IdxT* output_graph, pruned_j++) { const auto v = output_graph[output_graph_degree * nid + pruned_j]; unsigned int dup = 0; - for (uint32_t m = lane_id; m < output_j; m += 32) { + for (uint32_t m = threadIdx.x; m < output_j; m += 32) { if (v == smem_sorted_output_graph[m]) { dup = 1; break; @@ -380,7 +382,7 @@ __global__ void kern_merge_graph(IdxT* output_graph, unsigned int warp_dup = __ballot_sync(0xffffffff, dup); if (warp_dup == 0) { - if (lane_id == 0) smem_sorted_output_graph[output_j] = v; + if (threadIdx.x == 0) smem_sorted_output_graph[output_j] = v; output_j++; } __syncwarp(); @@ -388,7 +390,7 @@ __global__ void kern_merge_graph(IdxT* output_graph, } else { - for (uint32_t i = lane_id; i < output_graph_degree; i += 32) { + for (uint32_t i = threadIdx.x; i < output_graph_degree; i += 32) { smem_sorted_output_graph[i] = output_graph[output_graph_degree * nid + i]; } __syncwarp(); @@ -409,7 +411,7 @@ __global__ void kern_merge_graph(IdxT* output_graph, if (pos < num_protected_edges) { continue; } uint64_t num_shift = pos - num_protected_edges; if (pos >= output_graph_degree) { num_shift = output_graph_degree - num_protected_edges - 1; } - if (lane_id == 0) { + if (threadIdx.x == 0) { thread_shift_array(smem_sorted_output_graph + num_protected_edges, num_shift); smem_sorted_output_graph[num_protected_edges] = rev_graph[kr + (output_graph_degree * nid)]; } @@ -417,7 +419,7 @@ __global__ void kern_merge_graph(IdxT* output_graph, } } - for (uint32_t i = lane_id; i < output_graph_degree; i += 32) { + for (uint32_t i = threadIdx.x; i < output_graph_degree; i += 32) { output_graph[(output_graph_degree * nid) + i] = smem_sorted_output_graph[i]; } } @@ -1477,6 +1479,7 @@ void optimize( res, large_tmp_mr, raft::make_extents(batch_size)); auto d_output_graph = raft::make_device_mdarray( res, large_tmp_mr, raft::make_extents(batch_size, output_graph_degree)); + auto d_invalid_neighbor_list = raft::make_device_scalar(res, 0u); for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { // initialize the detour_count and num_no_detour_edges for the current batch @@ -1507,8 +1510,7 @@ void optimize( dev_stats.data_handle()); // select smallest-detour neighbors for the current batch - const size_t select_smem_size = - (knn_graph_degree * knn_graph_degree) * (sizeof(uint16_t) + sizeof(uint32_t)); + const size_t select_smem_size = (knn_graph_degree) * (sizeof(uint16_t) + sizeof(IdxT)); const dim3 threads_select(32, 1, 1); const dim3 blocks_select(batch_size, 1, 1); kern_select_smallest_detour_neighbors @@ -1522,10 +1524,11 @@ void optimize( d_detour_count.data_handle(), d_output_graph.data_handle(), batch_size, - i_batch); + i_batch, + d_invalid_neighbor_list.data_handle()); - raft::copy(output_graph_ptr, - d_output_graph.data_handle() + i_batch * batch_size * output_graph_degree, + raft::copy(output_graph_ptr + i_batch * batch_size * output_graph_degree, + d_output_graph.data_handle(), static_cast(batch_size) * output_graph_degree, raft::resource::get_cuda_stream(res)); @@ -1537,6 +1540,18 @@ void optimize( raft::resource::sync_stream(res); RAFT_LOG_DEBUG("\n"); + uint32_t invalid_neighbor_list = 0; + raft::copy(&invalid_neighbor_list, + d_invalid_neighbor_list.data_handle(), + 1, + raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); + RAFT_EXPECTS( + invalid_neighbor_list == 0, + "Could not generate an intermediate CAGRA graph because the initial kNN graph contains too " + "many invalid or duplicated neighbor nodes. This error can occur, for example, if too many " + "overflows occur during the norm computation between the dataset vectors."); + raft::copy( host_stats.data_handle(), dev_stats.data_handle(), 2, raft::resource::get_cuda_stream(res)); num_keep = host_stats.data_handle()[0]; @@ -1642,6 +1657,8 @@ void optimize( raft::make_device_mdarray(res, large_tmp_mr, raft::make_extents(graph_size)); auto d_rev_graph = raft::make_device_mdarray( res, large_tmp_mr, raft::make_extents(graph_size, output_graph_degree)); + auto d_output_graph = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(graph_size, output_graph_degree)); } catch (std::bad_alloc& e) { RAFT_LOG_DEBUG("Insufficient memory for reverse graph on GPU"); _use_gpu_rev_graph = false; @@ -1760,9 +1777,7 @@ void optimize( d_check_num_protected_edges.data_handle(), 1, raft::resource::get_cuda_stream(res)); - raft::resource::sync_stream(res); - // TODO: is this required? if (d_output_graph.allocated_memory()) { raft::copy(output_graph_ptr, d_output_graph.data_handle(), @@ -1770,6 +1785,8 @@ void optimize( raft::resource::get_cuda_stream(res)); } + raft::resource::sync_stream(res); + const auto merge_graph_end = cur_time(); RAFT_EXPECTS(check_num_protected_edges, "Failed to merge the MST, pruned, and reverse edge graphs. " From 9b1f741ca39eb4624a08b51d54a2429fa6b08eff Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Wed, 25 Feb 2026 15:41:10 +0000 Subject: [PATCH 05/40] some fixes --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 1b7e46e535..8705f555b6 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -268,21 +268,23 @@ __global__ void kern_select_smallest_detour_neighbors( const uint64_t nid = blockIdx.x + (batch_size * batch_id); const uint64_t nid_batch = blockIdx.x; + const uint32_t maxval16 = 0x0000ffff; if (nid >= graph_size) { return; } - // Each uint64_t loads detour_count for its assigned k + // Load indices and detour counts for each neighbor; invalidate out-of-bounds entries for (uint32_t k = threadIdx.x; k < knn_graph_degree; k += blockDim.x) { - smem_detour_count[k] = d_detour_count[nid_batch * knn_graph_degree + k]; smem_indices[k] = knn_graph[knn_graph_degree * nid + k]; + smem_detour_count[k] = (smem_indices[k] >= graph_size) + ? maxval16 + : (uint16_t)d_detour_count[nid_batch * knn_graph_degree + k]; } __syncwarp(); const unsigned warp_mask = 0xffffffff; - for (uint32_t i = 0; i < output_graph_degree; i++) { - uint32_t local_min = 255; - uint32_t local_idx = 0xffffffff; + uint32_t local_min = maxval16; + uint32_t local_idx = maxval16; for (uint32_t k = threadIdx.x; k < knn_graph_degree; k += blockDim.x) { if (smem_detour_count[k] < local_min) { local_min = smem_detour_count[k]; @@ -295,8 +297,7 @@ __global__ void kern_select_smallest_detour_neighbors( uint32_t warp_min_count = warp_min_with_tag >> 16; uint32_t warp_local_idx = warp_min_with_tag & 0xffff; - if (warp_min_count == 255) { - // No valid position left; set error flag and fill remaining slots with sentinel + if (warp_min_count == maxval16 || warp_local_idx == maxval16) { if (threadIdx.x == 0) { atomicExch(d_invalid_neighbor_list, 1u); } break; } @@ -304,7 +305,7 @@ __global__ void kern_select_smallest_detour_neighbors( IdxT selected_node = smem_indices[warp_local_idx]; for (uint32_t k = threadIdx.x; k < knn_graph_degree; k += blockDim.x) { - if (smem_indices[k] == selected_node) { smem_detour_count[k] = 255; } + if (smem_indices[k] == selected_node) { smem_detour_count[k] = maxval16; } } __syncwarp(warp_mask); @@ -1355,7 +1356,11 @@ void optimize( { RAFT_LOG_DEBUG( "# Pruning kNN graph (size=%lu, degree=%lu)\n", knn_graph.extent(0), knn_graph.extent(1)); + + // large temporary memory for large arrays, e.g. everything >= O(graph_size) auto large_tmp_mr = raft::resource::get_large_workspace_resource(res); + // temporary memory for small arrays, e.g. everything <= O(batchsize * graph_degree) + // auto tmp_mr = raft::resource::get_tmp_workspace_resource(res); RAFT_EXPECTS(knn_graph.extent(0) == new_graph.extent(0), "Each input array is expected to have the same number of rows"); @@ -1527,9 +1532,12 @@ void optimize( i_batch, d_invalid_neighbor_list.data_handle()); + size_t copy_size = + std::min(static_cast(batch_size), graph_size - i_batch * batch_size) * + output_graph_degree; raft::copy(output_graph_ptr + i_batch * batch_size * output_graph_degree, d_output_graph.data_handle(), - static_cast(batch_size) * output_graph_degree, + copy_size, raft::resource::get_cuda_stream(res)); raft::resource::sync_stream(res); From ecf3b1db009a78adaeea268ff51d1c2d79763eb9 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Fri, 27 Feb 2026 15:39:01 +0000 Subject: [PATCH 06/40] extract prune into separate function --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 497 +++++++++--------- 1 file changed, 245 insertions(+), 252 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 8705f555b6..70cd29aa4a 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -1343,6 +1343,246 @@ void count_2hop_detours(raft::host_matrix_view k } } +// +// Prune unimportant edges based on 2-hop detour counts. +// +// The edge to be retained is determined without explicitly considering distance or angle. +// Suppose the edge is the k-th edge of some node-A to node-B (A->B). Among the edges +// originating at node-A, there are k-1 edges shorter than the edge A->B. Each of these +// k-1 edges are connected to a different k-1 nodes. Among these k-1 nodes, count the +// number of nodes with edges to node-B, which is the number of 2-hop detours for the +// edge A->B. Once the number of 2-hop detours has been counted for all edges, the +// specified number of edges are picked up for each node, starting with the edge with +// the lowest number of 2-hop detours. +// +template +void prune_graph(raft::resources const& res, + InputMatrixView knn_graph, + OutputMatrixView output_graph, + bool use_gpu) +{ + const uint64_t graph_size = output_graph.extent(0); + const uint64_t knn_graph_degree = knn_graph.extent(1); + const uint64_t output_graph_degree = output_graph.extent(1); + auto output_graph_ptr = output_graph.data_handle(); + + auto large_tmp_mr = raft::resource::get_large_workspace_resource(res); + + uint32_t batch_size = + std::min(static_cast(graph_size), static_cast(256 * 1024)); + const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; + + bool use_gpu_prune = use_gpu; + if (use_gpu_prune) { + try { + auto d_detour_count = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(batch_size, knn_graph_degree)); + auto d_num_no_detour_edges = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(batch_size)); + auto d_output_graph = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(batch_size, output_graph_degree)); + auto d_input_graph = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(graph_size, knn_graph_degree)); + } catch (std::bad_alloc& e) { + RAFT_LOG_DEBUG("Insufficient memory for pruning on GPU"); + use_gpu_prune = false; + } catch (raft::logic_error& e) { + RAFT_LOG_DEBUG("Insufficient memory for pruning on GPU (logic error)"); + use_gpu_prune = false; + } + } + + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/prune"); + const double time_prune_start = cur_time(); + + if (use_gpu_prune) { + RAFT_LOG_DEBUG("# Pruning kNN Graph on GPUs\r"); + + constexpr int MAX_DEGREE = 1024; + if (knn_graph_degree > MAX_DEGREE) { + RAFT_FAIL( + "The degree of input knn graph is too large (%zu). " + "It must be equal to or smaller than %d.", + knn_graph_degree, + MAX_DEGREE); + } + + const double prune_start = cur_time(); + + uint64_t num_keep __attribute__((unused)) = 0; + uint64_t num_full __attribute__((unused)) = 0; + auto dev_stats = raft::make_device_vector(res, 2); + auto host_stats = raft::make_host_vector(2); + RAFT_CUDA_TRY(cudaMemsetAsync( + dev_stats.data_handle(), 0, sizeof(uint64_t) * 2, raft::resource::get_cuda_stream(res))); + + device_matrix_view_from_host d_input_graph( + res, + raft::make_host_matrix_view( + knn_graph.data_handle(), graph_size, knn_graph_degree)); + + auto d_detour_count = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(batch_size, knn_graph_degree)); + auto d_num_no_detour_edges = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(batch_size)); + auto d_output_graph = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(batch_size, output_graph_degree)); + auto d_invalid_neighbor_list = raft::make_device_scalar(res, 0u); + + for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { + RAFT_CUDA_TRY(cudaMemsetAsync(d_detour_count.data_handle(), + 0xff, + batch_size * knn_graph_degree * sizeof(uint8_t), + raft::resource::get_cuda_stream(res))); + + RAFT_CUDA_TRY(cudaMemsetAsync(d_num_no_detour_edges.data_handle(), + 0x00, + batch_size * sizeof(uint32_t), + raft::resource::get_cuda_stream(res))); + + const dim3 threads_prune(32, 1, 1); + const dim3 blocks_prune(batch_size, 1, 1); + const size_t prune_smem_size = knn_graph_degree * sizeof(IdxT); + kern_prune + <<>>( + d_input_graph.data_handle(), + graph_size, + knn_graph_degree, + output_graph_degree, + batch_size, + i_batch, + d_detour_count.data_handle(), + d_num_no_detour_edges.data_handle(), + dev_stats.data_handle()); + + const size_t select_smem_size = (knn_graph_degree) * (sizeof(uint16_t) + sizeof(IdxT)); + const dim3 threads_select(32, 1, 1); + const dim3 blocks_select(batch_size, 1, 1); + kern_select_smallest_detour_neighbors + <<>>( + d_input_graph.data_handle(), + graph_size, + knn_graph_degree, + output_graph_degree, + d_detour_count.data_handle(), + d_output_graph.data_handle(), + batch_size, + i_batch, + d_invalid_neighbor_list.data_handle()); + + size_t copy_size = + std::min(static_cast(batch_size), graph_size - i_batch * batch_size) * + output_graph_degree; + raft::copy(output_graph_ptr + i_batch * batch_size * output_graph_degree, + d_output_graph.data_handle(), + copy_size, + raft::resource::get_cuda_stream(res)); + + raft::resource::sync_stream(res); + RAFT_LOG_DEBUG( + "# Pruning kNN Graph on GPUs (%.1lf %%)\r", + (double)std::min((i_batch + 1) * batch_size, graph_size) / graph_size * 100); + } + raft::resource::sync_stream(res); + RAFT_LOG_DEBUG("\n"); + + uint32_t invalid_neighbor_list = 0; + raft::copy(&invalid_neighbor_list, + d_invalid_neighbor_list.data_handle(), + 1, + raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); + RAFT_EXPECTS( + invalid_neighbor_list == 0, + "Could not generate an intermediate CAGRA graph because the initial kNN graph contains too " + "many invalid or duplicated neighbor nodes. This error can occur, for example, if too many " + "overflows occur during the norm computation between the dataset vectors."); + + raft::copy( + host_stats.data_handle(), dev_stats.data_handle(), 2, raft::resource::get_cuda_stream(res)); + num_keep = host_stats.data_handle()[0]; + num_full = host_stats.data_handle()[1]; + + const double prune_end = cur_time(); + RAFT_LOG_DEBUG( + "# Time for pruning on GPU: %.1lf sec, " + "avg_no_detour_edges_per_node: %.2lf/%u, " + "nodes_with_no_detour_at_all_edges: %.1lf%%", + prune_end - prune_start, + (double)num_keep / graph_size, + output_graph_degree, + (double)num_full / graph_size * 100); + } else { + auto detour_count = raft::make_host_matrix(graph_size, knn_graph_degree); + + { + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/prune/2-hop-counting-by-CPU"); + const double time_2hop_count_start = cur_time(); + + auto knn_graph_view = raft::make_host_matrix_view( + knn_graph.data_handle(), knn_graph.extent(0), knn_graph.extent(1)); + count_2hop_detours(knn_graph_view, detour_count.view()); + + const double time_2hop_count_end = cur_time(); + RAFT_LOG_DEBUG("# Time for 2-hop detour counting on CPU: %.1lf sec", + time_2hop_count_end - time_2hop_count_start); + } + bool invalid_neighbor_list = false; +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + uint64_t pk = 0; + uint32_t num_detour = 0; + for (uint32_t l = 0; l < knn_graph_degree && pk < output_graph_degree; l++) { + uint32_t next_num_detour = std::numeric_limits::max(); + for (uint64_t k = 0; k < knn_graph_degree; k++) { + const auto num_detour_k = detour_count(i, k); + if (num_detour_k > num_detour) { + next_num_detour = std::min(static_cast(num_detour_k), next_num_detour); + } + + if (num_detour_k != num_detour) { continue; } + + const auto candidate_node = knn_graph(i, k); + bool dup = false; + for (uint32_t dk = 0; dk < pk; dk++) { + if (candidate_node == output_graph_ptr[i * output_graph_degree + dk]) { + dup = true; + break; + } + } + if (!dup && candidate_node < graph_size) { + output_graph_ptr[i * output_graph_degree + pk] = candidate_node; + pk += 1; + } + if (pk >= output_graph_degree) break; + } + if (pk >= output_graph_degree) break; + + if (next_num_detour == std::numeric_limits::max()) { break; } + num_detour = next_num_detour; + } + if (pk != output_graph_degree) { + RAFT_LOG_DEBUG( + "Couldn't find the output_graph_degree (%lu) smallest detourable count nodes for " + "node %lu in the rank-based node reranking process", + output_graph_degree, + i); + invalid_neighbor_list = true; + } + } + RAFT_EXPECTS( + !invalid_neighbor_list, + "Could not generate an intermediate CAGRA graph because the initial kNN graph contains too " + "many invalid or duplicated neighbor nodes. This error can occur, for example, if too many " + "overflows occur during the norm computation between the dataset vectors."); + } + + const double time_prune_end = cur_time(); + RAFT_LOG_DEBUG("# Pruning time: %.1lf ms", (time_prune_end - time_prune_start) * 1000.0); +} + // TODO allow pinned input for both knn_graph and new_graph template (graph_size), static_cast(256 * 1024)); - const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; - - // - // If the available device memory is insufficient, do not use the GPU to count - // the number of 2-hop detours, but use the CPU. - // - // TODO: we should decide on a global strategy for this in a single place - // it comes down to input memory type and available memory which data should be copied to GPU - bool _use_gpu_prune = use_gpu; - if (_use_gpu_prune) { - try { - auto d_detour_count = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(batch_size, knn_graph_degree)); - auto d_num_no_detour_edges = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(batch_size)); - auto d_output_graph = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(batch_size, output_graph_degree)); - // TODO we also want to consider pinned memory in case we are short on memory - auto d_input_graph = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size, knn_graph_degree)); - } catch (std::bad_alloc& e) { - RAFT_LOG_DEBUG("Insufficient memory for pruning on GPU"); - _use_gpu_prune = false; - } catch (raft::logic_error& e) { - RAFT_LOG_DEBUG("Insufficient memory for pruning on GPU (logic error)"); - _use_gpu_prune = false; - } - } - - { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/prune"); - const double time_prune_start = cur_time(); - - // - // Prune unimportant edges. - // - // The edge to be retained is determined without explicitly considering - // distance or angle. Suppose the edge is the k-th edge of some node-A to - // node-B (A->B). Among the edges originating at node-A, there are k-1 edges - // shorter than the edge A->B. Each of these k-1 edges are connected to a - // different k-1 nodes. Among these k-1 nodes, count the number of nodes with - // edges to node-B, which is the number of 2-hop detours for the edge A->B. - // Once the number of 2-hop detours has been counted for all edges, the - // specified number of edges are picked up for each node, starting with the - // edge with the lowest number of 2-hop detours. - // - if (_use_gpu_prune) { - // Pruning on GPU - RAFT_LOG_DEBUG("# Pruning kNN Graph on GPUs\r"); - - constexpr int MAX_DEGREE = 1024; - if (knn_graph_degree > MAX_DEGREE) { - RAFT_FAIL( - "The degree of input knn graph is too large (%zu). " - "It must be equal to or smaller than %d.", - knn_graph_degree, - MAX_DEGREE); - } - - const double prune_start = cur_time(); - - uint64_t num_keep __attribute__((unused)) = 0; - uint64_t num_full __attribute__((unused)) = 0; - auto dev_stats = raft::make_device_vector(res, 2); - auto host_stats = raft::make_host_vector(2); - RAFT_CUDA_TRY(cudaMemsetAsync( - dev_stats.data_handle(), 0, sizeof(uint64_t) * 2, raft::resource::get_cuda_stream(res))); - - // Copy knn_graph over to device if necessary - // TODO: should we use pinned memory if we have issues fitting on GPU? - device_matrix_view_from_host d_input_graph( - res, - raft::make_host_matrix_view( - knn_graph.data_handle(), graph_size, knn_graph_degree)); - - // data structures per batch - auto d_detour_count = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(batch_size, knn_graph_degree)); - auto d_num_no_detour_edges = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(batch_size)); - auto d_output_graph = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(batch_size, output_graph_degree)); - auto d_invalid_neighbor_list = raft::make_device_scalar(res, 0u); - - for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { - // initialize the detour_count and num_no_detour_edges for the current batch - RAFT_CUDA_TRY(cudaMemsetAsync(d_detour_count.data_handle(), - 0xff, - batch_size * knn_graph_degree * sizeof(uint8_t), - raft::resource::get_cuda_stream(res))); - - RAFT_CUDA_TRY(cudaMemsetAsync(d_num_no_detour_edges.data_handle(), - 0x00, - batch_size * sizeof(uint32_t), - raft::resource::get_cuda_stream(res))); - - // count 2-hop detours for the current batch - const dim3 threads_prune(32, 1, 1); - const dim3 blocks_prune(batch_size, 1, 1); - const size_t prune_smem_size = knn_graph_degree * sizeof(IdxT); - kern_prune - <<>>( - d_input_graph.data_handle(), - graph_size, - knn_graph_degree, - output_graph_degree, - batch_size, - i_batch, - d_detour_count.data_handle(), - d_num_no_detour_edges.data_handle(), - dev_stats.data_handle()); - - // select smallest-detour neighbors for the current batch - const size_t select_smem_size = (knn_graph_degree) * (sizeof(uint16_t) + sizeof(IdxT)); - const dim3 threads_select(32, 1, 1); - const dim3 blocks_select(batch_size, 1, 1); - kern_select_smallest_detour_neighbors - <<>>(d_input_graph.data_handle(), - graph_size, - knn_graph_degree, - output_graph_degree, - d_detour_count.data_handle(), - d_output_graph.data_handle(), - batch_size, - i_batch, - d_invalid_neighbor_list.data_handle()); - - size_t copy_size = - std::min(static_cast(batch_size), graph_size - i_batch * batch_size) * - output_graph_degree; - raft::copy(output_graph_ptr + i_batch * batch_size * output_graph_degree, - d_output_graph.data_handle(), - copy_size, - raft::resource::get_cuda_stream(res)); - - raft::resource::sync_stream(res); - RAFT_LOG_DEBUG( - "# Pruning kNN Graph on GPUs (%.1lf %%)\r", - (double)std::min((i_batch + 1) * batch_size, graph_size) / graph_size * 100); - } - raft::resource::sync_stream(res); - RAFT_LOG_DEBUG("\n"); - - uint32_t invalid_neighbor_list = 0; - raft::copy(&invalid_neighbor_list, - d_invalid_neighbor_list.data_handle(), - 1, - raft::resource::get_cuda_stream(res)); - raft::resource::sync_stream(res); - RAFT_EXPECTS( - invalid_neighbor_list == 0, - "Could not generate an intermediate CAGRA graph because the initial kNN graph contains too " - "many invalid or duplicated neighbor nodes. This error can occur, for example, if too many " - "overflows occur during the norm computation between the dataset vectors."); - - raft::copy( - host_stats.data_handle(), dev_stats.data_handle(), 2, raft::resource::get_cuda_stream(res)); - num_keep = host_stats.data_handle()[0]; - num_full = host_stats.data_handle()[1]; - - const double prune_end = cur_time(); - RAFT_LOG_DEBUG( - "# Time for pruning on GPU: %.1lf sec, " - "avg_no_detour_edges_per_node: %.2lf/%u, " - "nodes_with_no_detour_at_all_edges: %.1lf%%", - prune_end - prune_start, - (double)num_keep / graph_size, - output_graph_degree, - (double)num_full / graph_size * 100); - } else { - // Pruning on CPU - auto detour_count = raft::make_host_matrix(graph_size, knn_graph_degree); - - { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/prune/2-hop-counting-by-CPU"); - const double time_2hop_count_start = cur_time(); - - count_2hop_detours(knn_graph, detour_count.view()); - - const double time_2hop_count_end = cur_time(); - RAFT_LOG_DEBUG("# Time for 2-hop detour counting on CPU: %.1lf sec", - time_2hop_count_end - time_2hop_count_start); - } - bool invalid_neighbor_list = false; -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - // Find the `output_graph_degree` smallest detourable count nodes by checking the detourable - // count of the neighbors while increasing the target detourable count from zero. - uint64_t pk = 0; - uint32_t num_detour = 0; - for (uint32_t l = 0; l < knn_graph_degree && pk < output_graph_degree; l++) { - uint32_t next_num_detour = std::numeric_limits::max(); - for (uint64_t k = 0; k < knn_graph_degree; k++) { - const auto num_detour_k = detour_count(i, k); - // Find the detourable count to check in the next iteration - if (num_detour_k > num_detour) { - next_num_detour = std::min(static_cast(num_detour_k), next_num_detour); - } - - // Store the neighbor index if its detourable count is equal to `num_detour`. - if (num_detour_k != num_detour) { continue; } - - // Check duplication and append - const auto candidate_node = knn_graph(i, k); - bool dup = false; - for (uint32_t dk = 0; dk < pk; dk++) { - if (candidate_node == output_graph_ptr[i * output_graph_degree + dk]) { - dup = true; - break; - } - } - if (!dup && candidate_node < graph_size) { - output_graph_ptr[i * output_graph_degree + pk] = candidate_node; - pk += 1; - } - if (pk >= output_graph_degree) break; - } - if (pk >= output_graph_degree) break; - - if (next_num_detour == std::numeric_limits::max()) { - // There are no valid edges enough in the initial kNN graph. Break the loop here and - // catch the error at the next validation (pk != output_graph_degree). - break; - } - num_detour = next_num_detour; - } - if (pk != output_graph_degree) { - RAFT_LOG_DEBUG( - "Couldn't find the output_graph_degree (%lu) smallest detourable count nodes for " - "node %lu in the rank-based node reranking process", - output_graph_degree, - i); - invalid_neighbor_list = true; - } - } - RAFT_EXPECTS( - !invalid_neighbor_list, - "Could not generate an intermediate CAGRA graph because the initial kNN graph contains too " - "many invalid or duplicated neighbor nodes. This error can occur, for example, if too many " - "overflows occur during the norm computation between the dataset vectors."); - } - - const double time_prune_end = cur_time(); - RAFT_LOG_DEBUG("# Pruning time: %.1lf ms", (time_prune_end - time_prune_start) * 1000.0); - } + prune_graph(res, knn_graph, new_graph, use_gpu); auto rev_graph = raft::make_host_matrix(graph_size, output_graph_degree); auto rev_graph_count = raft::make_host_vector(graph_size); @@ -1760,6 +1749,10 @@ void optimize( // Create a boolean variable on the GPU using RAFT device allocator auto d_check_num_protected_edges = raft::make_device_scalar(res, true); + uint32_t batch_size = + std::min(static_cast(graph_size), static_cast(256 * 1024)); + const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; + const dim3 threads_merge(32, 1, 1); const dim3 blocks_merge(batch_size, 1, 1); const size_t merge_smem_size = (output_graph_degree + output_graph_degree) * sizeof(IdxT); From 972d278c77c05add60e4518150e00b0f0f7898cf Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Mon, 2 Mar 2026 14:41:22 +0000 Subject: [PATCH 07/40] extract optimize components --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 1174 +++++++++-------- 1 file changed, 616 insertions(+), 558 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 70cd29aa4a..713b03ca20 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -674,6 +674,316 @@ void shift_array(T* array, uint64_t num) array[i] = array[i - 1]; } } + +template +void log_replaced_edges_stats(const IdxT* output_graph_ptr, + uint64_t graph_size, + uint64_t output_graph_degree) +{ + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/stats"); + uint64_t num_replaced_edges = 0; +#pragma omp parallel for reduction(+ : num_replaced_edges) + for (uint64_t i = 0; i < graph_size; i++) { + for (uint64_t k = 0; k < output_graph_degree; k++) { + const uint64_t j = output_graph_ptr[k + (output_graph_degree * i)]; + const uint64_t pos = + pos_in_array(j, output_graph_ptr + (output_graph_degree * i), output_graph_degree); + if (pos == output_graph_degree) { num_replaced_edges += 1; } + } + } + RAFT_LOG_DEBUG("# Average number of replaced edges per node: %.2f", + (double)num_replaced_edges / graph_size); +} + +template +void log_incoming_edges_histogram(const IdxT* output_graph_ptr, + uint64_t graph_size, + uint64_t output_graph_degree) +{ + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/check_edges"); + auto in_edge_count = raft::make_host_vector(graph_size); + auto in_edge_count_ptr = in_edge_count.data_handle(); +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + in_edge_count_ptr[i] = 0; + } +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + for (uint64_t k = 0; k < output_graph_degree; k++) { + const uint64_t j = output_graph_ptr[k + (output_graph_degree * i)]; + if (j >= graph_size) continue; +#pragma omp atomic + in_edge_count_ptr[j] += 1; + } + } + auto hist = raft::make_host_vector(output_graph_degree); + auto hist_ptr = hist.data_handle(); + for (uint64_t k = 0; k < output_graph_degree; k++) { + hist_ptr[k] = 0; + } +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + uint32_t count = in_edge_count_ptr[i]; + if (count >= output_graph_degree) continue; +#pragma omp atomic + hist_ptr[count] += 1; + } + RAFT_LOG_DEBUG("# Histogram for number of incoming edges\n"); + uint32_t sum_hist = 0; + for (uint64_t k = 0; k < output_graph_degree; k++) { + sum_hist += hist_ptr[k]; + RAFT_LOG_DEBUG("# %3lu, %8u, %lf, (%8u, %lf)\n", + k, + hist_ptr[k], + (double)hist_ptr[k] / graph_size, + sum_hist, + (double)sum_hist / graph_size); + } +} + +template +void check_duplicates_and_out_of_range(const IdxT* output_graph_ptr, + uint64_t graph_size, + uint64_t output_graph_degree) +{ + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/check_duplicates"); + uint64_t num_dup = 0; + uint64_t num_oor = 0; +#pragma omp parallel for reduction(+ : num_dup) reduction(+ : num_oor) + for (uint64_t i = 0; i < graph_size; i++) { + auto my_out_graph = output_graph_ptr + (output_graph_degree * i); + for (uint32_t j = 0; j < output_graph_degree; j++) { + const auto neighbor_a = my_out_graph[j]; + + if (neighbor_a > graph_size) { + num_oor++; + continue; + } + + for (uint32_t k = j + 1; k < output_graph_degree; k++) { + const auto neighbor_b = my_out_graph[k]; + if (neighbor_a == neighbor_b) { num_dup++; } + } + } + } + RAFT_EXPECTS( + num_dup == 0, "%lu duplicated node(s) are found in the generated CAGRA graph", num_dup); + RAFT_EXPECTS( + num_oor == 0, "%lu out-of-range index node(s) are found in the generated CAGRA graph", num_oor); +} + +template +void merge_graph_gpu(raft::resources const& res, + IdxT* output_graph_ptr, + const IdxT* d_rev_graph, + uint32_t* d_rev_graph_count, + const IdxT* mst_graph_ptr, + const uint32_t* mst_graph_num_edges_ptr, + uint64_t graph_size, + uint64_t output_graph_degree, + bool guarantee_connectivity) +{ + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/combine"); + + const double merge_graph_start = cur_time(); + + device_matrix_view_from_host d_output_graph( + res, + raft::make_host_matrix_view(output_graph_ptr, graph_size, output_graph_degree)); + + auto d_check_num_protected_edges = raft::make_device_scalar(res, true); + + uint32_t batch_size = + std::min(static_cast(graph_size), static_cast(256 * 1024)); + const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; + + const dim3 threads_merge(32, 1, 1); + const dim3 blocks_merge(batch_size, 1, 1); + const size_t merge_smem_size = (output_graph_degree + output_graph_degree) * sizeof(IdxT); + for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { + kern_merge_graph + <<>>( + d_output_graph.data_handle(), + d_rev_graph, + d_rev_graph_count, + static_cast(graph_size), + static_cast(output_graph_degree), + mst_graph_ptr, + static_cast(output_graph_degree), + mst_graph_num_edges_ptr, + batch_size, + i_batch, + guarantee_connectivity, + d_check_num_protected_edges.data_handle()); + } + + bool check_num_protected_edges = true; + raft::copy(&check_num_protected_edges, + d_check_num_protected_edges.data_handle(), + 1, + raft::resource::get_cuda_stream(res)); + + if (d_output_graph.allocated_memory()) { + raft::copy(output_graph_ptr, + d_output_graph.data_handle(), + graph_size * output_graph_degree, + raft::resource::get_cuda_stream(res)); + } + + raft::resource::sync_stream(res); + + const auto merge_graph_end = cur_time(); + RAFT_EXPECTS(check_num_protected_edges, + "Failed to merge the MST, pruned, and reverse edge graphs. " + "Some nodes have too " + "many MST optimization edges."); + + RAFT_LOG_DEBUG("# Time for merging graphs: %.1lf ms", + (merge_graph_end - merge_graph_start) * 1000.0); +} + +template +void merge_graph_cpu(IdxT* output_graph_ptr, + const IdxT* rev_graph_ptr, + const uint32_t* rev_graph_count_ptr, + const IdxT* mst_graph_ptr, + const uint32_t* mst_graph_num_edges_ptr, + uint64_t graph_size, + uint64_t output_graph_degree, + bool guarantee_connectivity) +{ + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/combine"); + + const double time_replace_start = cur_time(); + + bool check_num_protected_edges = true; +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + auto my_rev_graph = rev_graph_ptr + (output_graph_degree * i); + auto my_out_graph = output_graph_ptr + (output_graph_degree * i); + + std::vector temp_output_neighbor_list; + if (guarantee_connectivity) { + temp_output_neighbor_list.resize(output_graph_degree); + my_out_graph = temp_output_neighbor_list.data(); + const auto mst_graph_num_edges = mst_graph_num_edges_ptr[i]; + + for (uint32_t j = 0; j < mst_graph_num_edges; j++) { + my_out_graph[j] = mst_graph_ptr[i * output_graph_degree + j]; + } + + for (uint32_t pruned_j = 0, output_j = mst_graph_num_edges; + (pruned_j < output_graph_degree) && (output_j < output_graph_degree); + pruned_j++) { + const auto v = output_graph_ptr[output_graph_degree * i + pruned_j]; + + bool dup = false; + for (uint32_t m = 0; m < output_j; m++) { + if (v == my_out_graph[m]) { + dup = true; + break; + } + } + + if (!dup) { + my_out_graph[output_j] = v; + output_j++; + } + } + } + + const auto num_protected_edges = + std::max(mst_graph_num_edges_ptr[i], output_graph_degree / 2); + if (num_protected_edges > output_graph_degree) { check_num_protected_edges = false; } + if (num_protected_edges == output_graph_degree) continue; + + auto kr = std::min(rev_graph_count_ptr[i], output_graph_degree); + while (kr) { + kr -= 1; + if (my_rev_graph[kr] < graph_size) { + uint64_t pos = pos_in_array(my_rev_graph[kr], my_out_graph, output_graph_degree); + if (pos < num_protected_edges) { continue; } + uint64_t num_shift = pos - num_protected_edges; + if (pos >= output_graph_degree) { + num_shift = output_graph_degree - num_protected_edges - 1; + } + shift_array(my_out_graph + num_protected_edges, num_shift); + my_out_graph[num_protected_edges] = my_rev_graph[kr]; + } + } + + if (guarantee_connectivity) { + for (uint32_t j = 0; j < output_graph_degree; j++) { + output_graph_ptr[(output_graph_degree * i) + j] = my_out_graph[j]; + } + } + } + RAFT_EXPECTS(check_num_protected_edges, + "Failed to merge the MST, pruned, and reverse edge graphs. Some nodes have too " + "many MST optimization edges."); + + const double time_replace_end = cur_time(); + RAFT_LOG_DEBUG("# Replacing edges time: %.1lf ms", + (time_replace_end - time_replace_start) * 1000.0); +} + +template +void make_reverse_graph_gpu(raft::resources const& res, + IdxT* d_rev_graph, + uint32_t* d_rev_graph_count, + raft::host_matrix_view new_graph) +{ + const uint64_t graph_size = new_graph.extent(0); + const uint64_t output_graph_degree = new_graph.extent(1); + const IdxT* output_graph_ptr = new_graph.data_handle(); + + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/reverse"); + + auto large_tmp_mr = raft::resource::get_large_workspace_resource(res); + auto dest_nodes = raft::make_host_vector(graph_size); + auto d_dest_nodes = + raft::make_device_mdarray(res, large_tmp_mr, raft::make_extents(graph_size)); + + RAFT_CUDA_TRY(cudaMemsetAsync(d_rev_graph, + 0xff, + graph_size * output_graph_degree * sizeof(IdxT), + raft::resource::get_cuda_stream(res))); + + RAFT_CUDA_TRY(cudaMemsetAsync( + d_rev_graph_count, 0x00, graph_size * sizeof(uint32_t), raft::resource::get_cuda_stream(res))); + + for (uint64_t k = 0; k < output_graph_degree; k++) { +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + dest_nodes(i) = output_graph_ptr[k + (output_graph_degree * i)]; + } + raft::resource::sync_stream(res); + + raft::copy(d_dest_nodes.data_handle(), + dest_nodes.data_handle(), + graph_size, + raft::resource::get_cuda_stream(res)); + + dim3 threads(256, 1, 1); + dim3 blocks(1024, 1, 1); + kern_make_rev_graph<<>>( + d_dest_nodes.data_handle(), + d_rev_graph, + d_rev_graph_count, + static_cast(graph_size), + static_cast(output_graph_degree)); + RAFT_LOG_DEBUG("# Making reverse graph on GPUs: %lu / %lu \r", k, output_graph_degree); + } + + raft::resource::sync_stream(res); + RAFT_LOG_DEBUG("\n"); +} } // namespace template k // specified number of edges are picked up for each node, starting with the edge with // the lowest number of 2-hop detours. // -template -void prune_graph(raft::resources const& res, - InputMatrixView knn_graph, - OutputMatrixView output_graph, - bool use_gpu) +template +void prune_graph_gpu(raft::resources const& res, + IdxT* knn_graph_ptr, + uint64_t graph_size, + uint64_t knn_graph_degree, + IdxT* output_graph_ptr, + uint64_t output_graph_degree) { - const uint64_t graph_size = output_graph.extent(0); - const uint64_t knn_graph_degree = knn_graph.extent(1); - const uint64_t output_graph_degree = output_graph.extent(1); - auto output_graph_ptr = output_graph.data_handle(); - - auto large_tmp_mr = raft::resource::get_large_workspace_resource(res); + auto default_ws_mr = raft::resource::get_workspace_resource(res); uint32_t batch_size = std::min(static_cast(graph_size), static_cast(256 * 1024)); const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; - bool use_gpu_prune = use_gpu; - if (use_gpu_prune) { - try { - auto d_detour_count = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(batch_size, knn_graph_degree)); - auto d_num_no_detour_edges = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(batch_size)); - auto d_output_graph = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(batch_size, output_graph_degree)); - auto d_input_graph = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size, knn_graph_degree)); - } catch (std::bad_alloc& e) { - RAFT_LOG_DEBUG("Insufficient memory for pruning on GPU"); - use_gpu_prune = false; - } catch (raft::logic_error& e) { - RAFT_LOG_DEBUG("Insufficient memory for pruning on GPU (logic error)"); - use_gpu_prune = false; - } - } - - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/prune"); - const double time_prune_start = cur_time(); + RAFT_LOG_DEBUG("# Pruning kNN Graph on GPUs\r"); - if (use_gpu_prune) { - RAFT_LOG_DEBUG("# Pruning kNN Graph on GPUs\r"); + constexpr int MAX_DEGREE = 1024; + if (knn_graph_degree > MAX_DEGREE) { + RAFT_FAIL( + "The degree of input knn graph is too large (%zu). " + "It must be equal to or smaller than %d.", + knn_graph_degree, + MAX_DEGREE); + } - constexpr int MAX_DEGREE = 1024; - if (knn_graph_degree > MAX_DEGREE) { - RAFT_FAIL( - "The degree of input knn graph is too large (%zu). " - "It must be equal to or smaller than %d.", + const double prune_start = cur_time(); + + uint64_t num_keep __attribute__((unused)) = 0; + uint64_t num_full __attribute__((unused)) = 0; + auto dev_stats = raft::make_device_vector(res, 2); + auto host_stats = raft::make_host_vector(2); + RAFT_CUDA_TRY(cudaMemsetAsync( + dev_stats.data_handle(), 0, sizeof(uint64_t) * 2, raft::resource::get_cuda_stream(res))); + + device_matrix_view_from_host d_input_graph( + res, raft::make_host_matrix_view(knn_graph_ptr, graph_size, knn_graph_degree)); + + auto d_detour_count = raft::make_device_mdarray( + res, default_ws_mr, raft::make_extents(batch_size, knn_graph_degree)); + auto d_num_no_detour_edges = raft::make_device_mdarray( + res, default_ws_mr, raft::make_extents(batch_size)); + auto d_output_graph = raft::make_device_mdarray( + res, default_ws_mr, raft::make_extents(batch_size, output_graph_degree)); + auto d_invalid_neighbor_list = raft::make_device_scalar(res, 0u); + + for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { + RAFT_CUDA_TRY(cudaMemsetAsync(d_detour_count.data_handle(), + 0xff, + batch_size * knn_graph_degree * sizeof(uint8_t), + raft::resource::get_cuda_stream(res))); + + RAFT_CUDA_TRY(cudaMemsetAsync(d_num_no_detour_edges.data_handle(), + 0x00, + batch_size * sizeof(uint32_t), + raft::resource::get_cuda_stream(res))); + + const dim3 threads_prune(32, 1, 1); + const dim3 blocks_prune(batch_size, 1, 1); + const size_t prune_smem_size = knn_graph_degree * sizeof(IdxT); + kern_prune + <<>>( + d_input_graph.data_handle(), + graph_size, knn_graph_degree, - MAX_DEGREE); - } + output_graph_degree, + batch_size, + i_batch, + d_detour_count.data_handle(), + d_num_no_detour_edges.data_handle(), + dev_stats.data_handle()); + + const size_t select_smem_size = (knn_graph_degree) * (sizeof(uint16_t) + sizeof(IdxT)); + const dim3 threads_select(32, 1, 1); + const dim3 blocks_select(batch_size, 1, 1); + kern_select_smallest_detour_neighbors + <<>>( + d_input_graph.data_handle(), + graph_size, + knn_graph_degree, + output_graph_degree, + d_detour_count.data_handle(), + d_output_graph.data_handle(), + batch_size, + i_batch, + d_invalid_neighbor_list.data_handle()); + + size_t copy_size = + std::min(static_cast(batch_size), graph_size - i_batch * batch_size) * + output_graph_degree; + raft::copy(output_graph_ptr + i_batch * batch_size * output_graph_degree, + d_output_graph.data_handle(), + copy_size, + raft::resource::get_cuda_stream(res)); - const double prune_start = cur_time(); + raft::resource::sync_stream(res); + RAFT_LOG_DEBUG( + "# Pruning kNN Graph on GPUs (%.1lf %%)\r", + (double)std::min((i_batch + 1) * batch_size, graph_size) / graph_size * 100); + } + raft::resource::sync_stream(res); + RAFT_LOG_DEBUG("\n"); - uint64_t num_keep __attribute__((unused)) = 0; - uint64_t num_full __attribute__((unused)) = 0; - auto dev_stats = raft::make_device_vector(res, 2); - auto host_stats = raft::make_host_vector(2); - RAFT_CUDA_TRY(cudaMemsetAsync( - dev_stats.data_handle(), 0, sizeof(uint64_t) * 2, raft::resource::get_cuda_stream(res))); + uint32_t invalid_neighbor_list = 0; + raft::copy(&invalid_neighbor_list, + d_invalid_neighbor_list.data_handle(), + 1, + raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); + RAFT_EXPECTS( + invalid_neighbor_list == 0, + "Could not generate an intermediate CAGRA graph because the initial kNN graph contains too " + "many invalid or duplicated neighbor nodes. This error can occur, for example, if too many " + "overflows occur during the norm computation between the dataset vectors."); - device_matrix_view_from_host d_input_graph( - res, - raft::make_host_matrix_view( - knn_graph.data_handle(), graph_size, knn_graph_degree)); + raft::copy( + host_stats.data_handle(), dev_stats.data_handle(), 2, raft::resource::get_cuda_stream(res)); + num_keep = host_stats.data_handle()[0]; + num_full = host_stats.data_handle()[1]; - auto d_detour_count = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(batch_size, knn_graph_degree)); - auto d_num_no_detour_edges = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(batch_size)); - auto d_output_graph = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(batch_size, output_graph_degree)); - auto d_invalid_neighbor_list = raft::make_device_scalar(res, 0u); - - for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { - RAFT_CUDA_TRY(cudaMemsetAsync(d_detour_count.data_handle(), - 0xff, - batch_size * knn_graph_degree * sizeof(uint8_t), - raft::resource::get_cuda_stream(res))); - - RAFT_CUDA_TRY(cudaMemsetAsync(d_num_no_detour_edges.data_handle(), - 0x00, - batch_size * sizeof(uint32_t), - raft::resource::get_cuda_stream(res))); - - const dim3 threads_prune(32, 1, 1); - const dim3 blocks_prune(batch_size, 1, 1); - const size_t prune_smem_size = knn_graph_degree * sizeof(IdxT); - kern_prune - <<>>( - d_input_graph.data_handle(), - graph_size, - knn_graph_degree, - output_graph_degree, - batch_size, - i_batch, - d_detour_count.data_handle(), - d_num_no_detour_edges.data_handle(), - dev_stats.data_handle()); - - const size_t select_smem_size = (knn_graph_degree) * (sizeof(uint16_t) + sizeof(IdxT)); - const dim3 threads_select(32, 1, 1); - const dim3 blocks_select(batch_size, 1, 1); - kern_select_smallest_detour_neighbors - <<>>( - d_input_graph.data_handle(), - graph_size, - knn_graph_degree, - output_graph_degree, - d_detour_count.data_handle(), - d_output_graph.data_handle(), - batch_size, - i_batch, - d_invalid_neighbor_list.data_handle()); - - size_t copy_size = - std::min(static_cast(batch_size), graph_size - i_batch * batch_size) * - output_graph_degree; - raft::copy(output_graph_ptr + i_batch * batch_size * output_graph_degree, - d_output_graph.data_handle(), - copy_size, - raft::resource::get_cuda_stream(res)); + const double prune_end = cur_time(); + RAFT_LOG_DEBUG( + "# Time for pruning on GPU: %.1lf sec, " + "avg_no_detour_edges_per_node: %.2lf/%u, " + "nodes_with_no_detour_at_all_edges: %.1lf%%", + prune_end - prune_start, + (double)num_keep / graph_size, + output_graph_degree, + (double)num_full / graph_size * 100); +} - raft::resource::sync_stream(res); - RAFT_LOG_DEBUG( - "# Pruning kNN Graph on GPUs (%.1lf %%)\r", - (double)std::min((i_batch + 1) * batch_size, graph_size) / graph_size * 100); - } - raft::resource::sync_stream(res); - RAFT_LOG_DEBUG("\n"); +template +void prune_graph_cpu(IdxT* knn_graph_ptr, + uint64_t graph_size, + uint64_t knn_graph_degree, + IdxT* output_graph_ptr, + uint64_t output_graph_degree) +{ + auto detour_count = raft::make_host_matrix(graph_size, knn_graph_degree); - uint32_t invalid_neighbor_list = 0; - raft::copy(&invalid_neighbor_list, - d_invalid_neighbor_list.data_handle(), - 1, - raft::resource::get_cuda_stream(res)); - raft::resource::sync_stream(res); - RAFT_EXPECTS( - invalid_neighbor_list == 0, - "Could not generate an intermediate CAGRA graph because the initial kNN graph contains too " - "many invalid or duplicated neighbor nodes. This error can occur, for example, if too many " - "overflows occur during the norm computation between the dataset vectors."); - - raft::copy( - host_stats.data_handle(), dev_stats.data_handle(), 2, raft::resource::get_cuda_stream(res)); - num_keep = host_stats.data_handle()[0]; - num_full = host_stats.data_handle()[1]; - - const double prune_end = cur_time(); - RAFT_LOG_DEBUG( - "# Time for pruning on GPU: %.1lf sec, " - "avg_no_detour_edges_per_node: %.2lf/%u, " - "nodes_with_no_detour_at_all_edges: %.1lf%%", - prune_end - prune_start, - (double)num_keep / graph_size, - output_graph_degree, - (double)num_full / graph_size * 100); - } else { - auto detour_count = raft::make_host_matrix(graph_size, knn_graph_degree); + auto knn_graph_view = + raft::make_host_matrix_view(knn_graph_ptr, graph_size, knn_graph_degree); - { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/prune/2-hop-counting-by-CPU"); - const double time_2hop_count_start = cur_time(); + { + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/prune/2-hop-counting-by-CPU"); + const double time_2hop_count_start = cur_time(); - auto knn_graph_view = raft::make_host_matrix_view( - knn_graph.data_handle(), knn_graph.extent(0), knn_graph.extent(1)); - count_2hop_detours(knn_graph_view, detour_count.view()); + count_2hop_detours(knn_graph_view, detour_count.view()); - const double time_2hop_count_end = cur_time(); - RAFT_LOG_DEBUG("# Time for 2-hop detour counting on CPU: %.1lf sec", - time_2hop_count_end - time_2hop_count_start); - } - bool invalid_neighbor_list = false; + const double time_2hop_count_end = cur_time(); + RAFT_LOG_DEBUG("# Time for 2-hop detour counting on CPU: %.1lf sec", + time_2hop_count_end - time_2hop_count_start); + } + bool invalid_neighbor_list = false; #pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - uint64_t pk = 0; - uint32_t num_detour = 0; - for (uint32_t l = 0; l < knn_graph_degree && pk < output_graph_degree; l++) { - uint32_t next_num_detour = std::numeric_limits::max(); - for (uint64_t k = 0; k < knn_graph_degree; k++) { - const auto num_detour_k = detour_count(i, k); - if (num_detour_k > num_detour) { - next_num_detour = std::min(static_cast(num_detour_k), next_num_detour); - } + for (uint64_t i = 0; i < graph_size; i++) { + uint64_t pk = 0; + uint32_t num_detour = 0; + for (uint32_t l = 0; l < knn_graph_degree && pk < output_graph_degree; l++) { + uint32_t next_num_detour = std::numeric_limits::max(); + for (uint64_t k = 0; k < knn_graph_degree; k++) { + const auto num_detour_k = detour_count(i, k); + if (num_detour_k > num_detour) { + next_num_detour = std::min(static_cast(num_detour_k), next_num_detour); + } - if (num_detour_k != num_detour) { continue; } + if (num_detour_k != num_detour) { continue; } - const auto candidate_node = knn_graph(i, k); - bool dup = false; - for (uint32_t dk = 0; dk < pk; dk++) { - if (candidate_node == output_graph_ptr[i * output_graph_degree + dk]) { - dup = true; - break; - } - } - if (!dup && candidate_node < graph_size) { - output_graph_ptr[i * output_graph_degree + pk] = candidate_node; - pk += 1; + const auto candidate_node = knn_graph_view(i, k); + bool dup = false; + for (uint32_t dk = 0; dk < pk; dk++) { + if (candidate_node == output_graph_ptr[i * output_graph_degree + dk]) { + dup = true; + break; } - if (pk >= output_graph_degree) break; + } + if (!dup && candidate_node < graph_size) { + output_graph_ptr[i * output_graph_degree + pk] = candidate_node; + pk += 1; } if (pk >= output_graph_degree) break; - - if (next_num_detour == std::numeric_limits::max()) { break; } - num_detour = next_num_detour; - } - if (pk != output_graph_degree) { - RAFT_LOG_DEBUG( - "Couldn't find the output_graph_degree (%lu) smallest detourable count nodes for " - "node %lu in the rank-based node reranking process", - output_graph_degree, - i); - invalid_neighbor_list = true; } + if (pk >= output_graph_degree) break; + + if (next_num_detour == std::numeric_limits::max()) { break; } + num_detour = next_num_detour; + } + if (pk != output_graph_degree) { + RAFT_LOG_DEBUG( + "Couldn't find the output_graph_degree (%lu) smallest detourable count nodes for " + "node %lu in the rank-based node reranking process", + output_graph_degree, + i); + invalid_neighbor_list = true; } - RAFT_EXPECTS( - !invalid_neighbor_list, - "Could not generate an intermediate CAGRA graph because the initial kNN graph contains too " - "many invalid or duplicated neighbor nodes. This error can occur, for example, if too many " - "overflows occur during the norm computation between the dataset vectors."); } + RAFT_EXPECTS( + !invalid_neighbor_list, + "Could not generate an intermediate CAGRA graph because the initial kNN graph contains too " + "many invalid or duplicated neighbor nodes. This error can occur, for example, if too many " + "overflows occur during the norm computation between the dataset vectors."); +} - const double time_prune_end = cur_time(); - RAFT_LOG_DEBUG("# Pruning time: %.1lf ms", (time_prune_end - time_prune_start) * 1000.0); +template +bool is_gpu_accessible(T* ptr) +{ + cudaPointerAttributes attr; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, ptr)); + return attr.devicePointer != nullptr; } // TODO allow pinned input for both knn_graph and new_graph @@ -1600,7 +1893,7 @@ void optimize( // large temporary memory for large arrays, e.g. everything >= O(graph_size) auto large_tmp_mr = raft::resource::get_large_workspace_resource(res); // temporary memory for small arrays, e.g. everything <= O(batchsize * graph_degree) - // auto tmp_mr = raft::resource::get_tmp_workspace_resource(res); + auto default_ws_mr = raft::resource::get_workspace_resource(res); RAFT_EXPECTS(knn_graph.extent(0) == new_graph.extent(0), "Each input array is expected to have the same number of rows"); @@ -1611,409 +1904,174 @@ void optimize( const uint64_t output_graph_degree = new_graph.extent(1); const uint64_t graph_size = new_graph.extent(0); // auto input_graph_ptr = knn_graph.data_handle(); - auto output_graph_ptr = new_graph.data_handle(); raft::common::nvtx::range fun_scope( "cagra::graph::optimize(%zu, %zu, %u)", graph_size, knn_graph_degree, output_graph_degree); - // MST optimization - auto mst_graph = raft::make_pinned_matrix(res, 0, 0); - auto mst_graph_num_edges = raft::make_pinned_vector(res, graph_size); - auto mst_graph_num_edges_ptr = mst_graph_num_edges.data_handle(); + // check if input and output are both device accessible + // in this case we assume data to be ONLY device accessible and not host accessible + // furthermore we ensure all large allocations to go to the large workspace resource + // and all small allocations to go to the default workspace resource + bool inout_device_accessible = false; + { + bool input_device_accessible = is_gpu_accessible(knn_graph.data_handle()); + bool output_device_accessible = is_gpu_accessible(new_graph.data_handle()); + RAFT_EXPECTS(input_device_accessible == output_device_accessible, + "Input and output must be either both device accessible or both host accessible"); + inout_device_accessible = input_device_accessible && output_device_accessible; + } + // MST optimization + // currently, only using GPU path for MST optimization + auto p_mst_graph = raft::make_pinned_matrix(res, 0, 0); + auto p_mst_graph_num_edges = raft::make_pinned_vector(res, graph_size); + auto p_mst_graph_num_edges_ptr = p_mst_graph_num_edges.data_handle(); #pragma omp parallel for for (uint64_t i = 0; i < graph_size; i++) { - mst_graph_num_edges_ptr[i] = 0; + p_mst_graph_num_edges_ptr[i] = 0; } if (guarantee_connectivity) { raft::common::nvtx::range block_scope( "cagra::graph::optimize/check_connectivity"); - mst_graph = raft::make_pinned_matrix( + p_mst_graph = raft::make_pinned_matrix( res, graph_size, output_graph_degree); RAFT_LOG_INFO("MST optimization is used to guarantee graph connectivity."); - mst_optimization(res, knn_graph, mst_graph.view(), mst_graph_num_edges.view(), use_gpu); + mst_optimization( + res, knn_graph, p_mst_graph.view(), p_mst_graph_num_edges.view(), use_gpu); for (uint64_t i = 0; i < graph_size; i++) { if (i < 8 || i >= graph_size - 8) { - RAFT_LOG_DEBUG("# mst_graph_num_edges_ptr[%lu]: %u\n", i, mst_graph_num_edges_ptr[i]); + RAFT_LOG_DEBUG("# p_mst_graph_num_edges_ptr[%lu]: %u\n", i, p_mst_graph_num_edges_ptr[i]); } } } - prune_graph(res, knn_graph, new_graph, use_gpu); - - auto rev_graph = raft::make_host_matrix(graph_size, output_graph_degree); - auto rev_graph_count = raft::make_host_vector(graph_size); - - bool _use_gpu_rev_graph = use_gpu; - // TODO: should we use pinned memory if we have issues fitting on GPU? - if (_use_gpu_rev_graph) { + // prune graph -- will use GPU path if possible, otherwise CPU path + // we only need to check in case input is not alreadydevice accessible + bool use_gpu_prune = use_gpu; + if (!inout_device_accessible) { try { - auto d_rev_graph_count = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size)); - auto d_dest_nodes = - raft::make_device_mdarray(res, large_tmp_mr, raft::make_extents(graph_size)); - auto d_rev_graph = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size, output_graph_degree)); - auto d_output_graph = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size, output_graph_degree)); + auto d_input_graph = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(graph_size, knn_graph_degree)); } catch (std::bad_alloc& e) { - RAFT_LOG_DEBUG("Insufficient memory for reverse graph on GPU"); - _use_gpu_rev_graph = false; + RAFT_LOG_DEBUG("Insufficient memory for pruning on GPU"); + use_gpu_prune = false; } catch (raft::logic_error& e) { - RAFT_LOG_DEBUG("Insufficient memory for reverse graph on GPU (logic error)"); - _use_gpu_rev_graph = false; + RAFT_LOG_DEBUG("Insufficient memory for pruning on GPU (logic error)"); + use_gpu_prune = false; } } - - const double time_make_start = cur_time(); - if (_use_gpu_rev_graph) { - // - // Make reverse graph on GPU - // - auto d_rev_graph_count = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size)); - - device_matrix_view_from_host d_rev_graph(res, rev_graph.view()); - device_matrix_view_from_host d_output_graph( + if (use_gpu_prune) { + // should be noop in case input is already device accessible + device_matrix_view_from_host d_input_graph( res, raft::make_host_matrix_view( - output_graph_ptr, graph_size, output_graph_degree)); - - { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/reverse"); - auto dest_nodes = raft::make_host_vector(graph_size); - auto d_dest_nodes = - raft::make_device_mdarray(res, large_tmp_mr, raft::make_extents(graph_size)); - - RAFT_CUDA_TRY(cudaMemsetAsync(d_rev_graph.data_handle(), - 0xff, - graph_size * output_graph_degree * sizeof(IdxT), - raft::resource::get_cuda_stream(res))); - - RAFT_CUDA_TRY(cudaMemsetAsync(d_rev_graph_count.data_handle(), - 0x00, - graph_size * sizeof(uint32_t), - raft::resource::get_cuda_stream(res))); - - for (uint64_t k = 0; k < output_graph_degree; k++) { -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - // dest_nodes.data_handle()[i] = output_graph_ptr[k + (output_graph_degree * i)]; - dest_nodes(i) = output_graph_ptr[k + (output_graph_degree * i)]; - } - raft::resource::sync_stream(res); - - raft::copy(d_dest_nodes.data_handle(), - dest_nodes.data_handle(), - graph_size, - raft::resource::get_cuda_stream(res)); - - dim3 threads(256, 1, 1); - dim3 blocks(1024, 1, 1); - kern_make_rev_graph<<>>( - d_dest_nodes.data_handle(), - d_rev_graph.data_handle(), - d_rev_graph_count.data_handle(), - graph_size, - output_graph_degree); - RAFT_LOG_DEBUG("# Making reverse graph on GPUs: %lu / %u \r", k, output_graph_degree); - } - - raft::resource::sync_stream(res); - RAFT_LOG_DEBUG("\n"); - - if (d_rev_graph.allocated_memory()) { - raft::copy(rev_graph.data_handle(), - d_rev_graph.data_handle(), - graph_size * output_graph_degree, - raft::resource::get_cuda_stream(res)); - } - raft::copy(rev_graph_count.data_handle(), - d_rev_graph_count.data_handle(), - graph_size, - raft::resource::get_cuda_stream(res)); - - const double time_make_end = cur_time(); - RAFT_LOG_DEBUG("# Making reverse graph time: %.1lf ms", - (time_make_end - time_make_start) * 1000.0); - } - - { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/combine"); - - // Merging the prunned graph and the reverse graph - const double merge_graph_start = cur_time(); - - // Create a boolean variable on the GPU using RAFT device allocator - auto d_check_num_protected_edges = raft::make_device_scalar(res, true); - - uint32_t batch_size = - std::min(static_cast(graph_size), static_cast(256 * 1024)); - const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; - - const dim3 threads_merge(32, 1, 1); - const dim3 blocks_merge(batch_size, 1, 1); - const size_t merge_smem_size = (output_graph_degree + output_graph_degree) * sizeof(IdxT); - for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { - kern_merge_graph - <<>>( - d_output_graph.data_handle(), - d_rev_graph.data_handle(), - d_rev_graph_count.data_handle(), - graph_size, - output_graph_degree, - mst_graph.data_handle(), - output_graph_degree, - mst_graph_num_edges_ptr, - batch_size, - i_batch, - guarantee_connectivity, - d_check_num_protected_edges.data_handle()); - } - - bool check_num_protected_edges = true; - raft::copy(&check_num_protected_edges, - d_check_num_protected_edges.data_handle(), - 1, - raft::resource::get_cuda_stream(res)); - - if (d_output_graph.allocated_memory()) { - raft::copy(output_graph_ptr, - d_output_graph.data_handle(), - graph_size * output_graph_degree, - raft::resource::get_cuda_stream(res)); - } - - raft::resource::sync_stream(res); + knn_graph.data_handle(), graph_size, knn_graph_degree)); - const auto merge_graph_end = cur_time(); - RAFT_EXPECTS(check_num_protected_edges, - "Failed to merge the MST, pruned, and reverse edge graphs. " - "Some nodes have too " - "many MST optimization edges."); + prune_graph_gpu(res, + d_input_graph.data_handle(), + graph_size, + knn_graph_degree, + new_graph.data_handle(), + output_graph_degree); - RAFT_LOG_DEBUG("# Time for merging graphs: %.1lf ms", - (merge_graph_end - merge_graph_start) * 1000.0); - } } else { - { - // Make reverse graph on CPU - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/reverse"); - - auto rev_graph_ptr = rev_graph.data_handle(); - auto rev_graph_count_ptr = rev_graph_count.data_handle(); - -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - rev_graph_count_ptr[i] = 0; - } - - for (uint32_t k = 0; k < output_graph_degree; k++) { -#pragma omp parallel for - for (uint64_t src_id = 0; src_id < graph_size; src_id++) { - const IdxT dest_id = - output_graph_ptr[k + (static_cast(output_graph_degree) * src_id)]; - if (dest_id >= graph_size) continue; - uint32_t pos; -#pragma omp atomic capture - pos = rev_graph_count_ptr[dest_id]++; - if (pos < output_graph_degree) { - rev_graph_ptr[(static_cast(output_graph_degree) * dest_id) + pos] = - static_cast(src_id); - } - } - } + prune_graph_cpu(knn_graph.data_handle(), + graph_size, + knn_graph_degree, + new_graph.data_handle(), + output_graph_degree); + } - const double time_make_end = cur_time(); - RAFT_LOG_DEBUG("# Making reverse graph time (CPU): %.1lf ms", - (time_make_end - time_make_start) * 1000.0); - } + // reverse graph creation will always use the GPU + auto d_rev_graph = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(graph_size, output_graph_degree)); - { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/combine"); - // - // Create search graphs from MST and pruned and reverse graphs - // - const double time_replace_start = cur_time(); + // This should use the default workspace resource for random access / atomics + auto d_rev_graph_count = raft::make_device_mdarray( + res, default_ws_mr, raft::make_extents(graph_size)); - bool check_num_protected_edges = true; -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - auto my_rev_graph = rev_graph.data_handle() + (output_graph_degree * i); - auto my_out_graph = output_graph_ptr + (output_graph_degree * i); - - // If guarantee_connectivity == true, use a temporal list to merge the neighbor lists of the - // graphs. - std::vector temp_output_neighbor_list; - if (guarantee_connectivity) { - temp_output_neighbor_list.resize(output_graph_degree); - my_out_graph = temp_output_neighbor_list.data(); - const auto mst_graph_num_edges = mst_graph_num_edges_ptr[i]; - - // Set MST graph edges - for (uint32_t j = 0; j < mst_graph_num_edges; j++) { - my_out_graph[j] = mst_graph(i, j); - } - - // Set pruned graph edges - for (uint32_t pruned_j = 0, output_j = mst_graph_num_edges; - (pruned_j < output_graph_degree) && (output_j < output_graph_degree); - pruned_j++) { - const auto v = output_graph_ptr[output_graph_degree * i + pruned_j]; - - // duplication check - bool dup = false; - for (uint32_t m = 0; m < output_j; m++) { - if (v == my_out_graph[m]) { - dup = true; - break; - } - } + const double time_make_start = cur_time(); - if (!dup) { - my_out_graph[output_j] = v; - output_j++; - } - } - } + make_reverse_graph_gpu( + res, d_rev_graph.data_handle(), d_rev_graph_count.data_handle(), new_graph); - const auto num_protected_edges = - std::max(mst_graph_num_edges_ptr[i], output_graph_degree / 2); - if (num_protected_edges > output_graph_degree) { check_num_protected_edges = false; } - if (num_protected_edges == output_graph_degree) continue; - - // Replace some edges of the output graph with edges of the reverse graph. - auto kr = std::min(rev_graph_count.data_handle()[i], output_graph_degree); - while (kr) { - kr -= 1; - if (my_rev_graph[kr] < graph_size) { - uint64_t pos = pos_in_array(my_rev_graph[kr], my_out_graph, output_graph_degree); - if (pos < num_protected_edges) { continue; } - uint64_t num_shift = pos - num_protected_edges; - if (pos >= output_graph_degree) { - num_shift = output_graph_degree - num_protected_edges - 1; - } - shift_array(my_out_graph + num_protected_edges, num_shift); - my_out_graph[num_protected_edges] = my_rev_graph[kr]; - } - } + const double time_make_end = cur_time(); + RAFT_LOG_DEBUG("# Making reverse graph time: %.1lf ms", + (time_make_end - time_make_start) * 1000.0); - // If guarantee_connectivity == true, move the output neighbor list from the temporal list - // to the output list. If false, the copy is not needed because my_out_graph is a pointer to - // the output buffer. - if (guarantee_connectivity) { - for (uint32_t j = 0; j < output_graph_degree; j++) { - output_graph_ptr[(output_graph_degree * i) + j] = my_out_graph[j]; - } - } - } - RAFT_EXPECTS(check_num_protected_edges, - "Failed to merge the MST, pruned, and reverse edge graphs. Some nodes have too " - "many MST optimization edges."); - - const double time_replace_end = cur_time(); - RAFT_LOG_DEBUG("# Replacing edges time: %.1lf ms", - (time_replace_end - time_replace_start) * 1000.0); + // merge graph -- will use GPU path if possible, otherwise CPU path + // we only need to check in case output is not already device accessible + bool use_gpu_merge = use_gpu; + if (!inout_device_accessible) { + try { + auto d_new_graph = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(graph_size, output_graph_degree)); + } catch (std::bad_alloc& e) { + RAFT_LOG_DEBUG("Insufficient memory for merging on GPU"); + use_gpu_merge = false; + } catch (raft::logic_error& e) { + RAFT_LOG_DEBUG("Insufficient memory for merging on GPU (logic error)"); + use_gpu_merge = false; } } - // Check stats - { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/stats"); - /* stats */ - uint64_t num_replaced_edges = 0; -#pragma omp parallel for reduction(+ : num_replaced_edges) - for (uint64_t i = 0; i < graph_size; i++) { - for (uint64_t k = 0; k < output_graph_degree; k++) { - const uint64_t j = output_graph_ptr[k + (output_graph_degree * i)]; - const uint64_t pos = - pos_in_array(j, output_graph_ptr + (output_graph_degree * i), output_graph_degree); - if (pos == output_graph_degree) { num_replaced_edges += 1; } - } + if (use_gpu_merge) { + // should be noop in case output is already device accessible + device_matrix_view_from_host d_new_graph( + res, + raft::make_host_matrix_view( + new_graph.data_handle(), graph_size, output_graph_degree)); + + merge_graph_gpu(res, + d_new_graph.data_handle(), + d_rev_graph.data_handle(), + d_rev_graph_count.data_handle(), + p_mst_graph.data_handle(), + p_mst_graph_num_edges.data_handle(), + graph_size, + output_graph_degree, + guarantee_connectivity); + + if (d_new_graph.allocated_memory()) { + raft::copy(new_graph.data_handle(), + d_new_graph.data_handle(), + graph_size * output_graph_degree, + raft::resource::get_cuda_stream(res)); } - RAFT_LOG_DEBUG("# Average number of replaced edges per node: %.2f", - (double)num_replaced_edges / graph_size); - } + } else { + auto rev_graph = raft::make_host_matrix(graph_size, output_graph_degree); + auto rev_graph_count = raft::make_host_vector(graph_size); + auto mst_graph = raft::make_host_matrix(0, 0); + raft::copy(rev_graph.data_handle(), + d_rev_graph.data_handle(), + graph_size * output_graph_degree, + raft::resource::get_cuda_stream(res)); + raft::copy(rev_graph_count.data_handle(), + d_rev_graph_count.data_handle(), + graph_size, + raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); - // Check number of incoming edges - { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/check_edges"); - auto in_edge_count = raft::make_host_vector(graph_size); - auto in_edge_count_ptr = in_edge_count.data_handle(); -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - in_edge_count_ptr[i] = 0; - } -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - for (uint64_t k = 0; k < output_graph_degree; k++) { - const uint64_t j = output_graph_ptr[k + (output_graph_degree * i)]; - if (j >= graph_size) continue; -#pragma omp atomic - in_edge_count_ptr[j] += 1; - } - } - auto hist = raft::make_host_vector(output_graph_degree); - auto hist_ptr = hist.data_handle(); - for (uint64_t k = 0; k < output_graph_degree; k++) { - hist_ptr[k] = 0; - } -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - uint32_t count = in_edge_count_ptr[i]; - if (count >= output_graph_degree) continue; -#pragma omp atomic - hist_ptr[count] += 1; - } - RAFT_LOG_DEBUG("# Histogram for number of incoming edges\n"); - uint32_t sum_hist = 0; - for (uint64_t k = 0; k < output_graph_degree; k++) { - sum_hist += hist_ptr[k]; - RAFT_LOG_DEBUG("# %3lu, %8u, %lf, (%8u, %lf)\n", - k, - hist_ptr[k], - (double)hist_ptr[k] / graph_size, - sum_hist, - (double)sum_hist / graph_size); - } + merge_graph_cpu(new_graph.data_handle(), + rev_graph.data_handle(), + rev_graph_count.data_handle(), + p_mst_graph.data_handle(), + p_mst_graph_num_edges_ptr, + graph_size, + output_graph_degree, + guarantee_connectivity); } - // Check duplication and out-of-range indices - { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/check_duplicates"); - uint64_t num_dup = 0; - uint64_t num_oor = 0; -#pragma omp parallel for reduction(+ : num_dup) reduction(+ : num_oor) - for (uint64_t i = 0; i < graph_size; i++) { - auto my_out_graph = output_graph_ptr + (output_graph_degree * i); - for (uint32_t j = 0; j < output_graph_degree; j++) { - const auto neighbor_a = my_out_graph[j]; + if (!inout_device_accessible) { + // following checks require host access + log_replaced_edges_stats(new_graph.data_handle(), graph_size, output_graph_degree); - // Check oor - if (neighbor_a > graph_size) { - num_oor++; - continue; - } + log_incoming_edges_histogram(new_graph.data_handle(), graph_size, output_graph_degree); - // Check duplication - for (uint32_t k = j + 1; k < output_graph_degree; k++) { - const auto neighbor_b = my_out_graph[k]; - if (neighbor_a == neighbor_b) { num_dup++; } - } - } - } - RAFT_EXPECTS( - num_dup == 0, "%lu duplicated node(s) are found in the generated CAGRA graph", num_dup); - RAFT_EXPECTS(num_oor == 0, - "%lu out-of-range index node(s) are found in the generated CAGRA graph", - num_oor); + check_duplicates_and_out_of_range( + new_graph.data_handle(), graph_size, output_graph_degree); + } else { + RAFT_LOG_DEBUG("Output graph is on GPU, skipping checks"); } } From 5e9ebc53950e472c8ee0035f280905dc5b1984b5 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Mon, 2 Mar 2026 17:34:17 +0000 Subject: [PATCH 08/40] enable both host/device inout graphs for optimize --- .../neighbors/detail/cagra/cagra_build.cuh | 23 ++--- cpp/src/neighbors/detail/cagra/graph_core.cuh | 97 +++++++++++-------- cpp/src/neighbors/detail/cagra/utils.hpp | 18 +++- 3 files changed, 86 insertions(+), 52 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/cagra_build.cuh b/cpp/src/neighbors/detail/cagra/cagra_build.cuh index a1c16250c5..009362aa96 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_build.cuh @@ -822,8 +822,6 @@ inline std::pair optimize_workspace_size(size_t n_rows, size_t index_size, bool mst_optimize = false) { - // TODO: MODIFY!! - // MST optimization memory (host only) size_t mst_host = n_rows * index_size; // mst_graph_num_edges if (mst_optimize) { @@ -835,27 +833,26 @@ inline std::pair optimize_workspace_size(size_t n_rows, // Prune stage memory // We neglect 8 bytes (both on host and device) for stats - size_t prune_host = n_rows * intermediate_degree * sizeof(uint8_t); // detour count + size_t batch_size = std::min(static_cast(256 * 1024), n_rows); - size_t prune_dev = n_rows * intermediate_degree * 1; // detour count (uint8_t) - prune_dev += n_rows * sizeof(uint32_t); // d_num_detour_edges - prune_dev += n_rows * intermediate_degree * index_size; // d_input_graph + size_t prune_dev = batch_size * intermediate_degree * 1; // detour count (uint8_t) + prune_dev += batch_size * sizeof(uint32_t); // d_num_detour_edges + prune_dev += n_rows * intermediate_degree * index_size; // d_input_graph // Reverse graph stage memory - size_t rev_host = n_rows * graph_degree * index_size; // rev_graph - rev_host += n_rows * sizeof(uint32_t); // rev_graph_count - rev_host += n_rows * index_size; // dest_nodes - size_t rev_dev = n_rows * graph_degree * index_size; // d_rev_graph rev_dev += n_rows * sizeof(uint32_t); // d_rev_graph_count rev_dev += n_rows * sizeof(uint32_t); // d_dest_nodes - // Memory for merging graphs (host only) + // Memory for merging graphs (host only optional) size_t combine_host = n_rows * sizeof(uint32_t) + graph_degree * sizeof(uint32_t); // in_edge_count + hist - size_t total_host = mst_host + std::max({prune_host, rev_host, combine_host}); - size_t total_dev = std::max(prune_dev, rev_dev); + // additional memory for combine stage on device + size_t combine_dev = n_rows * graph_degree * index_size; // d_output_graph + + size_t total_host = mst_host + combine_host; + size_t total_dev = std::max(prune_dev, rev_dev + combine_dev); return std::make_pair(total_host, total_dev); } diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 713b03ca20..96110c9613 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -246,6 +246,26 @@ __global__ void kern_make_rev_graph(const IdxT* const dest_nodes, // [graph_ } } +template +__global__ void kern_make_rev_graph_k(const IdxT* const dest_nodes, // [graph_size] + IdxT* const rev_graph, // [size, degree] + uint32_t* const rev_graph_count, // [graph_size] + const uint32_t graph_size, + const uint32_t degree, + uint64_t k) +{ + const uint64_t tid = threadIdx.x + (blockDim.x * blockIdx.x); + const uint64_t tnum = blockDim.x * gridDim.x; + + for (uint64_t src_id = tid; src_id < graph_size; src_id += tnum) { + IdxT dest_id = dest_nodes[k + (degree * src_id)]; + if (dest_id >= graph_size) continue; + + const uint32_t pos = atomicAdd(rev_graph_count + dest_id, 1); + if (pos < degree) { rev_graph[(degree * dest_id) + pos] = static_cast(src_id); } + } +} + // Based on the detour count, select the smallest detour count and its index // (Pruning Update Kernel) template @@ -932,11 +952,11 @@ void merge_graph_cpu(IdxT* output_graph_ptr, (time_replace_end - time_replace_start) * 1000.0); } -template +template void make_reverse_graph_gpu(raft::resources const& res, IdxT* d_rev_graph, uint32_t* d_rev_graph_count, - raft::host_matrix_view new_graph) + InOutMatrixView new_graph) { const uint64_t graph_size = new_graph.extent(0); const uint64_t output_graph_degree = new_graph.extent(1); @@ -958,26 +978,38 @@ void make_reverse_graph_gpu(raft::resources const& res, RAFT_CUDA_TRY(cudaMemsetAsync( d_rev_graph_count, 0x00, graph_size * sizeof(uint32_t), raft::resource::get_cuda_stream(res))); + bool output_graph_device_accessible = is_ptr_device_accessible(output_graph_ptr); + dim3 threads(256, 1, 1); + dim3 blocks(1024, 1, 1); + for (uint64_t k = 0; k < output_graph_degree; k++) { + if (output_graph_device_accessible) { + kern_make_rev_graph_k<<>>( + output_graph_ptr, + d_rev_graph, + d_rev_graph_count, + static_cast(graph_size), + static_cast(output_graph_degree), + k); + } else { #pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - dest_nodes(i) = output_graph_ptr[k + (output_graph_degree * i)]; - } - raft::resource::sync_stream(res); + for (uint64_t i = 0; i < graph_size; i++) { + dest_nodes(i) = output_graph_ptr[k + (output_graph_degree * i)]; + } + raft::resource::sync_stream(res); - raft::copy(d_dest_nodes.data_handle(), - dest_nodes.data_handle(), - graph_size, - raft::resource::get_cuda_stream(res)); + raft::copy(d_dest_nodes.data_handle(), + dest_nodes.data_handle(), + graph_size, + raft::resource::get_cuda_stream(res)); - dim3 threads(256, 1, 1); - dim3 blocks(1024, 1, 1); - kern_make_rev_graph<<>>( - d_dest_nodes.data_handle(), - d_rev_graph, - d_rev_graph_count, - static_cast(graph_size), - static_cast(output_graph_degree)); + kern_make_rev_graph<<>>( + d_dest_nodes.data_handle(), + d_rev_graph, + d_rev_graph_count, + static_cast(graph_size), + static_cast(output_graph_degree)); + } RAFT_LOG_DEBUG("# Making reverse graph on GPUs: %lu / %lu \r", k, output_graph_degree); } @@ -1868,24 +1900,13 @@ void prune_graph_cpu(IdxT* knn_graph_ptr, "overflows occur during the norm computation between the dataset vectors."); } -template -bool is_gpu_accessible(T* ptr) -{ - cudaPointerAttributes attr; - RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, ptr)); - return attr.devicePointer != nullptr; -} - // TODO allow pinned input for both knn_graph and new_graph -template , raft::memory_type::host>> -void optimize( - raft::resources const& res, - raft::mdspan, raft::row_major, g_accessor> knn_graph, - raft::host_matrix_view new_graph, - const bool guarantee_connectivity = true, - const bool use_gpu = true) +template +void optimize(raft::resources const& res, + InOutMatrixView knn_graph, + InOutMatrixView new_graph, + const bool guarantee_connectivity = true, + const bool use_gpu = true) { RAFT_LOG_DEBUG( "# Pruning kNN graph (size=%lu, degree=%lu)\n", knn_graph.extent(0), knn_graph.extent(1)); @@ -1913,8 +1934,8 @@ void optimize( // and all small allocations to go to the default workspace resource bool inout_device_accessible = false; { - bool input_device_accessible = is_gpu_accessible(knn_graph.data_handle()); - bool output_device_accessible = is_gpu_accessible(new_graph.data_handle()); + bool input_device_accessible = is_ptr_device_accessible(knn_graph.data_handle()); + bool output_device_accessible = is_ptr_device_accessible(new_graph.data_handle()); RAFT_EXPECTS(input_device_accessible == output_device_accessible, "Input and output must be either both device accessible or both host accessible"); inout_device_accessible = input_device_accessible && output_device_accessible; @@ -2062,7 +2083,7 @@ void optimize( guarantee_connectivity); } - if (!inout_device_accessible) { + if (is_ptr_host_accessible(new_graph.data_handle())) { // following checks require host access log_replaced_edges_stats(new_graph.data_handle(), graph_size, output_graph_degree); diff --git a/cpp/src/neighbors/detail/cagra/utils.hpp b/cpp/src/neighbors/detail/cagra/utils.hpp index 30c7287430..7889d6d9a9 100644 --- a/cpp/src/neighbors/detail/cagra/utils.hpp +++ b/cpp/src/neighbors/detail/cagra/utils.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -152,6 +152,22 @@ struct gen_index_msb_1_mask { }; } // namespace utils +template +bool is_ptr_device_accessible(T* ptr) +{ + cudaPointerAttributes attr; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, ptr)); + return attr.devicePointer != nullptr; +} + +template +bool is_ptr_host_accessible(T* ptr) +{ + cudaPointerAttributes attr; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, ptr)); + return attr.hostPointer != nullptr; +} + /** * Utility to sync memory from a host_matrix_view to a device_matrix_view * From 40977e2e456f2fd9ee32413be3590acfe2e7bdd4 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Mon, 2 Mar 2026 23:35:32 +0000 Subject: [PATCH 09/40] smaller fixes --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 61 +++++++++++-------- 1 file changed, 35 insertions(+), 26 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 6c0d6c747c..f77f6367e5 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -801,8 +801,8 @@ void check_duplicates_and_out_of_range(const IdxT* output_graph_ptr, template void merge_graph_gpu(raft::resources const& res, IdxT* output_graph_ptr, - const IdxT* d_rev_graph, - uint32_t* d_rev_graph_count, + const IdxT* d_rev_graph_ptr, + uint32_t* d_rev_graph_count_ptr, const IdxT* mst_graph_ptr, const uint32_t* mst_graph_num_edges_ptr, uint64_t graph_size, @@ -831,8 +831,8 @@ void merge_graph_gpu(raft::resources const& res, kern_merge_graph <<>>( d_output_graph.data_handle(), - d_rev_graph, - d_rev_graph_count, + d_rev_graph_ptr, + d_rev_graph_count_ptr, static_cast(graph_size), static_cast(output_graph_degree), mst_graph_ptr, @@ -955,8 +955,8 @@ void merge_graph_cpu(IdxT* output_graph_ptr, template void make_reverse_graph_gpu(raft::resources const& res, - IdxT* d_rev_graph, - uint32_t* d_rev_graph_count, + IdxT* d_rev_graph_ptr, + uint32_t* d_rev_graph_count_ptr, InOutMatrixView new_graph) { const uint64_t graph_size = new_graph.extent(0); @@ -966,18 +966,19 @@ void make_reverse_graph_gpu(raft::resources const& res, raft::common::nvtx::range block_scope( "cagra::graph::optimize/reverse"); - auto large_tmp_mr = raft::resource::get_large_workspace_resource(res); auto dest_nodes = raft::make_host_vector(graph_size); - auto d_dest_nodes = - raft::make_device_mdarray(res, large_tmp_mr, raft::make_extents(graph_size)); + auto d_dest_nodes = raft::make_device_mdarray( + res, raft::resource::get_workspace_resource(res), raft::make_extents(graph_size)); raft::matrix::fill( res, - raft::make_device_vector_view(d_rev_graph, graph_size * output_graph_degree), + raft::make_device_vector_view(d_rev_graph_ptr, graph_size * output_graph_degree), IdxT(-1)); raft::matrix::fill( - res, raft::make_device_vector_view(d_rev_graph_count, graph_size), uint32_t(0)); + res, + raft::make_device_vector_view(d_rev_graph_count_ptr, graph_size), + uint32_t(0)); bool output_graph_device_accessible = is_ptr_device_accessible(output_graph_ptr); dim3 threads(256, 1, 1); @@ -987,8 +988,8 @@ void make_reverse_graph_gpu(raft::resources const& res, if (output_graph_device_accessible) { kern_make_rev_graph_k<<>>( output_graph_ptr, - d_rev_graph, - d_rev_graph_count, + d_rev_graph_ptr, + d_rev_graph_count_ptr, static_cast(graph_size), static_cast(output_graph_degree), k); @@ -1003,8 +1004,8 @@ void make_reverse_graph_gpu(raft::resources const& res, kern_make_rev_graph<<>>( d_dest_nodes.data_handle(), - d_rev_graph, - d_rev_graph_count, + d_rev_graph_ptr, + d_rev_graph_count_ptr, static_cast(graph_size), static_cast(output_graph_degree)); } @@ -1679,6 +1680,8 @@ void prune_graph_gpu(raft::resources const& res, IdxT* output_graph_ptr, uint64_t output_graph_degree) { + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/prune"); auto default_ws_mr = raft::resource::get_workspace_resource(res); uint32_t batch_size = @@ -1715,6 +1718,8 @@ void prune_graph_gpu(raft::resources const& res, res, default_ws_mr, raft::make_extents(batch_size, output_graph_degree)); auto d_invalid_neighbor_list = raft::make_device_scalar(res, 0u); + bool output_device_accessible = is_ptr_device_accessible(output_graph_ptr); + for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { raft::matrix::fill(res, d_detour_count.view(), uint8_t(0xff)); raft::matrix::fill(res, d_num_no_detour_edges.view(), uint32_t(0)); @@ -1744,18 +1749,21 @@ void prune_graph_gpu(raft::resources const& res, knn_graph_degree, output_graph_degree, d_detour_count.data_handle(), - d_output_graph.data_handle(), + output_device_accessible ? d_output_graph.data_handle() + : output_graph_ptr + i_batch * batch_size * output_graph_degree, batch_size, i_batch, d_invalid_neighbor_list.data_handle()); - size_t copy_size = - std::min(static_cast(batch_size), graph_size - i_batch * batch_size) * - output_graph_degree; - raft::copy(output_graph_ptr + i_batch * batch_size * output_graph_degree, - d_output_graph.data_handle(), - copy_size, - raft::resource::get_cuda_stream(res)); + if (!output_device_accessible) { + size_t copy_size = + std::min(static_cast(batch_size), graph_size - i_batch * batch_size) * + output_graph_degree; + raft::copy(output_graph_ptr + i_batch * batch_size * output_graph_degree, + d_output_graph.data_handle(), + copy_size, + raft::resource::get_cuda_stream(res)); + } raft::resource::sync_stream(res); RAFT_LOG_DEBUG( @@ -1799,14 +1807,14 @@ void prune_graph_cpu(IdxT* knn_graph_ptr, IdxT* output_graph_ptr, uint64_t output_graph_degree) { + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/prune"); auto detour_count = raft::make_host_matrix(graph_size, knn_graph_degree); auto knn_graph_view = raft::make_host_matrix_view(knn_graph_ptr, graph_size, knn_graph_degree); { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/prune/2-hop-counting-by-CPU"); const double time_2hop_count_start = cur_time(); count_2hop_detours(knn_graph_view, detour_count.view()); @@ -2022,7 +2030,6 @@ void optimize(raft::resources const& res, } else { auto rev_graph = raft::make_host_matrix(graph_size, output_graph_degree); auto rev_graph_count = raft::make_host_vector(graph_size); - auto mst_graph = raft::make_host_matrix(0, 0); raft::copy(res, rev_graph.view(), d_rev_graph.view()); raft::copy(res, rev_graph_count.view(), d_rev_graph_count.view()); @@ -2036,6 +2043,8 @@ void optimize(raft::resources const& res, guarantee_connectivity); } + raft::resource::sync_stream(res); + if (is_ptr_host_accessible(new_graph.data_handle())) { // following checks require host access log_replaced_edges_stats(new_graph.data_handle(), graph_size, output_graph_degree); From 14e9f3ebc94aec7031d5c8eb685dc9b6fb36595d Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Tue, 3 Mar 2026 12:41:33 +0000 Subject: [PATCH 10/40] bugfix --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index f77f6367e5..5b2893e77f 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -1749,8 +1749,8 @@ void prune_graph_gpu(raft::resources const& res, knn_graph_degree, output_graph_degree, d_detour_count.data_handle(), - output_device_accessible ? d_output_graph.data_handle() - : output_graph_ptr + i_batch * batch_size * output_graph_degree, + output_device_accessible ? output_graph_ptr + i_batch * batch_size * output_graph_degree + : d_output_graph.data_handle(), batch_size, i_batch, d_invalid_neighbor_list.data_handle()); From 416558d40b1207ca7f1b8aad0aeda68b24e68aea Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Thu, 5 Mar 2026 21:43:49 +0000 Subject: [PATCH 11/40] fuse and simplify pruning, remove CPU path --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 331 +++++------------- 1 file changed, 92 insertions(+), 239 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 5b2893e77f..25be6ae393 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -157,79 +157,6 @@ __global__ void kern_sort(const DATA_T* const dataset, // [dataset_chunk_size, } } -template -__global__ void kern_prune(const IdxT* const knn_graph, // [graph_chunk_size, graph_degree] - const uint32_t graph_size, - const uint32_t graph_degree, - const uint32_t degree, - const uint32_t batch_size, - const uint32_t batch_id, - uint8_t* const detour_count, // [batch_size, graph_degree] - uint32_t* const num_no_detour_edges, // [batch_size] - uint64_t* const stats) -{ - __shared__ uint32_t smem_num_detour[MAX_DEGREE]; - extern __shared__ unsigned char smem_buf[]; - IdxT* const smem_knn_iA_neighbors = reinterpret_cast(smem_buf); - - uint64_t* const num_retain = stats; - uint64_t* const num_full = stats + 1; - - const uint64_t iA = blockIdx.x + (batch_size * batch_id); - const uint64_t iA_batch = blockIdx.x; - - if (iA >= graph_size) { return; } - - // Load this node's neighbor row into shared memory to reduce global reads - for (uint32_t k = threadIdx.x; k < graph_degree; k += blockDim.x) { - smem_num_detour[k] = 0; - smem_knn_iA_neighbors[k] = knn_graph[k + ((uint64_t)graph_degree * iA)]; - if (smem_knn_iA_neighbors[k] == iA) { - // Lower the priority of self-edge - smem_num_detour[k] = graph_degree; - } - } - __syncthreads(); - - // count number of detours (A->D->B) - for (uint32_t kAD = 0; kAD < graph_degree - 1; kAD++) { - const uint64_t iD = smem_knn_iA_neighbors[kAD]; - if (iD >= graph_size) { continue; } - for (uint32_t kDB = threadIdx.x; kDB < graph_degree; kDB += blockDim.x) { - const uint64_t iB_candidate = knn_graph[kDB + ((uint64_t)graph_degree * iD)]; - for (uint32_t kAB = kAD + 1; kAB < graph_degree; kAB++) { - // if ( kDB < kAB ) - { - const uint64_t iB = smem_knn_iA_neighbors[kAB]; - if (iB == iB_candidate) { - atomicAdd(smem_num_detour + kAB, 1); - break; - } - } - } - } - __syncthreads(); - } - - uint32_t num_edges_no_detour = 0; - for (uint32_t k = threadIdx.x; k < graph_degree; k += blockDim.x) { - detour_count[k + (graph_degree * iA_batch)] = min(smem_num_detour[k], (uint32_t)255); - if (smem_num_detour[k] == 0) { num_edges_no_detour++; } - } - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 1); - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 2); - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 4); - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 8); - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 16); - num_edges_no_detour = min(num_edges_no_detour, degree); - - if (threadIdx.x == 0) { - num_no_detour_edges[iA_batch] = num_edges_no_detour; - atomicAdd((unsigned long long int*)num_retain, (unsigned long long int)num_edges_no_detour); - if (num_edges_no_detour >= degree) { atomicAdd((unsigned long long int*)num_full, 1); } - } -} - template __global__ void kern_make_rev_graph(const IdxT* const dest_nodes, // [graph_size] IdxT* const rev_graph, // [size, degree] @@ -269,48 +196,98 @@ __global__ void kern_make_rev_graph_k(const IdxT* const dest_nodes, // [grap } } -// Based on the detour count, select the smallest detour count and its index -// (Pruning Update Kernel) -template -__global__ void kern_select_smallest_detour_neighbors( - const IdxT* const knn_graph, // [graph_chunk_size, graph_degree] - uint64_t graph_size, - uint64_t knn_graph_degree, - uint64_t output_graph_degree, - uint8_t* const d_detour_count, // [batch_size, graph_degree] - IdxT* output_graph_ptr, - const uint32_t batch_size, // [batch_size, output_graph_degree] - const uint32_t batch_id, - uint32_t* const d_invalid_neighbor_list) +template +__global__ void kern_fused_prune(const IdxT* const knn_graph, // [graph_chunk_size, graph_degree] + IdxT* const output_graph_ptr, + const uint32_t graph_size, + const uint32_t knn_graph_degree, + const uint32_t output_graph_degree, + const uint32_t batch_size, + const uint32_t batch_id, + uint32_t* const d_invalid_neighbor_list, + uint64_t* const stats) { - assert(blockDim.x == 32); + extern __shared__ unsigned char smem_buf[]; - // Allocate shared memory for detour counts and their indices - extern __shared__ IdxT smem_indices[]; - uint16_t* smem_detour_count = (uint16_t*)&smem_indices[knn_graph_degree]; + const uint32_t wid = threadIdx.x / raft::WarpSize; + const uint32_t lane_id = threadIdx.x % raft::WarpSize; - const uint64_t nid = blockIdx.x + (batch_size * batch_id); - const uint64_t nid_batch = blockIdx.x; + IdxT* const smem_indices = + reinterpret_cast(smem_buf + wid * knn_graph_degree * sizeof(IdxT)); + uint32_t* const smem_num_detour = reinterpret_cast( + smem_buf + wid * knn_graph_degree * sizeof(IdxT) + num_warps * knn_graph_degree * sizeof(IdxT)); + + uint64_t* const num_retain = stats; + uint64_t* const num_full = stats + 1; + + const unsigned warp_mask = 0xffffffff; const uint32_t maxval16 = 0x0000ffff; + const uint64_t nid_batch = blockIdx.x * num_warps + wid; + const uint64_t nid = nid_batch + (batch_size * batch_id); + if (nid >= graph_size) { return; } - // Load indices and detour counts for each neighbor; invalidate out-of-bounds entries - for (uint32_t k = threadIdx.x; k < knn_graph_degree; k += blockDim.x) { - smem_indices[k] = knn_graph[knn_graph_degree * nid + k]; - smem_detour_count[k] = (smem_indices[k] >= graph_size) - ? maxval16 - : (uint16_t)d_detour_count[nid_batch * knn_graph_degree + k]; + // Load this node's neighbor row into shared memory to reduce global reads + for (uint32_t k = lane_id; k < knn_graph_degree; k += raft::WarpSize) { + smem_num_detour[k] = 0; + smem_indices[k] = knn_graph[k + ((uint64_t)knn_graph_degree * nid)]; + if (smem_indices[k] == nid) { + // Lower the priority of self-edge + smem_num_detour[k] = knn_graph_degree; + } } __syncwarp(); - const unsigned warp_mask = 0xffffffff; + // count number of detours (A->D->B) + for (uint32_t kAD = 0; kAD < knn_graph_degree - 1; kAD++) { + const uint64_t iD = smem_indices[kAD]; + if (iD >= graph_size) { continue; } + for (uint32_t kDB = lane_id; kDB < knn_graph_degree; kDB += raft::WarpSize) { + const uint64_t iB_candidate = knn_graph[kDB + ((uint64_t)knn_graph_degree * iD)]; + for (uint32_t kAB = kAD + 1; kAB < knn_graph_degree; kAB++) { + // if ( kDB < kAB ) + { + const uint64_t iB = smem_indices[kAB]; + if (iB == iB_candidate) { + atomicAdd(smem_num_detour + kAB, 1); + break; + } + } + } + } + __syncwarp(); + } + + uint32_t num_edges_no_detour = 0; + for (uint32_t k = lane_id; k < knn_graph_degree; k += raft::WarpSize) { + smem_num_detour[k] = min(smem_num_detour[k], maxval16); + if (smem_num_detour[k] == 0) { num_edges_no_detour++; } + if (smem_indices[k] >= graph_size) { smem_num_detour[k] = maxval16; } + } + + __syncwarp(); + + num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 1); + num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 2); + num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 4); + num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 8); + num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 16); + num_edges_no_detour = min(num_edges_no_detour, output_graph_degree); + + if (lane_id == 0) { + atomicAdd((unsigned long long int*)num_retain, (unsigned long long int)num_edges_no_detour); + if (num_edges_no_detour >= output_graph_degree) { + atomicAdd((unsigned long long int*)num_full, 1); + } + } + for (uint32_t i = 0; i < output_graph_degree; i++) { uint32_t local_min = maxval16; uint32_t local_idx = maxval16; - for (uint32_t k = threadIdx.x; k < knn_graph_degree; k += blockDim.x) { - if (smem_detour_count[k] < local_min) { - local_min = smem_detour_count[k]; + for (uint32_t k = lane_id; k < knn_graph_degree; k += raft::WarpSize) { + if (smem_num_detour[k] < local_min) { + local_min = smem_num_detour[k]; local_idx = k; } } @@ -321,18 +298,18 @@ __global__ void kern_select_smallest_detour_neighbors( uint32_t warp_local_idx = warp_min_with_tag & 0xffff; if (warp_min_count == maxval16 || warp_local_idx == maxval16) { - if (threadIdx.x == 0) { atomicExch(d_invalid_neighbor_list, 1u); } + if (lane_id == 0) { atomicExch(d_invalid_neighbor_list, 1u); } break; } IdxT selected_node = smem_indices[warp_local_idx]; - for (uint32_t k = threadIdx.x; k < knn_graph_degree; k += blockDim.x) { - if (smem_indices[k] == selected_node) { smem_detour_count[k] = maxval16; } + for (uint32_t k = lane_id; k < knn_graph_degree; k += raft::WarpSize) { + if (smem_indices[k] == selected_node) { smem_num_detour[k] = maxval16; } } __syncwarp(warp_mask); - if (threadIdx.x == 0) { output_graph_ptr[nid_batch * output_graph_degree + i] = selected_node; } + if (lane_id == 0) { output_graph_ptr[nid_batch * output_graph_degree + i] = selected_node; } } } @@ -1690,15 +1667,6 @@ void prune_graph_gpu(raft::resources const& res, RAFT_LOG_DEBUG("# Pruning kNN Graph on GPUs\r"); - constexpr int MAX_DEGREE = 1024; - if (knn_graph_degree > MAX_DEGREE) { - RAFT_FAIL( - "The degree of input knn graph is too large (%zu). " - "It must be equal to or smaller than %d.", - knn_graph_degree, - MAX_DEGREE); - } - const double prune_start = cur_time(); uint64_t num_keep __attribute__((unused)) = 0; @@ -1710,10 +1678,6 @@ void prune_graph_gpu(raft::resources const& res, device_matrix_view_from_host d_input_graph( res, raft::make_host_matrix_view(knn_graph_ptr, graph_size, knn_graph_degree)); - auto d_detour_count = raft::make_device_mdarray( - res, default_ws_mr, raft::make_extents(batch_size, knn_graph_degree)); - auto d_num_no_detour_edges = raft::make_device_mdarray( - res, default_ws_mr, raft::make_extents(batch_size)); auto d_output_graph = raft::make_device_mdarray( res, default_ws_mr, raft::make_extents(batch_size, output_graph_degree)); auto d_invalid_neighbor_list = raft::make_device_scalar(res, 0u); @@ -1721,40 +1685,23 @@ void prune_graph_gpu(raft::resources const& res, bool output_device_accessible = is_ptr_device_accessible(output_graph_ptr); for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { - raft::matrix::fill(res, d_detour_count.view(), uint8_t(0xff)); - raft::matrix::fill(res, d_num_no_detour_edges.view(), uint32_t(0)); - - const dim3 threads_prune(32, 1, 1); + const uint32_t num_warps = 4; + const dim3 threads_prune(raft::WarpSize * num_warps, 1, 1); const dim3 blocks_prune(batch_size, 1, 1); - const size_t prune_smem_size = knn_graph_degree * sizeof(IdxT); - kern_prune + const size_t prune_smem_size = num_warps * knn_graph_degree * (sizeof(IdxT) + sizeof(uint32_t)); + kern_fused_prune <<>>( d_input_graph.data_handle(), + output_device_accessible ? output_graph_ptr + i_batch * batch_size * output_graph_degree + : d_output_graph.data_handle(), graph_size, knn_graph_degree, output_graph_degree, batch_size, i_batch, - d_detour_count.data_handle(), - d_num_no_detour_edges.data_handle(), + d_invalid_neighbor_list.data_handle(), dev_stats.data_handle()); - const size_t select_smem_size = (knn_graph_degree) * (sizeof(uint16_t) + sizeof(IdxT)); - const dim3 threads_select(32, 1, 1); - const dim3 blocks_select(batch_size, 1, 1); - kern_select_smallest_detour_neighbors - <<>>( - d_input_graph.data_handle(), - graph_size, - knn_graph_degree, - output_graph_degree, - d_detour_count.data_handle(), - output_device_accessible ? output_graph_ptr + i_batch * batch_size * output_graph_degree - : d_output_graph.data_handle(), - batch_size, - i_batch, - d_invalid_neighbor_list.data_handle()); - if (!output_device_accessible) { size_t copy_size = std::min(static_cast(batch_size), graph_size - i_batch * batch_size) * @@ -1800,79 +1747,6 @@ void prune_graph_gpu(raft::resources const& res, (double)num_full / graph_size * 100); } -template -void prune_graph_cpu(IdxT* knn_graph_ptr, - uint64_t graph_size, - uint64_t knn_graph_degree, - IdxT* output_graph_ptr, - uint64_t output_graph_degree) -{ - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/prune"); - auto detour_count = raft::make_host_matrix(graph_size, knn_graph_degree); - - auto knn_graph_view = - raft::make_host_matrix_view(knn_graph_ptr, graph_size, knn_graph_degree); - - { - const double time_2hop_count_start = cur_time(); - - count_2hop_detours(knn_graph_view, detour_count.view()); - - const double time_2hop_count_end = cur_time(); - RAFT_LOG_DEBUG("# Time for 2-hop detour counting on CPU: %.1lf sec", - time_2hop_count_end - time_2hop_count_start); - } - bool invalid_neighbor_list = false; -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - uint64_t pk = 0; - uint32_t num_detour = 0; - for (uint32_t l = 0; l < knn_graph_degree && pk < output_graph_degree; l++) { - uint32_t next_num_detour = std::numeric_limits::max(); - for (uint64_t k = 0; k < knn_graph_degree; k++) { - const auto num_detour_k = detour_count(i, k); - if (num_detour_k > num_detour) { - next_num_detour = std::min(static_cast(num_detour_k), next_num_detour); - } - - if (num_detour_k != num_detour) { continue; } - - const auto candidate_node = knn_graph_view(i, k); - bool dup = false; - for (uint32_t dk = 0; dk < pk; dk++) { - if (candidate_node == output_graph_ptr[i * output_graph_degree + dk]) { - dup = true; - break; - } - } - if (!dup && candidate_node < graph_size) { - output_graph_ptr[i * output_graph_degree + pk] = candidate_node; - pk += 1; - } - if (pk >= output_graph_degree) break; - } - if (pk >= output_graph_degree) break; - - if (next_num_detour == std::numeric_limits::max()) { break; } - num_detour = next_num_detour; - } - if (pk != output_graph_degree) { - RAFT_LOG_DEBUG( - "Couldn't find the output_graph_degree (%lu) smallest detourable count nodes for " - "node %lu in the rank-based node reranking process", - output_graph_degree, - i); - invalid_neighbor_list = true; - } - } - RAFT_EXPECTS( - !invalid_neighbor_list, - "Could not generate an intermediate CAGRA graph because the initial kNN graph contains too " - "many invalid or duplicated neighbor nodes. This error can occur, for example, if too many " - "overflows occur during the norm computation between the dataset vectors."); -} - // TODO allow pinned input for both knn_graph and new_graph template void optimize(raft::resources const& res, @@ -1939,22 +1813,8 @@ void optimize(raft::resources const& res, } } - // prune graph -- will use GPU path if possible, otherwise CPU path - // we only need to check in case input is not alreadydevice accessible - bool use_gpu_prune = use_gpu; - if (!inout_device_accessible) { - try { - auto d_input_graph = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size, knn_graph_degree)); - } catch (std::bad_alloc& e) { - RAFT_LOG_DEBUG("Insufficient memory for pruning on GPU"); - use_gpu_prune = false; - } catch (raft::logic_error& e) { - RAFT_LOG_DEBUG("Insufficient memory for pruning on GPU (logic error)"); - use_gpu_prune = false; - } - } - if (use_gpu_prune) { + // prune graph -- will always use GPU path + { // should be noop in case input is already device accessible device_matrix_view_from_host d_input_graph( res, @@ -1967,13 +1827,6 @@ void optimize(raft::resources const& res, knn_graph_degree, new_graph.data_handle(), output_graph_degree); - - } else { - prune_graph_cpu(knn_graph.data_handle(), - graph_size, - knn_graph_degree, - new_graph.data_handle(), - output_graph_degree); } // reverse graph creation will always use the GPU From d8d8bd877db9596720efaf67bb1373084dbf17c8 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Thu, 5 Mar 2026 22:49:16 +0000 Subject: [PATCH 12/40] cleanup merge, remove CPU path --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 339 +++++------------- 1 file changed, 85 insertions(+), 254 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 25be6ae393..392edc97d9 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -197,8 +197,8 @@ __global__ void kern_make_rev_graph_k(const IdxT* const dest_nodes, // [grap } template -__global__ void kern_fused_prune(const IdxT* const knn_graph, // [graph_chunk_size, graph_degree] - IdxT* const output_graph_ptr, +__global__ void kern_fused_prune(const IdxT* const knn_graph, // [graph_chunk_size, graph_degree] + IdxT* const output_graph_ptr, // [batch_size, output_graph_degree] const uint32_t graph_size, const uint32_t knn_graph_degree, const uint32_t output_graph_degree, @@ -337,8 +337,8 @@ __device__ void thread_shift_array(T* array, uint64_t num) } } -template -__global__ void kern_merge_graph(IdxT* output_graph, +template +__global__ void kern_merge_graph(IdxT* output_graph, // [batch_size, output_graph_degree] const IdxT* const rev_graph, uint32_t* const rev_graph_count, // [graph_size] const uint32_t graph_size, @@ -352,29 +352,32 @@ __global__ void kern_merge_graph(IdxT* output_graph, bool* check_num_protected_edges) { extern __shared__ unsigned char smem_buf[]; - IdxT* smem_sorted_output_graph = reinterpret_cast(smem_buf); - assert(blockDim.x == 32); + const uint32_t wid = threadIdx.x / raft::WarpSize; + const uint32_t lane_id = threadIdx.x % raft::WarpSize; - const uint64_t nid = blockIdx.x + (batch_size * batch_id); - if (nid >= graph_size) { return; } + IdxT* smem_sorted_output_graph = + reinterpret_cast(smem_buf + wid * output_graph_degree * sizeof(IdxT)); + + const uint64_t nid_batch = blockIdx.x * num_warps + wid; + const uint64_t nid = nid_batch + (batch_size * batch_id); - if (threadIdx.x == 0) check_num_protected_edges[0] = true; + if (nid >= graph_size) { return; } - const auto mst_graph_num_edges = mst_graph_num_edges_ptr[nid]; + const auto mst_graph_num_edges = guarantee_connectivity ? mst_graph_num_edges_ptr[nid] : 0; // If guarantee_connectivity == true, use a temporal list to merge the // neighbor lists of the graphs. if (guarantee_connectivity) { - for (uint32_t i = threadIdx.x; i < mst_graph_degree; i += 32) { + for (uint32_t i = lane_id; i < mst_graph_degree; i += raft::WarpSize) { smem_sorted_output_graph[i] = mst_graph[nid * mst_graph_degree + i]; } __syncwarp(); for (uint32_t pruned_j = 0, output_j = mst_graph_num_edges; (pruned_j < output_graph_degree) && (output_j < output_graph_degree); pruned_j++) { - const auto v = output_graph[output_graph_degree * nid + pruned_j]; + const auto v = output_graph[output_graph_degree * nid_batch + pruned_j]; unsigned int dup = 0; - for (uint32_t m = threadIdx.x; m < output_j; m += 32) { + for (uint32_t m = lane_id; m < output_j; m += raft::WarpSize) { if (v == smem_sorted_output_graph[m]) { dup = 1; break; @@ -383,7 +386,7 @@ __global__ void kern_merge_graph(IdxT* output_graph, unsigned int warp_dup = __ballot_sync(0xffffffff, dup); if (warp_dup == 0) { - if (threadIdx.x == 0) smem_sorted_output_graph[output_j] = v; + if (lane_id == 0) smem_sorted_output_graph[output_j] = v; output_j++; } __syncwarp(); @@ -391,8 +394,8 @@ __global__ void kern_merge_graph(IdxT* output_graph, } else { - for (uint32_t i = threadIdx.x; i < output_graph_degree; i += 32) { - smem_sorted_output_graph[i] = output_graph[output_graph_degree * nid + i]; + for (uint32_t i = lane_id; i < output_graph_degree; i += raft::WarpSize) { + smem_sorted_output_graph[i] = output_graph[output_graph_degree * nid_batch + i]; } __syncwarp(); } @@ -412,7 +415,7 @@ __global__ void kern_merge_graph(IdxT* output_graph, if (pos < num_protected_edges) { continue; } uint64_t num_shift = pos - num_protected_edges; if (pos >= output_graph_degree) { num_shift = output_graph_degree - num_protected_edges - 1; } - if (threadIdx.x == 0) { + if (lane_id == 0) { thread_shift_array(smem_sorted_output_graph + num_protected_edges, num_shift); smem_sorted_output_graph[num_protected_edges] = rev_graph[kr + (output_graph_degree * nid)]; } @@ -420,8 +423,8 @@ __global__ void kern_merge_graph(IdxT* output_graph, } } - for (uint32_t i = threadIdx.x; i < output_graph_degree; i += 32) { - output_graph[(output_graph_degree * nid) + i] = smem_sorted_output_graph[i]; + for (uint32_t i = lane_id; i < output_graph_degree; i += raft::WarpSize) { + output_graph[(output_graph_degree * nid_batch) + i] = smem_sorted_output_graph[i]; } } @@ -780,8 +783,8 @@ void merge_graph_gpu(raft::resources const& res, IdxT* output_graph_ptr, const IdxT* d_rev_graph_ptr, uint32_t* d_rev_graph_count_ptr, - const IdxT* mst_graph_ptr, - const uint32_t* mst_graph_num_edges_ptr, + IdxT* mst_graph_ptr, + uint32_t* mst_graph_num_edges_ptr, uint64_t graph_size, uint64_t output_graph_degree, bool guarantee_connectivity) @@ -789,36 +792,62 @@ void merge_graph_gpu(raft::resources const& res, raft::common::nvtx::range block_scope( "cagra::graph::optimize/combine"); + auto default_ws_mr = raft::resource::get_workspace_resource(res); const double merge_graph_start = cur_time(); - device_matrix_view_from_host d_output_graph( - res, - raft::make_host_matrix_view(output_graph_ptr, graph_size, output_graph_degree)); - auto d_check_num_protected_edges = raft::make_device_scalar(res, true); + auto d_invalid_neighbor_list = raft::make_device_scalar(res, 0u); uint32_t batch_size = std::min(static_cast(graph_size), static_cast(256 * 1024)); const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; - const dim3 threads_merge(32, 1, 1); - const dim3 blocks_merge(batch_size, 1, 1); - const size_t merge_smem_size = (output_graph_degree + output_graph_degree) * sizeof(IdxT); + bool output_device_accessible = is_ptr_device_accessible(output_graph_ptr); + auto d_output_graph = raft::make_device_mdarray( + res, + default_ws_mr, + raft::make_extents(output_device_accessible ? 0 : batch_size, output_graph_degree)); + + device_matrix_view_from_host d_mst_graph( + res, + raft::make_host_matrix_view( + mst_graph_ptr, guarantee_connectivity ? graph_size : 0, output_graph_degree)); + + device_matrix_view_from_host d_mst_graph_num_edges( + res, + raft::make_host_matrix_view( + mst_graph_num_edges_ptr, guarantee_connectivity ? graph_size : 0, 1)); + + const uint32_t num_warps = 4; + const dim3 threads_merge(raft::WarpSize * num_warps, 1, 1); + const dim3 blocks_merge(batch_size / num_warps, 1, 1); + const size_t merge_smem_size = num_warps * output_graph_degree * sizeof(IdxT); for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { - kern_merge_graph + kern_merge_graph <<>>( - d_output_graph.data_handle(), + output_device_accessible ? output_graph_ptr + (i_batch * batch_size * output_graph_degree) + : d_output_graph.data_handle(), d_rev_graph_ptr, d_rev_graph_count_ptr, static_cast(graph_size), static_cast(output_graph_degree), - mst_graph_ptr, + d_mst_graph.data_handle(), static_cast(output_graph_degree), - mst_graph_num_edges_ptr, + d_mst_graph_num_edges.data_handle(), batch_size, i_batch, guarantee_connectivity, d_check_num_protected_edges.data_handle()); + + if (!output_device_accessible) { + size_t copy_size = + std::min(static_cast(batch_size), graph_size - i_batch * batch_size) * + output_graph_degree; + raft::copy(output_graph_ptr + i_batch * batch_size * output_graph_degree, + d_output_graph.data_handle(), + copy_size, + raft::resource::get_cuda_stream(res)); + } } bool check_num_protected_edges = true; @@ -827,13 +856,6 @@ void merge_graph_gpu(raft::resources const& res, 1, raft::resource::get_cuda_stream(res)); - if (d_output_graph.allocated_memory()) { - raft::copy( - res, - raft::make_host_matrix_view(output_graph_ptr, graph_size, output_graph_degree), - d_output_graph.view()); - } - const auto merge_graph_end = cur_time(); RAFT_EXPECTS(check_num_protected_edges, "Failed to merge the MST, pruned, and reverse edge graphs. " @@ -844,92 +866,6 @@ void merge_graph_gpu(raft::resources const& res, (merge_graph_end - merge_graph_start) * 1000.0); } -template -void merge_graph_cpu(IdxT* output_graph_ptr, - const IdxT* rev_graph_ptr, - const uint32_t* rev_graph_count_ptr, - const IdxT* mst_graph_ptr, - const uint32_t* mst_graph_num_edges_ptr, - uint64_t graph_size, - uint64_t output_graph_degree, - bool guarantee_connectivity) -{ - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/combine"); - - const double time_replace_start = cur_time(); - - bool check_num_protected_edges = true; -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - auto my_rev_graph = rev_graph_ptr + (output_graph_degree * i); - auto my_out_graph = output_graph_ptr + (output_graph_degree * i); - - std::vector temp_output_neighbor_list; - if (guarantee_connectivity) { - temp_output_neighbor_list.resize(output_graph_degree); - my_out_graph = temp_output_neighbor_list.data(); - const auto mst_graph_num_edges = mst_graph_num_edges_ptr[i]; - - for (uint32_t j = 0; j < mst_graph_num_edges; j++) { - my_out_graph[j] = mst_graph_ptr[i * output_graph_degree + j]; - } - - for (uint32_t pruned_j = 0, output_j = mst_graph_num_edges; - (pruned_j < output_graph_degree) && (output_j < output_graph_degree); - pruned_j++) { - const auto v = output_graph_ptr[output_graph_degree * i + pruned_j]; - - bool dup = false; - for (uint32_t m = 0; m < output_j; m++) { - if (v == my_out_graph[m]) { - dup = true; - break; - } - } - - if (!dup) { - my_out_graph[output_j] = v; - output_j++; - } - } - } - - const auto num_protected_edges = - std::max(mst_graph_num_edges_ptr[i], output_graph_degree / 2); - if (num_protected_edges > output_graph_degree) { check_num_protected_edges = false; } - if (num_protected_edges == output_graph_degree) continue; - - auto kr = std::min(rev_graph_count_ptr[i], output_graph_degree); - while (kr) { - kr -= 1; - if (my_rev_graph[kr] < graph_size) { - uint64_t pos = pos_in_array(my_rev_graph[kr], my_out_graph, output_graph_degree); - if (pos < num_protected_edges) { continue; } - uint64_t num_shift = pos - num_protected_edges; - if (pos >= output_graph_degree) { - num_shift = output_graph_degree - num_protected_edges - 1; - } - shift_array(my_out_graph + num_protected_edges, num_shift); - my_out_graph[num_protected_edges] = my_rev_graph[kr]; - } - } - - if (guarantee_connectivity) { - for (uint32_t j = 0; j < output_graph_degree; j++) { - output_graph_ptr[(output_graph_degree * i) + j] = my_out_graph[j]; - } - } - } - RAFT_EXPECTS(check_num_protected_edges, - "Failed to merge the MST, pruned, and reverse edge graphs. Some nodes have too " - "many MST optimization edges."); - - const double time_replace_end = cur_time(); - RAFT_LOG_DEBUG("# Replacing edges time: %.1lf ms", - (time_replace_end - time_replace_start) * 1000.0); -} - template void make_reverse_graph_gpu(raft::resources const& res, IdxT* d_rev_graph_ptr, @@ -1585,58 +1521,6 @@ void mst_optimization(raft::resources const& res, RAFT_LOG_DEBUG("# MST optimization time: %.1lf sec", time_mst_opt_end - time_mst_opt_start); } -template -void count_2hop_detours(raft::host_matrix_view knn_graph, - raft::host_matrix_view detour_count) -{ - RAFT_EXPECTS(knn_graph.extent(0) == detour_count.extent(0), - "knn_graph and detour_count are expected to have the same number of rows"); - RAFT_EXPECTS(knn_graph.extent(1) == detour_count.extent(1), - "knn_graph and detour_count are expected to have the same number of cols"); - const uint64_t graph_size = knn_graph.extent(0); - const uint64_t graph_degree = knn_graph.extent(1); - -#pragma omp parallel for - for (IdxT iA = 0; iA < graph_size; iA++) { - // Create a list of nodes, iB_candidates, that can be reached in 2-hops from node A. - auto iB_candidates = - raft::make_host_vector((graph_degree - 1) * (graph_degree - 1)); - for (uint64_t kAC = 0; kAC < graph_degree - 1; kAC++) { - IdxT iC = knn_graph(iA, kAC); - for (uint64_t kCB = 0; kCB < graph_degree - 1; kCB++) { - IdxT iB_candidate; - if (iC == iA || iC >= graph_size) { - iB_candidate = graph_size; - } else { - iB_candidate = knn_graph(iC, kCB); - if (iB_candidate == iA || iB_candidate == iC) { iB_candidate = graph_size; } - } - uint64_t idx; - if (kAC < kCB) { - idx = (kCB * kCB) + kAC; - } else { - idx = (kAC * (kAC + 1)) + kCB; - } - iB_candidates(idx) = iB_candidate; - } - } - // Count how many 2-hop detours are on each edge of node A. - for (uint64_t kAB = 0; kAB < graph_degree; kAB++) { - constexpr uint32_t max_count = 255; - uint32_t count = 0; - IdxT iB = knn_graph(iA, kAB); - if (iB == iA) { - count = max_count; - } else { - for (uint64_t idx = 0; idx < kAB * kAB; idx++) { - if (iB_candidates(idx) == iB) { count += 1; } - } - } - detour_count(iA, kAB) = std::min(count, max_count); - } - } -} - // // Prune unimportant edges based on 2-hop detour counts. // @@ -1678,16 +1562,18 @@ void prune_graph_gpu(raft::resources const& res, device_matrix_view_from_host d_input_graph( res, raft::make_host_matrix_view(knn_graph_ptr, graph_size, knn_graph_degree)); - auto d_output_graph = raft::make_device_mdarray( - res, default_ws_mr, raft::make_extents(batch_size, output_graph_degree)); auto d_invalid_neighbor_list = raft::make_device_scalar(res, 0u); bool output_device_accessible = is_ptr_device_accessible(output_graph_ptr); + auto d_output_graph = raft::make_device_mdarray( + res, + default_ws_mr, + raft::make_extents(output_device_accessible ? 0 : batch_size, output_graph_degree)); for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { const uint32_t num_warps = 4; const dim3 threads_prune(raft::WarpSize * num_warps, 1, 1); - const dim3 blocks_prune(batch_size, 1, 1); + const dim3 blocks_prune(batch_size / num_warps, 1, 1); const size_t prune_smem_size = num_warps * knn_graph_degree * (sizeof(IdxT) + sizeof(uint32_t)); kern_fused_prune <<>>( @@ -1775,54 +1661,36 @@ void optimize(raft::resources const& res, raft::common::nvtx::range fun_scope( "cagra::graph::optimize(%zu, %zu, %u)", graph_size, knn_graph_degree, output_graph_degree); - // check if input and output are both device accessible - // in this case we assume data to be ONLY device accessible and not host accessible - // furthermore we ensure all large allocations to go to the large workspace resource - // and all small allocations to go to the default workspace resource - bool inout_device_accessible = false; - { - bool input_device_accessible = is_ptr_device_accessible(knn_graph.data_handle()); - bool output_device_accessible = is_ptr_device_accessible(new_graph.data_handle()); - RAFT_EXPECTS(input_device_accessible == output_device_accessible, - "Input and output must be either both device accessible or both host accessible"); - inout_device_accessible = input_device_accessible && output_device_accessible; - } - // MST optimization // currently, only using GPU path for MST optimization - auto p_mst_graph = raft::make_pinned_matrix(res, 0, 0); - auto p_mst_graph_num_edges = raft::make_pinned_vector(res, graph_size); - auto p_mst_graph_num_edges_ptr = p_mst_graph_num_edges.data_handle(); -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - p_mst_graph_num_edges_ptr[i] = 0; - } + auto mst_graph = raft::make_host_matrix(0, 0); + auto mst_graph_num_edges = raft::make_host_vector(0); + if (guarantee_connectivity) { + auto mst_graph_num_edges = raft::make_host_vector(graph_size); + auto mst_graph_num_edges_ptr = mst_graph_num_edges.data_handle(); +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + mst_graph_num_edges_ptr[i] = 0; + } raft::common::nvtx::range block_scope( "cagra::graph::optimize/check_connectivity"); - p_mst_graph = raft::make_pinned_matrix( - res, graph_size, output_graph_degree); + mst_graph = + raft::make_host_matrix(graph_size, output_graph_degree); RAFT_LOG_INFO("MST optimization is used to guarantee graph connectivity."); - mst_optimization( - res, knn_graph, p_mst_graph.view(), p_mst_graph_num_edges.view(), use_gpu); + mst_optimization(res, knn_graph, mst_graph.view(), mst_graph_num_edges.view(), use_gpu); for (uint64_t i = 0; i < graph_size; i++) { if (i < 8 || i >= graph_size - 8) { - RAFT_LOG_DEBUG("# p_mst_graph_num_edges_ptr[%lu]: %u\n", i, p_mst_graph_num_edges_ptr[i]); + RAFT_LOG_DEBUG("# mst_graph_num_edges_ptr[%lu]: %u\n", i, mst_graph_num_edges_ptr[i]); } } } // prune graph -- will always use GPU path { - // should be noop in case input is already device accessible - device_matrix_view_from_host d_input_graph( - res, - raft::make_host_matrix_view( - knn_graph.data_handle(), graph_size, knn_graph_degree)); - prune_graph_gpu(res, - d_input_graph.data_handle(), + knn_graph.data_handle(), graph_size, knn_graph_degree, new_graph.data_handle(), @@ -1846,51 +1714,14 @@ void optimize(raft::resources const& res, RAFT_LOG_DEBUG("# Making reverse graph time: %.1lf ms", (time_make_end - time_make_start) * 1000.0); - // merge graph -- will use GPU path if possible, otherwise CPU path - // we only need to check in case output is not already device accessible - bool use_gpu_merge = use_gpu; - if (!inout_device_accessible) { - try { - auto d_new_graph = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size, output_graph_degree)); - } catch (std::bad_alloc& e) { - RAFT_LOG_DEBUG("Insufficient memory for merging on GPU"); - use_gpu_merge = false; - } catch (raft::logic_error& e) { - RAFT_LOG_DEBUG("Insufficient memory for merging on GPU (logic error)"); - use_gpu_merge = false; - } - } - - if (use_gpu_merge) { - // should be noop in case output is already device accessible - device_matrix_view_from_host d_new_graph( - res, - raft::make_host_matrix_view( - new_graph.data_handle(), graph_size, output_graph_degree)); - + // merge graph -- will always use GPU path + { merge_graph_gpu(res, - d_new_graph.data_handle(), + new_graph.data_handle(), d_rev_graph.data_handle(), d_rev_graph_count.data_handle(), - p_mst_graph.data_handle(), - p_mst_graph_num_edges.data_handle(), - graph_size, - output_graph_degree, - guarantee_connectivity); - - if (d_new_graph.allocated_memory()) { raft::copy(res, new_graph, d_new_graph.view()); } - } else { - auto rev_graph = raft::make_host_matrix(graph_size, output_graph_degree); - auto rev_graph_count = raft::make_host_vector(graph_size); - raft::copy(res, rev_graph.view(), d_rev_graph.view()); - raft::copy(res, rev_graph_count.view(), d_rev_graph_count.view()); - - merge_graph_cpu(new_graph.data_handle(), - rev_graph.data_handle(), - rev_graph_count.data_handle(), - p_mst_graph.data_handle(), - p_mst_graph_num_edges_ptr, + mst_graph.data_handle(), + mst_graph_num_edges.data_handle(), graph_size, output_graph_degree, guarantee_connectivity); From 00c42045aa9f0f7d148865ec7e570078e5f16658 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Fri, 6 Mar 2026 00:01:10 +0000 Subject: [PATCH 13/40] batch reverse creation --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 116 ++++++++---------- 1 file changed, 52 insertions(+), 64 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 392edc97d9..9f2bc09d86 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -157,38 +157,24 @@ __global__ void kern_sort(const DATA_T* const dataset, // [dataset_chunk_size, } } -template -__global__ void kern_make_rev_graph(const IdxT* const dest_nodes, // [graph_size] - IdxT* const rev_graph, // [size, degree] - uint32_t* const rev_graph_count, // [graph_size] - const uint32_t graph_size, - const uint32_t degree) -{ - const uint32_t tid = threadIdx.x + (blockDim.x * blockIdx.x); - const uint32_t tnum = blockDim.x * gridDim.x; - - for (uint32_t src_id = tid; src_id < graph_size; src_id += tnum) { - const IdxT dest_id = dest_nodes[src_id]; - if (dest_id >= graph_size) continue; - - const uint32_t pos = atomicAdd(rev_graph_count + dest_id, 1); - if (pos < degree) { rev_graph[pos + ((uint64_t)degree * dest_id)] = src_id; } - } -} - template -__global__ void kern_make_rev_graph_k(const IdxT* const dest_nodes, // [graph_size] - IdxT* const rev_graph, // [size, degree] - uint32_t* const rev_graph_count, // [graph_size] - const uint32_t graph_size, - const uint32_t degree, - uint64_t k) +__global__ void kern_rev_graph_batched(const IdxT* const dest_nodes, // [batch_size, degree] + IdxT* const rev_graph, // [graph_size, degree] + uint32_t* const rev_graph_count, // [graph_size] + const uint32_t graph_size, + const uint32_t degree, + const uint32_t batch_size, + const uint32_t batch_id) { const uint64_t tid = threadIdx.x + (blockDim.x * blockIdx.x); const uint64_t tnum = blockDim.x * gridDim.x; - for (uint64_t src_id = tid; src_id < graph_size; src_id += tnum) { - IdxT dest_id = dest_nodes[k + (degree * src_id)]; + const uint64_t block_batch_size = min(batch_size, graph_size - batch_id * batch_size); + + for (uint64_t idx = tid; idx < block_batch_size * degree; idx += tnum) { + const IdxT dest_id = dest_nodes[idx]; + const uint32_t src_id = idx / degree; + if (dest_id >= graph_size) continue; const uint32_t pos = atomicAdd(rev_graph_count + dest_id, 1); @@ -866,22 +852,18 @@ void merge_graph_gpu(raft::resources const& res, (merge_graph_end - merge_graph_start) * 1000.0); } -template +template void make_reverse_graph_gpu(raft::resources const& res, + IdxT* output_graph_ptr, IdxT* d_rev_graph_ptr, uint32_t* d_rev_graph_count_ptr, - InOutMatrixView new_graph) + uint64_t graph_size, + uint64_t output_graph_degree) { - const uint64_t graph_size = new_graph.extent(0); - const uint64_t output_graph_degree = new_graph.extent(1); - const IdxT* output_graph_ptr = new_graph.data_handle(); - raft::common::nvtx::range block_scope( "cagra::graph::optimize/reverse"); - auto dest_nodes = raft::make_host_vector(graph_size); - auto d_dest_nodes = raft::make_device_mdarray( - res, raft::resource::get_workspace_resource(res), raft::make_extents(graph_size)); + auto default_ws_mr = raft::resource::get_workspace_resource(res); raft::matrix::fill( res, @@ -893,36 +875,38 @@ void make_reverse_graph_gpu(raft::resources const& res, raft::make_device_vector_view(d_rev_graph_count_ptr, graph_size), uint32_t(0)); - bool output_graph_device_accessible = is_ptr_device_accessible(output_graph_ptr); - dim3 threads(256, 1, 1); - dim3 blocks(1024, 1, 1); + const uint32_t batch_size = + std::min(static_cast(graph_size), static_cast(256 * 1024)); + const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; - for (uint64_t k = 0; k < output_graph_degree; k++) { - if (output_graph_device_accessible) { - kern_make_rev_graph_k<<>>( - output_graph_ptr, - d_rev_graph_ptr, - d_rev_graph_count_ptr, - static_cast(graph_size), - static_cast(output_graph_degree), - k); - } else { -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - dest_nodes(i) = output_graph_ptr[k + (output_graph_degree * i)]; - } - raft::resource::sync_stream(res); + bool output_device_accessible = is_ptr_device_accessible(output_graph_ptr); + auto d_output_graph = raft::make_device_mdarray( + res, + default_ws_mr, + raft::make_extents(output_device_accessible ? 0 : batch_size, output_graph_degree)); - raft::copy(res, d_dest_nodes.view(), dest_nodes.view()); + for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { + dim3 threads(256, 1, 1); + dim3 blocks(1024, 1, 1); - kern_make_rev_graph<<>>( - d_dest_nodes.data_handle(), - d_rev_graph_ptr, - d_rev_graph_count_ptr, - static_cast(graph_size), - static_cast(output_graph_degree)); + if (!output_device_accessible) { + size_t copy_size = + std::min(static_cast(batch_size), graph_size - i_batch * batch_size) * + output_graph_degree; + raft::copy(d_output_graph.data_handle(), + output_graph_ptr + i_batch * batch_size * output_graph_degree, + copy_size, + raft::resource::get_cuda_stream(res)); } - RAFT_LOG_DEBUG("# Making reverse graph on GPUs: %lu / %lu \r", k, output_graph_degree); + kern_rev_graph_batched<<>>( + output_device_accessible ? output_graph_ptr + (i_batch * batch_size * output_graph_degree) + : d_output_graph.data_handle(), + d_rev_graph_ptr, + d_rev_graph_count_ptr, + static_cast(graph_size), + static_cast(output_graph_degree), + static_cast(batch_size), + static_cast(i_batch)); } raft::resource::sync_stream(res); @@ -1707,8 +1691,12 @@ void optimize(raft::resources const& res, const double time_make_start = cur_time(); - make_reverse_graph_gpu( - res, d_rev_graph.data_handle(), d_rev_graph_count.data_handle(), new_graph); + make_reverse_graph_gpu(res, + new_graph.data_handle(), + d_rev_graph.data_handle(), + d_rev_graph_count.data_handle(), + graph_size, + output_graph_degree); const double time_make_end = cur_time(); RAFT_LOG_DEBUG("# Making reverse graph time: %.1lf ms", From 9e63a7c442d6725703cbb52e575b68e2625f0694 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Fri, 6 Mar 2026 12:05:48 +0000 Subject: [PATCH 14/40] add prefetch view to handle managed & host --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 17 +- cpp/src/neighbors/detail/cagra/utils.hpp | 300 +++++++++++++++++- 2 files changed, 313 insertions(+), 4 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 9f2bc09d86..28006fa133 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -1543,8 +1543,19 @@ void prune_graph_gpu(raft::resources const& res, auto host_stats = raft::make_host_vector(2); raft::matrix::fill(res, dev_stats.view(), uint64_t(0)); - device_matrix_view_from_host d_input_graph( - res, raft::make_host_matrix_view(knn_graph_ptr, graph_size, knn_graph_degree)); + // device_matrix_view_from_host d_input_graph( + // res, raft::make_host_matrix_view(knn_graph_ptr, graph_size, + // knn_graph_degree)); + + batched_device_view_from_host d_input_graph( + res, + raft::make_host_matrix_view(knn_graph_ptr, graph_size, knn_graph_degree), + /*batch_size*/ graph_size, + /*read_only*/ true, + /*host_writeback*/ false, + /*initialize*/ true, + /*evict*/ true); + auto input_view = d_input_graph.next_view(); auto d_invalid_neighbor_list = raft::make_device_scalar(res, 0u); @@ -1561,7 +1572,7 @@ void prune_graph_gpu(raft::resources const& res, const size_t prune_smem_size = num_warps * knn_graph_degree * (sizeof(IdxT) + sizeof(uint32_t)); kern_fused_prune <<>>( - d_input_graph.data_handle(), + input_view.data_handle(), output_device_accessible ? output_graph_ptr + i_batch * batch_size * output_graph_degree : d_output_graph.data_handle(), graph_size, diff --git a/cpp/src/neighbors/detail/cagra/utils.hpp b/cpp/src/neighbors/detail/cagra/utils.hpp index a59ac7fd57..c3d15e59f4 100644 --- a/cpp/src/neighbors/detail/cagra/utils.hpp +++ b/cpp/src/neighbors/detail/cagra/utils.hpp @@ -9,9 +9,13 @@ #include #include #include +#include +#include +#include +#include +#include #include #include - #include #include @@ -308,4 +312,298 @@ void copy_with_padding( raft::resource::get_cuda_stream(res))); } } + +/** + * Utility to create a batched device view from a host view + * + * This utility will create a batched device view from a host view and will handle the prefetch and + * writeback of the data Each batch can be referenced exactlyonce by calling the next_view() + * function + * + * @tparam T The type of the data + * @tparam IdxT The type of the index + * @param res The resources + * @param host_view The host view to create the batched device view from + * @param batch_size The batch size + * @param read_only Whether the data is read only (only for managed memory) + * @param host_writeback Whether to write back the data to the host (only for host memory) + * @param initialize Whether to initialize the data (only for managed memory) + * @param evict Whether to evict the data (only for managed memory) + * + * @return The batched device view + */ +template +class batched_device_view_from_host { + public: + batched_device_view_from_host(raft::resources const& res, + raft::host_matrix_view host_view, + uint64_t batch_size, + bool read_only = false, + bool host_writeback = false, + bool initialize = true, + bool evict = false) + : res_(res), + host_view_(host_view), + batch_size_(batch_size), + offset_(0), + batch_id_(0), + num_buffers_(2), + read_only_(read_only), + host_writeback_(host_writeback), + next_buffer_pos_(0), + evict_(evict), + initialize_(initialize) + { + cudaPointerAttributes attr; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, host_view.data_handle())); + mem_type_ = attr.type; + // cudaMemoryTypeUnregistered = 0 + // cudaMemoryTypeHost = 1 + // cudaMemoryTypeDevice = 2 + // cudaMemoryTypeManaged = 3 + + prefetch_stream_ = raft::resource::get_cuda_stream(res); + writeback_stream_ = raft::resource::get_cuda_stream(res); + if (res.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL)) { + if (raft::resource::get_stream_pool_size(res) >= 1) { + prefetch_stream_ = raft::resource::get_stream_from_stream_pool(res); + writeback_stream_ = raft::resource::get_stream_from_stream_pool(res); + } + } + + // allocations + if (mem_type_ == cudaMemoryTypeHost || mem_type_ == cudaMemoryTypeUnregistered) { + device_mem_[0].emplace(raft::make_device_mdarray( + res, + raft::resource::get_large_workspace_resource(res), + raft::make_extents(batch_size, host_view.extent(1)))); + device_ptr[0] = device_mem_[0]->data_handle(); + if (batch_size < static_cast(host_view.extent(0))) { + device_mem_[1].emplace(raft::make_device_mdarray( + res, + raft::resource::get_large_workspace_resource(res), + raft::make_extents(batch_size, host_view.extent(1)))); + device_ptr[1] = device_mem_[1]->data_handle(); + } + if (host_writeback_ && batch_size * 2 < static_cast(host_view.extent(0))) { + num_buffers_ = 3; + device_mem_[2].emplace(raft::make_device_mdarray( + res, + raft::resource::get_large_workspace_resource(res), + raft::make_extents(batch_size, host_view.extent(1)))); + device_ptr[2] = device_mem_[2]->data_handle(); + } + } + + // if data is managed and not for_write_ we can set the attribute on the device ptr + if (mem_type_ == cudaMemoryTypeManaged) { + // location_.type = CU_MEM_LOCATION_TYPE_DEVICE; + location_.type = cudaMemLocationTypeDevice; + location_.id = static_cast(raft::resource::get_device_id(res_)); + if (read_only_) { +#if CUDA_VERSION >= 13000 + RAFT_CUDA_TRY(cudaMemAdvise(host_view_.data_handle(), + host_view_.extent(0) * host_view_.extent(1) * sizeof(T), + cudaMemAdviseSetReadMostly, + location_)); +#else + RAFT_CUDA_TRY(cudaMemAdvise_v2(host_view_.data_handle(), + host_view_.extent(0) * host_view_.extent(1) * sizeof(T), + cudaMemAdviseSetReadMostly, + location_)); +#endif + // TODO maybe also reset upon destruction + } + } + + // prefetch next batch (0) + prefetch_next_batch(); + } + + bool prefetch_next_batch() + { + // this function will ensure the device_ptr [next_buffer_pos_] is pointing to the correct memory + // after the next synchronization with the prefetch stream + + // if data is on host and we are writing to it we will have to copy it back + // if data is on host we will have to copy it to the device_ptr + + // if data is managed and evict_ is true we can evict the data from device memory + // if data is managed we have to prefetch it + + bool next_batch_exists = offset_ < static_cast(host_view_.extent(0)); + + if (next_batch_exists) { + actual_batch_size_[next_buffer_pos_] = + next_batch_exists ? min(batch_size_, host_view_.extent(0) - offset_) : 0; + + switch (mem_type_) { + case cudaMemoryTypeManaged: +#if CUDA_VERSION >= 13000 + if (evict_ && batch_id_ > 1) { + // evict last active + CUdeviceptr dptrs[] = {device_ptr[next_buffer_pos_]}; + size_t sizes[] = {batch_size_ * host_view_.extent(1) * sizeof(T)}; + size_t prefetchLocIdxs[] = {0}; + RAFT_CUDA_TRY(cuMemDiscardBatchAsync( + dptrs, sizes, 1, &location_, prefetchLocIdxs, 1, 0, prefetch_stream_)); + } +#endif + // prefetch + device_ptr[next_buffer_pos_] = host_view_.data_handle() + offset_ * host_view_.extent(1); + if (initialize_) { + // managed API call to prefetch async +#if CUDA_VERSION >= 13000 + RAFT_CUDA_TRY(cudaMemPrefetchAsync( + device_ptr[next_buffer_pos_], + actual_batch_size_[next_buffer_pos_] * host_view_.extent(1) * sizeof(T), + location_, + 0, + prefetch_stream_)); +#else + RAFT_CUDA_TRY(cudaMemPrefetchAsync_v2( + device_ptr[next_buffer_pos_], + actual_batch_size_[next_buffer_pos_] * host_view_.extent(1) * sizeof(T), + location_, + 0, + prefetch_stream_)); +#endif + } else { + // managed API call to cuMemDiscardAndPrefetchBatchAsync (discard and prefetch batch) +#if CUDA_VERSION >= 13000 + CUdeviceptr dptrs[] = {device_ptr[next_buffer_pos_]}; + size_t sizes[] = {actual_batch_size_[next_buffer_pos_] * host_view_.extent(1) * + sizeof(T)}; + size_t prefetchLocIdxs[] = {0}; + RAFT_CUDA_TRY(cuMemDiscardAndPrefetchBatchAsync( + dptrs, sizes, 1, &location_, prefetchLocIdxs, 1, 0, prefetch_stream_)); +#endif + } + + break; + case cudaMemoryTypeHost: + case cudaMemoryTypeUnregistered: + if (host_writeback_ && batch_id_ > 1) { + writeback_stream_.synchronize(); + // copy back last active + uint32_t writeback_pos = (next_buffer_pos_ + num_buffers_ - 2) % num_buffers_; + uint64_t writeback_offset = (offset_ - 2 * batch_size_) * host_view_.extent(1); + raft::copy(host_view_.data_handle() + writeback_offset, + device_ptr[writeback_pos], + actual_batch_size_[writeback_pos] * host_view_.extent(1), + writeback_stream_); + } + if (initialize_) { + // prefetch next position + raft::copy(device_ptr[next_buffer_pos_], + host_view_.data_handle() + offset_ * host_view_.extent(1), + actual_batch_size_[next_buffer_pos_] * host_view_.extent(1), + prefetch_stream_); + } + + break; + case cudaMemoryTypeDevice: + // just move pointer to next position + device_ptr[next_buffer_pos_] = host_view_.data_handle() + offset_ * host_view_.extent(1); + break; + } + + offset_ += actual_batch_size_[next_buffer_pos_]; + // swap next_buffer_pos_ + next_buffer_pos_ = (next_buffer_pos_ + 1) % num_buffers_; + } + + return next_batch_exists; + } + + ~batched_device_view_from_host() noexcept + { + prefetch_stream_.synchronize(); + writeback_stream_.synchronize(); + raft::resource::sync_stream(res_); + + // if data is on host and for_write --> make sure to copy back last active + // if data is managed and evict --> evict last active + + // make sure to sync on prefetch & writeback stream & res + switch (mem_type_) { + case cudaMemoryTypeManaged: +#if CUDA_VERSION >= 13000 + if (evict_ && batch_id_ > 0) { + // managed API call to evict 2 + uint32_t evict_pos = (next_buffer_pos_ + num_buffers_ - 1) % num_buffers_; + CUdeviceptr dptrs[] = {device_ptr[evict_pos]}; + size_t sizes[] = {batch_size_ * host_view_.extent(1) * sizeof(T)}; + size_t prefetchLocIdxs[] = {0}; + RAFT_CUDA_TRY(cuMemDiscardBatchAsync( + dptrs, sizes, 1, &location_, prefetchLocIdxs, 1, 0, prefetch_stream_)); + } + prefetch_stream_.synchronize(); +#endif + break; + case cudaMemoryTypeHost: + case cudaMemoryTypeUnregistered: + if (host_writeback_ && batch_id_ > 0) { + // TODO managed API call to copy back last active + uint32_t writeback_pos = (next_buffer_pos_ + num_buffers_ - 1) % num_buffers_; + uint64_t writeback_offset = + (offset_ - actual_batch_size_[writeback_pos]) * host_view_.extent(1); + raft::copy(host_view_.data_handle() + writeback_offset, + device_ptr[writeback_pos], + actual_batch_size_[writeback_pos] * host_view_.extent(1), + writeback_stream_); + } + writeback_stream_.synchronize(); + break; + case cudaMemoryTypeDevice: break; + } + } + + /** + * Returns the next view of the batch + * + * This function will ensure the next batch is ready and will trigger the prefetch of the + * subsequent next batch + * + * @return The next view of the batch + */ + raft::device_matrix_view next_view() + { + RAFT_EXPECTS(batch_id_ * batch_size_ < host_view_.extent(0), "Batch index out of bounds"); + + // ensure current batch is ready + prefetch_stream_.synchronize(); + + // trigger prefetch of next batch + bool next_batch_exists = prefetch_next_batch(); + + batch_id_++; + + uint32_t current_pos = + (next_buffer_pos_ + num_buffers_ - (next_batch_exists ? 2 : 1)) % num_buffers_; + return raft::make_device_matrix_view( + device_ptr[current_pos], actual_batch_size_[current_pos], host_view_.extent(1)); + } + + private: + cudaMemoryType mem_type_; + const raft::resources& res_; + uint64_t batch_size_; + uint64_t offset_; + uint64_t num_buffers_; + bool initialize_; + rmm::cuda_stream_view prefetch_stream_; + rmm::cuda_stream_view writeback_stream_; + bool read_only_; + bool host_writeback_; + bool evict_; + int32_t next_buffer_pos_; + int32_t batch_id_; + cudaMemLocation location_; + std::optional> device_mem_[3]; + raft::host_matrix_view host_view_; + T* device_ptr[3]; + uint32_t actual_batch_size_[3]; +}; + } // namespace cuvs::neighbors::cagra::detail From a38ad525570d31882a1c86ff04eb679a6b1c4476 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Mon, 9 Mar 2026 20:49:08 +0000 Subject: [PATCH 15/40] fix batched iterator --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 123 +++---- cpp/src/neighbors/detail/cagra/utils.hpp | 313 ++++++++++-------- 2 files changed, 233 insertions(+), 203 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 28006fa133..ef8b1f8daf 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -22,6 +22,7 @@ #include #include +#include #include @@ -324,14 +325,14 @@ __device__ void thread_shift_array(T* array, uint64_t num) } template -__global__ void kern_merge_graph(IdxT* output_graph, // [batch_size, output_graph_degree] - const IdxT* const rev_graph, +__global__ void kern_merge_graph(IdxT* output_graph, // [batch_size, output_graph_degree] + const IdxT* const rev_graph, // [graph_size, output_graph_degree] uint32_t* const rev_graph_count, // [graph_size] const uint32_t graph_size, const uint32_t output_graph_degree, - const IdxT* const mst_graph, + const IdxT* const mst_graph, // [batch_size, output_graph_degree] const uint32_t mst_graph_degree, - const uint32_t* const mst_graph_num_edges_ptr, + const uint32_t* const mst_graph_num_edges_ptr, // [batch_size] const uint32_t batch_size, const uint32_t batch_id, bool guarantee_connectivity, @@ -350,12 +351,12 @@ __global__ void kern_merge_graph(IdxT* output_graph, // [batch_size, output_gra if (nid >= graph_size) { return; } - const auto mst_graph_num_edges = guarantee_connectivity ? mst_graph_num_edges_ptr[nid] : 0; + const auto mst_graph_num_edges = guarantee_connectivity ? mst_graph_num_edges_ptr[nid_batch] : 0; // If guarantee_connectivity == true, use a temporal list to merge the // neighbor lists of the graphs. if (guarantee_connectivity) { for (uint32_t i = lane_id; i < mst_graph_degree; i += raft::WarpSize) { - smem_sorted_output_graph[i] = mst_graph[nid * mst_graph_degree + i]; + smem_sorted_output_graph[i] = mst_graph[nid_batch * mst_graph_degree + i]; } __syncwarp(); for (uint32_t pruned_j = 0, output_j = mst_graph_num_edges; @@ -788,52 +789,54 @@ void merge_graph_gpu(raft::resources const& res, std::min(static_cast(graph_size), static_cast(256 * 1024)); const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; - bool output_device_accessible = is_ptr_device_accessible(output_graph_ptr); - auto d_output_graph = raft::make_device_mdarray( + batched_device_view_from_host d_output_graph( res, - default_ws_mr, - raft::make_extents(output_device_accessible ? 0 : batch_size, output_graph_degree)); + raft::make_host_matrix_view(output_graph_ptr, graph_size, output_graph_degree), + /*batch_size*/ batch_size, + /*host_writeback*/ true, + /*initialize*/ true, + /*hmm_as_managed*/ false); - device_matrix_view_from_host d_mst_graph( + batched_device_view_from_host d_mst_graph( res, raft::make_host_matrix_view( - mst_graph_ptr, guarantee_connectivity ? graph_size : 0, output_graph_degree)); + mst_graph_ptr, guarantee_connectivity ? graph_size : 0, output_graph_degree), + /*batch_size*/ batch_size, + /*host_writeback*/ false, + /*initialize*/ true, + /*hmm_as_managed*/ false); - device_matrix_view_from_host d_mst_graph_num_edges( + batched_device_view_from_host d_mst_graph_num_edges( res, - raft::make_host_matrix_view( - mst_graph_num_edges_ptr, guarantee_connectivity ? graph_size : 0, 1)); + raft::make_host_matrix_view( + mst_graph_ptr, guarantee_connectivity ? graph_size : 0, output_graph_degree), + /*batch_size*/ batch_size, + /*host_writeback*/ false, + /*initialize*/ true, + /*hmm_as_managed*/ false); const uint32_t num_warps = 4; const dim3 threads_merge(raft::WarpSize * num_warps, 1, 1); - const dim3 blocks_merge(batch_size / num_warps, 1, 1); + const dim3 blocks_merge(raft::ceildiv(batch_size, num_warps), 1, 1); const size_t merge_smem_size = num_warps * output_graph_degree * sizeof(IdxT); for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { + auto mst_graph_view = d_mst_graph.next_view(); + auto mst_graph_num_edges_view = d_mst_graph_num_edges.next_view(); + auto output_view = d_output_graph.next_view(); kern_merge_graph <<>>( - output_device_accessible ? output_graph_ptr + (i_batch * batch_size * output_graph_degree) - : d_output_graph.data_handle(), + output_view.data_handle(), d_rev_graph_ptr, d_rev_graph_count_ptr, static_cast(graph_size), static_cast(output_graph_degree), - d_mst_graph.data_handle(), + mst_graph_view.data_handle(), static_cast(output_graph_degree), - d_mst_graph_num_edges.data_handle(), + mst_graph_num_edges_view.data_handle(), batch_size, i_batch, guarantee_connectivity, d_check_num_protected_edges.data_handle()); - - if (!output_device_accessible) { - size_t copy_size = - std::min(static_cast(batch_size), graph_size - i_batch * batch_size) * - output_graph_degree; - raft::copy(output_graph_ptr + i_batch * batch_size * output_graph_degree, - d_output_graph.data_handle(), - copy_size, - raft::resource::get_cuda_stream(res)); - } } bool check_num_protected_edges = true; @@ -879,28 +882,21 @@ void make_reverse_graph_gpu(raft::resources const& res, std::min(static_cast(graph_size), static_cast(256 * 1024)); const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; - bool output_device_accessible = is_ptr_device_accessible(output_graph_ptr); - auto d_output_graph = raft::make_device_mdarray( + batched_device_view_from_host d_output_graph( res, - default_ws_mr, - raft::make_extents(output_device_accessible ? 0 : batch_size, output_graph_degree)); + raft::make_host_matrix_view(output_graph_ptr, graph_size, output_graph_degree), + /*batch_size*/ batch_size, + /*host_writeback*/ false, + /*initialize*/ true, + /*hmm_as_managed*/ false); for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { dim3 threads(256, 1, 1); dim3 blocks(1024, 1, 1); + auto output_view = d_output_graph.next_view(); - if (!output_device_accessible) { - size_t copy_size = - std::min(static_cast(batch_size), graph_size - i_batch * batch_size) * - output_graph_degree; - raft::copy(d_output_graph.data_handle(), - output_graph_ptr + i_batch * batch_size * output_graph_degree, - copy_size, - raft::resource::get_cuda_stream(res)); - } kern_rev_graph_batched<<>>( - output_device_accessible ? output_graph_ptr + (i_batch * batch_size * output_graph_degree) - : d_output_graph.data_handle(), + output_view.data_handle(), d_rev_graph_ptr, d_rev_graph_count_ptr, static_cast(graph_size), @@ -1543,38 +1539,35 @@ void prune_graph_gpu(raft::resources const& res, auto host_stats = raft::make_host_vector(2); raft::matrix::fill(res, dev_stats.view(), uint64_t(0)); - // device_matrix_view_from_host d_input_graph( - // res, raft::make_host_matrix_view(knn_graph_ptr, graph_size, - // knn_graph_degree)); - batched_device_view_from_host d_input_graph( res, raft::make_host_matrix_view(knn_graph_ptr, graph_size, knn_graph_degree), /*batch_size*/ graph_size, - /*read_only*/ true, /*host_writeback*/ false, /*initialize*/ true, - /*evict*/ true); + /*hmm_as_managed*/ true); auto input_view = d_input_graph.next_view(); - auto d_invalid_neighbor_list = raft::make_device_scalar(res, 0u); - - bool output_device_accessible = is_ptr_device_accessible(output_graph_ptr); - auto d_output_graph = raft::make_device_mdarray( + batched_device_view_from_host d_output_graph( res, - default_ws_mr, - raft::make_extents(output_device_accessible ? 0 : batch_size, output_graph_degree)); + raft::make_host_matrix_view(output_graph_ptr, graph_size, output_graph_degree), + /*batch_size*/ batch_size, + /*host_writeback*/ true, + /*initialize*/ false, + /*hmm_as_managed*/ false); + + auto d_invalid_neighbor_list = raft::make_device_scalar(res, 0u); for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { + auto output_view = d_output_graph.next_view(); const uint32_t num_warps = 4; const dim3 threads_prune(raft::WarpSize * num_warps, 1, 1); - const dim3 blocks_prune(batch_size / num_warps, 1, 1); + const dim3 blocks_prune(raft::ceildiv(batch_size, num_warps), 1, 1); const size_t prune_smem_size = num_warps * knn_graph_degree * (sizeof(IdxT) + sizeof(uint32_t)); kern_fused_prune <<>>( input_view.data_handle(), - output_device_accessible ? output_graph_ptr + i_batch * batch_size * output_graph_degree - : d_output_graph.data_handle(), + output_view.data_handle(), graph_size, knn_graph_degree, output_graph_degree, @@ -1583,16 +1576,6 @@ void prune_graph_gpu(raft::resources const& res, d_invalid_neighbor_list.data_handle(), dev_stats.data_handle()); - if (!output_device_accessible) { - size_t copy_size = - std::min(static_cast(batch_size), graph_size - i_batch * batch_size) * - output_graph_degree; - raft::copy(output_graph_ptr + i_batch * batch_size * output_graph_degree, - d_output_graph.data_handle(), - copy_size, - raft::resource::get_cuda_stream(res)); - } - raft::resource::sync_stream(res); RAFT_LOG_DEBUG( "# Pruning kNN Graph on GPUs (%.1lf %%)\r", diff --git a/cpp/src/neighbors/detail/cagra/utils.hpp b/cpp/src/neighbors/detail/cagra/utils.hpp index c3d15e59f4..df6ef1ce6f 100644 --- a/cpp/src/neighbors/detail/cagra/utils.hpp +++ b/cpp/src/neighbors/detail/cagra/utils.hpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include @@ -23,6 +24,7 @@ #include #include +#include #include namespace cuvs::neighbors::cagra::detail { @@ -328,7 +330,7 @@ void copy_with_padding( * @param read_only Whether the data is read only (only for managed memory) * @param host_writeback Whether to write back the data to the host (only for host memory) * @param initialize Whether to initialize the data (only for managed memory) - * @param evict Whether to evict the data (only for managed memory) + * @param discard Whether to discard the data (only for managed memory) * * @return The batched device view */ @@ -338,22 +340,24 @@ class batched_device_view_from_host { batched_device_view_from_host(raft::resources const& res, raft::host_matrix_view host_view, uint64_t batch_size, - bool read_only = false, bool host_writeback = false, bool initialize = true, - bool evict = false) + bool hmm_as_managed = false) : res_(res), host_view_(host_view), batch_size_(batch_size), offset_(0), - batch_id_(0), + batch_id_(-2), num_buffers_(2), - read_only_(read_only), host_writeback_(host_writeback), - next_buffer_pos_(0), - evict_(evict), - initialize_(initialize) + initialize_(initialize), + hmm_as_managed_(hmm_as_managed) { + if (host_view.extent(0) == 0) { + mem_type_ = cudaMemoryTypeDevice; + return; + } + cudaPointerAttributes attr; RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, host_view.data_handle())); mem_type_ = attr.type; @@ -361,27 +365,35 @@ class batched_device_view_from_host { // cudaMemoryTypeHost = 1 // cudaMemoryTypeDevice = 2 // cudaMemoryTypeManaged = 3 + // + // On HMM systems, unregistered (malloc) memory can have devicePointer != nullptr, + // meaning it's directly accessible from the GPU. Treat it like managed memory: + if (mem_type_ == cudaMemoryTypeUnregistered && attr.devicePointer != nullptr && + hmm_as_managed) { + mem_type_ = cudaMemoryTypeManaged; + } - prefetch_stream_ = raft::resource::get_cuda_stream(res); - writeback_stream_ = raft::resource::get_cuda_stream(res); - if (res.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL)) { - if (raft::resource::get_stream_pool_size(res) >= 1) { - prefetch_stream_ = raft::resource::get_stream_from_stream_pool(res); - writeback_stream_ = raft::resource::get_stream_from_stream_pool(res); - } + if (res.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL) && + raft::resource::get_stream_pool_size(res) >= 1) { + prefetch_stream_ = raft::resource::get_stream_from_stream_pool(res); + writeback_stream_ = raft::resource::get_stream_from_stream_pool(res); + } else { + local_stream_pool_ = std::make_shared(2); + prefetch_stream_ = local_stream_pool_.value()->get_stream(); + writeback_stream_ = local_stream_pool_.value()->get_stream(); } // allocations if (mem_type_ == cudaMemoryTypeHost || mem_type_ == cudaMemoryTypeUnregistered) { device_mem_[0].emplace(raft::make_device_mdarray( res, - raft::resource::get_large_workspace_resource(res), + raft::resource::get_workspace_resource(res), raft::make_extents(batch_size, host_view.extent(1)))); device_ptr[0] = device_mem_[0]->data_handle(); if (batch_size < static_cast(host_view.extent(0))) { device_mem_[1].emplace(raft::make_device_mdarray( res, - raft::resource::get_large_workspace_resource(res), + raft::resource::get_workspace_resource(res), raft::make_extents(batch_size, host_view.extent(1)))); device_ptr[1] = device_mem_[1]->data_handle(); } @@ -389,7 +401,7 @@ class batched_device_view_from_host { num_buffers_ = 3; device_mem_[2].emplace(raft::make_device_mdarray( res, - raft::resource::get_large_workspace_resource(res), + raft::resource::get_workspace_resource(res), raft::make_extents(batch_size, host_view.extent(1)))); device_ptr[2] = device_mem_[2]->data_handle(); } @@ -400,18 +412,9 @@ class batched_device_view_from_host { // location_.type = CU_MEM_LOCATION_TYPE_DEVICE; location_.type = cudaMemLocationTypeDevice; location_.id = static_cast(raft::resource::get_device_id(res_)); - if (read_only_) { -#if CUDA_VERSION >= 13000 - RAFT_CUDA_TRY(cudaMemAdvise(host_view_.data_handle(), - host_view_.extent(0) * host_view_.extent(1) * sizeof(T), - cudaMemAdviseSetReadMostly, - location_)); -#else - RAFT_CUDA_TRY(cudaMemAdvise_v2(host_view_.data_handle(), - host_view_.extent(0) * host_view_.extent(1) * sizeof(T), - cudaMemAdviseSetReadMostly, - location_)); -#endif + if (!host_writeback_) { + advise_read_mostly(host_view_.data_handle(), + host_view_.extent(0) * host_view_.extent(1) * sizeof(T)); // TODO maybe also reset upon destruction } } @@ -422,95 +425,72 @@ class batched_device_view_from_host { bool prefetch_next_batch() { - // this function will ensure the device_ptr [next_buffer_pos_] is pointing to the correct memory - // after the next synchronization with the prefetch stream + batch_id_++; + + // ensure previous batch at position batch_id_ is ready + prefetch_stream_.synchronize(); + if (host_writeback_) { writeback_stream_.synchronize(); } - // if data is on host and we are writing to it we will have to copy it back - // if data is on host we will have to copy it to the device_ptr + // this step will + // * write back data from batch_id_ - 1 + // * prefetch data for batch_id_ + 1 - // if data is managed and evict_ is true we can evict the data from device memory - // if data is managed we have to prefetch it + // if data is on host and host_writeback_ is true we will have to copy it back + // if data is on host and initialize_ is true we will have to copy it to the device_ptr + + // if data is managed and !host_writeback_ we can discard the data from device memory + // if data is managed and initialize_ is true we can prefetch it to the device + // if data is managed and !initialize_ we can discard and prefetch the data location + + // if data is on device only this is almost a noop, just prepping the pointers + + RAFT_EXPECTS(offset_ <= host_view_.extent(0), "Offset out of bounds"); bool next_batch_exists = offset_ < static_cast(host_view_.extent(0)); if (next_batch_exists) { - actual_batch_size_[next_buffer_pos_] = - next_batch_exists ? min(batch_size_, host_view_.extent(0) - offset_) : 0; + // synchronize to ensure all previous operations are completed + // in particular all work on batch_id_ - 1 + raft::resource::sync_stream(res_); + + int32_t prefetch_pos = (batch_id_ + 1) % num_buffers_; + actual_batch_size_[prefetch_pos] = min(batch_size_, host_view_.extent(0) - offset_); switch (mem_type_) { case cudaMemoryTypeManaged: -#if CUDA_VERSION >= 13000 - if (evict_ && batch_id_ > 1) { - // evict last active - CUdeviceptr dptrs[] = {device_ptr[next_buffer_pos_]}; - size_t sizes[] = {batch_size_ * host_view_.extent(1) * sizeof(T)}; - size_t prefetchLocIdxs[] = {0}; - RAFT_CUDA_TRY(cuMemDiscardBatchAsync( - dptrs, sizes, 1, &location_, prefetchLocIdxs, 1, 0, prefetch_stream_)); + if (!host_writeback_ && batch_id_ > 1) { + uint32_t discard_pos = (batch_id_ - 1) % num_buffers_; + size_t discard_size = batch_size_ * host_view_.extent(1) * sizeof(T); + discard_managed_region(device_ptr[discard_pos], discard_size); } -#endif - // prefetch - device_ptr[next_buffer_pos_] = host_view_.data_handle() + offset_ * host_view_.extent(1); - if (initialize_) { - // managed API call to prefetch async -#if CUDA_VERSION >= 13000 - RAFT_CUDA_TRY(cudaMemPrefetchAsync( - device_ptr[next_buffer_pos_], - actual_batch_size_[next_buffer_pos_] * host_view_.extent(1) * sizeof(T), - location_, - 0, - prefetch_stream_)); -#else - RAFT_CUDA_TRY(cudaMemPrefetchAsync_v2( - device_ptr[next_buffer_pos_], - actual_batch_size_[next_buffer_pos_] * host_view_.extent(1) * sizeof(T), - location_, - 0, - prefetch_stream_)); -#endif - } else { - // managed API call to cuMemDiscardAndPrefetchBatchAsync (discard and prefetch batch) -#if CUDA_VERSION >= 13000 - CUdeviceptr dptrs[] = {device_ptr[next_buffer_pos_]}; - size_t sizes[] = {actual_batch_size_[next_buffer_pos_] * host_view_.extent(1) * - sizeof(T)}; - size_t prefetchLocIdxs[] = {0}; - RAFT_CUDA_TRY(cuMemDiscardAndPrefetchBatchAsync( - dptrs, sizes, 1, &location_, prefetchLocIdxs, 1, 0, prefetch_stream_)); -#endif - } - + // prefetch next position + device_ptr[prefetch_pos] = host_view_.data_handle() + offset_ * host_view_.extent(1); + prefetch_managed_region( + device_ptr[prefetch_pos], + actual_batch_size_[prefetch_pos] * host_view_.extent(1) * sizeof(T)); break; case cudaMemoryTypeHost: case cudaMemoryTypeUnregistered: - if (host_writeback_ && batch_id_ > 1) { - writeback_stream_.synchronize(); + if (host_writeback_ && batch_id_ > 0) { // copy back last active - uint32_t writeback_pos = (next_buffer_pos_ + num_buffers_ - 2) % num_buffers_; - uint64_t writeback_offset = (offset_ - 2 * batch_size_) * host_view_.extent(1); - raft::copy(host_view_.data_handle() + writeback_offset, - device_ptr[writeback_pos], - actual_batch_size_[writeback_pos] * host_view_.extent(1), - writeback_stream_); + uint32_t writeback_pos = (batch_id_ - 1) % num_buffers_; + uint64_t writeback_offset = (batch_id_ - 1) * batch_size_; + writeback_from_device_to_host(device_ptr[writeback_pos], writeback_offset, batch_size_); } if (initialize_) { // prefetch next position - raft::copy(device_ptr[next_buffer_pos_], - host_view_.data_handle() + offset_ * host_view_.extent(1), - actual_batch_size_[next_buffer_pos_] * host_view_.extent(1), - prefetch_stream_); + prefetch_from_host_to_device( + device_ptr[prefetch_pos], offset_, actual_batch_size_[prefetch_pos]); } break; case cudaMemoryTypeDevice: // just move pointer to next position - device_ptr[next_buffer_pos_] = host_view_.data_handle() + offset_ * host_view_.extent(1); + device_ptr[prefetch_pos] = host_view_.data_handle() + offset_ * host_view_.extent(1); break; } - offset_ += actual_batch_size_[next_buffer_pos_]; - // swap next_buffer_pos_ - next_buffer_pos_ = (next_buffer_pos_ + 1) % num_buffers_; + offset_ += actual_batch_size_[prefetch_pos]; } return next_batch_exists; @@ -525,33 +505,36 @@ class batched_device_view_from_host { // if data is on host and for_write --> make sure to copy back last active // if data is managed and evict --> evict last active - // make sure to sync on prefetch & writeback stream & res + // make sure to sync on prefetch stream & res switch (mem_type_) { case cudaMemoryTypeManaged: -#if CUDA_VERSION >= 13000 - if (evict_ && batch_id_ > 0) { - // managed API call to evict 2 - uint32_t evict_pos = (next_buffer_pos_ + num_buffers_ - 1) % num_buffers_; - CUdeviceptr dptrs[] = {device_ptr[evict_pos]}; - size_t sizes[] = {batch_size_ * host_view_.extent(1) * sizeof(T)}; - size_t prefetchLocIdxs[] = {0}; - RAFT_CUDA_TRY(cuMemDiscardBatchAsync( - dptrs, sizes, 1, &location_, prefetchLocIdxs, 1, 0, prefetch_stream_)); + if (!host_writeback_) { + uint32_t discard_pos = batch_id_ % num_buffers_; + size_t discard_size_rows = actual_batch_size_[discard_pos]; + if (batch_id_ > 0) { + discard_pos = (batch_id_ - 1) % num_buffers_; + discard_size_rows += batch_size_; + } + discard_managed_region(device_ptr[discard_pos], + discard_size_rows * host_view_.extent(1) * sizeof(T)); } - prefetch_stream_.synchronize(); -#endif + writeback_stream_.synchronize(); break; case cudaMemoryTypeHost: case cudaMemoryTypeUnregistered: - if (host_writeback_ && batch_id_ > 0) { - // TODO managed API call to copy back last active - uint32_t writeback_pos = (next_buffer_pos_ + num_buffers_ - 1) % num_buffers_; - uint64_t writeback_offset = - (offset_ - actual_batch_size_[writeback_pos]) * host_view_.extent(1); - raft::copy(host_view_.data_handle() + writeback_offset, - device_ptr[writeback_pos], - actual_batch_size_[writeback_pos] * host_view_.extent(1), - writeback_stream_); + if (host_writeback_) { + uint32_t writeback_pos_last = batch_id_ % num_buffers_; + if (batch_id_ > 0) { + uint32_t writeback_pos = (batch_id_ - 1) % num_buffers_; + uint64_t writeback_offset = (batch_id_ - 1) * batch_size_; + writeback_from_device_to_host(device_ptr[writeback_pos], writeback_offset, batch_size_); + } + { + uint64_t writeback_offset_last = batch_id_ * batch_size_; + writeback_from_device_to_host(device_ptr[writeback_pos_last], + writeback_offset_last, + actual_batch_size_[writeback_pos_last]); + } } writeback_stream_.synchronize(); break; @@ -569,39 +552,103 @@ class batched_device_view_from_host { */ raft::device_matrix_view next_view() { - RAFT_EXPECTS(batch_id_ * batch_size_ < host_view_.extent(0), "Batch index out of bounds"); - - // ensure current batch is ready - prefetch_stream_.synchronize(); + // special case for empty host view + if (host_view_.extent(0) == 0) { + return raft::make_device_matrix_view(nullptr, 0, host_view_.extent(1)); + } // trigger prefetch of next batch bool next_batch_exists = prefetch_next_batch(); - batch_id_++; + RAFT_EXPECTS(batch_id_ * batch_size_ < host_view_.extent(0), "Batch index out of bounds"); - uint32_t current_pos = - (next_buffer_pos_ + num_buffers_ - (next_batch_exists ? 2 : 1)) % num_buffers_; + uint32_t current_pos = batch_id_ % num_buffers_; return raft::make_device_matrix_view( device_ptr[current_pos], actual_batch_size_[current_pos], host_view_.extent(1)); } private: - cudaMemoryType mem_type_; - const raft::resources& res_; - uint64_t batch_size_; - uint64_t offset_; - uint64_t num_buffers_; - bool initialize_; + void advise_read_mostly(T* ptr, size_t size) + { +#if CUDA_VERSION >= 13000 + RAFT_CUDA_TRY(cudaMemAdvise(ptr, size, cudaMemAdviseSetReadMostly, location_)); +#else + RAFT_CUDA_TRY(cudaMemAdvise_v2(ptr, size, cudaMemAdviseSetReadMostly, location_)); +#endif + } + + void discard_managed_region(T* dev_ptr, size_t size) + { +#if CUDA_VERSION >= 13000 + void* dptrs[1] = {dev_ptr}; + size_t sizes[1] = {size}; + RAFT_CUDA_TRY(cudaMemDiscardBatchAsync(dptrs, sizes, 1, 0, writeback_stream_)); +#endif + // FIXME: CUDA12 does not support discard + } + + void prefetch_managed_region(T* dev_ptr, size_t size) + { +#if CUDA_VERSION >= 13000 + if (initialize_) { + RAFT_CUDA_TRY(cudaMemPrefetchAsync(dev_ptr, size, location_, 0, prefetch_stream_)); + } else { + void* dptrs[1] = {dev_ptr}; + size_t sizes[1] = {size}; + RAFT_CUDA_TRY( + cudaMemDiscardAndPrefetchBatchAsync(dptrs, sizes, 1, location_, 0, prefetch_stream_)); + } +#else + // FIXME: CUDA12 does not support discard - so we just prefetch + if (initialize_) { + RAFT_CUDA_TRY(cudaMemPrefetchAsync_v2(dev_ptr, size, location_, 0, prefetch_stream_)); + } else { + RAFT_CUDA_TRY(cudaMemPrefetchAsync_v2(dev_ptr, size, location_, 0, prefetch_stream_)); + } +#endif + } + + void prefetch_from_host_to_device(T* dev_ptr, size_t src_row_offset, size_t num_rows) + { + raft::copy(dev_ptr, + host_view_.data_handle() + src_row_offset * host_view_.extent(1), + num_rows * host_view_.extent(1), + prefetch_stream_); + } + + void writeback_from_device_to_host(T* dev_ptr, size_t dst_row_offset, size_t num_rows) + { + raft::copy(host_view_.data_handle() + dst_row_offset * host_view_.extent(1), + dev_ptr, + num_rows * host_view_.extent(1), + writeback_stream_); + } + + // stream pool for local streams + std::optional> local_stream_pool_; rmm::cuda_stream_view prefetch_stream_; rmm::cuda_stream_view writeback_stream_; - bool read_only_; - bool host_writeback_; - bool evict_; - int32_t next_buffer_pos_; + + // configuration + const raft::resources& res_; + bool initialize_; // initialize the data on the device + bool host_writeback_; // write back the data to the host + bool hmm_as_managed_; // treat unregistered memory as managed memory + + // batch position information + uint64_t batch_size_; int32_t batch_id_; + uint64_t offset_; + cudaMemLocation location_; - std::optional> device_mem_[3]; + + // input pointer information + cudaMemoryType mem_type_; raft::host_matrix_view host_view_; + + // internal device buffers + uint64_t num_buffers_; + std::optional> device_mem_[3]; T* device_ptr[3]; uint32_t actual_batch_size_[3]; }; From 89b0d1c25bbff782cf906be7d9b2dc58a5927116 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Mon, 9 Mar 2026 21:57:25 +0000 Subject: [PATCH 16/40] implement fallback / simplify strategy --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 18 +-- cpp/src/neighbors/detail/cagra/utils.hpp | 110 ++++++++++-------- 2 files changed, 66 insertions(+), 62 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index ef8b1f8daf..a6e4c08350 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -794,8 +794,7 @@ void merge_graph_gpu(raft::resources const& res, raft::make_host_matrix_view(output_graph_ptr, graph_size, output_graph_degree), /*batch_size*/ batch_size, /*host_writeback*/ true, - /*initialize*/ true, - /*hmm_as_managed*/ false); + /*initialize*/ true); batched_device_view_from_host d_mst_graph( res, @@ -803,8 +802,7 @@ void merge_graph_gpu(raft::resources const& res, mst_graph_ptr, guarantee_connectivity ? graph_size : 0, output_graph_degree), /*batch_size*/ batch_size, /*host_writeback*/ false, - /*initialize*/ true, - /*hmm_as_managed*/ false); + /*initialize*/ true); batched_device_view_from_host d_mst_graph_num_edges( res, @@ -812,8 +810,7 @@ void merge_graph_gpu(raft::resources const& res, mst_graph_ptr, guarantee_connectivity ? graph_size : 0, output_graph_degree), /*batch_size*/ batch_size, /*host_writeback*/ false, - /*initialize*/ true, - /*hmm_as_managed*/ false); + /*initialize*/ true); const uint32_t num_warps = 4; const dim3 threads_merge(raft::WarpSize * num_warps, 1, 1); @@ -887,8 +884,7 @@ void make_reverse_graph_gpu(raft::resources const& res, raft::make_host_matrix_view(output_graph_ptr, graph_size, output_graph_degree), /*batch_size*/ batch_size, /*host_writeback*/ false, - /*initialize*/ true, - /*hmm_as_managed*/ false); + /*initialize*/ true); for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { dim3 threads(256, 1, 1); @@ -1544,8 +1540,7 @@ void prune_graph_gpu(raft::resources const& res, raft::make_host_matrix_view(knn_graph_ptr, graph_size, knn_graph_degree), /*batch_size*/ graph_size, /*host_writeback*/ false, - /*initialize*/ true, - /*hmm_as_managed*/ true); + /*initialize*/ true); auto input_view = d_input_graph.next_view(); batched_device_view_from_host d_output_graph( @@ -1553,8 +1548,7 @@ void prune_graph_gpu(raft::resources const& res, raft::make_host_matrix_view(output_graph_ptr, graph_size, output_graph_degree), /*batch_size*/ batch_size, /*host_writeback*/ true, - /*initialize*/ false, - /*hmm_as_managed*/ false); + /*initialize*/ false); auto d_invalid_neighbor_list = raft::make_device_scalar(res, 0u); diff --git a/cpp/src/neighbors/detail/cagra/utils.hpp b/cpp/src/neighbors/detail/cagra/utils.hpp index df6ef1ce6f..8f6cfb063f 100644 --- a/cpp/src/neighbors/detail/cagra/utils.hpp +++ b/cpp/src/neighbors/detail/cagra/utils.hpp @@ -327,22 +327,25 @@ void copy_with_padding( * @param res The resources * @param host_view The host view to create the batched device view from * @param batch_size The batch size - * @param read_only Whether the data is read only (only for managed memory) * @param host_writeback Whether to write back the data to the host (only for host memory) * @param initialize Whether to initialize the data (only for managed memory) - * @param discard Whether to discard the data (only for managed memory) * * @return The batched device view */ template class batched_device_view_from_host { public: + enum class memory_strategy { + device_only, // data is on device only (no copy needed) + copy_device, // data is explicitly moved to/from device buffers + managed_only, // data is on managed memory (system managed) + }; + batched_device_view_from_host(raft::resources const& res, raft::host_matrix_view host_view, uint64_t batch_size, bool host_writeback = false, - bool initialize = true, - bool hmm_as_managed = false) + bool initialize = true) : res_(res), host_view_(host_view), batch_size_(batch_size), @@ -350,29 +353,23 @@ class batched_device_view_from_host { batch_id_(-2), num_buffers_(2), host_writeback_(host_writeback), - initialize_(initialize), - hmm_as_managed_(hmm_as_managed) + initialize_(initialize) { if (host_view.extent(0) == 0) { - mem_type_ = cudaMemoryTypeDevice; + mem_strategy_ = memory_strategy::device_only; return; } cudaPointerAttributes attr; RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, host_view.data_handle())); - mem_type_ = attr.type; - // cudaMemoryTypeUnregistered = 0 - // cudaMemoryTypeHost = 1 - // cudaMemoryTypeDevice = 2 - // cudaMemoryTypeManaged = 3 - // - // On HMM systems, unregistered (malloc) memory can have devicePointer != nullptr, - // meaning it's directly accessible from the GPU. Treat it like managed memory: - if (mem_type_ == cudaMemoryTypeUnregistered && attr.devicePointer != nullptr && - hmm_as_managed) { - mem_type_ = cudaMemoryTypeManaged; + switch (attr.type) { + case cudaMemoryTypeUnregistered: + case cudaMemoryTypeHost: + case cudaMemoryTypeManaged: mem_strategy_ = memory_strategy::copy_device; break; + case cudaMemoryTypeDevice: mem_strategy_ = memory_strategy::device_only; break; } + // setup streams if (res.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL) && raft::resource::get_stream_pool_size(res) >= 1) { prefetch_stream_ = raft::resource::get_stream_from_stream_pool(res); @@ -383,32 +380,48 @@ class batched_device_view_from_host { writeback_stream_ = local_stream_pool_.value()->get_stream(); } - // allocations - if (mem_type_ == cudaMemoryTypeHost || mem_type_ == cudaMemoryTypeUnregistered) { - device_mem_[0].emplace(raft::make_device_mdarray( - res, - raft::resource::get_workspace_resource(res), - raft::make_extents(batch_size, host_view.extent(1)))); - device_ptr[0] = device_mem_[0]->data_handle(); - if (batch_size < static_cast(host_view.extent(0))) { - device_mem_[1].emplace(raft::make_device_mdarray( - res, - raft::resource::get_workspace_resource(res), - raft::make_extents(batch_size, host_view.extent(1)))); - device_ptr[1] = device_mem_[1]->data_handle(); - } - if (host_writeback_ && batch_size * 2 < static_cast(host_view.extent(0))) { - num_buffers_ = 3; - device_mem_[2].emplace(raft::make_device_mdarray( + // buffer allocations + if (mem_strategy_ == memory_strategy::copy_device) { + try { + device_mem_[0].emplace(raft::make_device_mdarray( res, raft::resource::get_workspace_resource(res), raft::make_extents(batch_size, host_view.extent(1)))); - device_ptr[2] = device_mem_[2]->data_handle(); + device_ptr[0] = device_mem_[0]->data_handle(); + if (batch_size < static_cast(host_view.extent(0))) { + device_mem_[1].emplace(raft::make_device_mdarray( + res, + raft::resource::get_workspace_resource(res), + raft::make_extents(batch_size, host_view.extent(1)))); + device_ptr[1] = device_mem_[1]->data_handle(); + } + if (host_writeback_ && batch_size * 2 < static_cast(host_view.extent(0))) { + num_buffers_ = 3; + device_mem_[2].emplace(raft::make_device_mdarray( + res, + raft::resource::get_workspace_resource(res), + raft::make_extents(batch_size, host_view.extent(1)))); + device_ptr[2] = device_mem_[2]->data_handle(); + } + } catch (std::bad_alloc& e) { + RAFT_LOG_DEBUG("Insufficient memory for device buffers"); + if (attr.devicePointer != nullptr) { + mem_strategy_ = memory_strategy::managed_only; + } else { + throw std::bad_alloc(); + } + } catch (raft::logic_error& e) { + RAFT_LOG_DEBUG("Insufficient memory for device buffers (logic error)"); + if (attr.devicePointer != nullptr) { + mem_strategy_ = memory_strategy::managed_only; + } else { + throw raft::logic_error("Insufficient memory for device buffers (logic error)"); + } } } // if data is managed and not for_write_ we can set the attribute on the device ptr - if (mem_type_ == cudaMemoryTypeManaged) { + if (mem_strategy_ == memory_strategy::managed_only) { // location_.type = CU_MEM_LOCATION_TYPE_DEVICE; location_.type = cudaMemLocationTypeDevice; location_.id = static_cast(raft::resource::get_device_id(res_)); @@ -428,7 +441,7 @@ class batched_device_view_from_host { batch_id_++; // ensure previous batch at position batch_id_ is ready - prefetch_stream_.synchronize(); + if (initialize_) { prefetch_stream_.synchronize(); } if (host_writeback_) { writeback_stream_.synchronize(); } // this step will @@ -456,8 +469,8 @@ class batched_device_view_from_host { int32_t prefetch_pos = (batch_id_ + 1) % num_buffers_; actual_batch_size_[prefetch_pos] = min(batch_size_, host_view_.extent(0) - offset_); - switch (mem_type_) { - case cudaMemoryTypeManaged: + switch (mem_strategy_) { + case memory_strategy::managed_only: if (!host_writeback_ && batch_id_ > 1) { uint32_t discard_pos = (batch_id_ - 1) % num_buffers_; size_t discard_size = batch_size_ * host_view_.extent(1) * sizeof(T); @@ -469,8 +482,7 @@ class batched_device_view_from_host { device_ptr[prefetch_pos], actual_batch_size_[prefetch_pos] * host_view_.extent(1) * sizeof(T)); break; - case cudaMemoryTypeHost: - case cudaMemoryTypeUnregistered: + case memory_strategy::copy_device: if (host_writeback_ && batch_id_ > 0) { // copy back last active uint32_t writeback_pos = (batch_id_ - 1) % num_buffers_; @@ -484,7 +496,7 @@ class batched_device_view_from_host { } break; - case cudaMemoryTypeDevice: + case memory_strategy::device_only: // just move pointer to next position device_ptr[prefetch_pos] = host_view_.data_handle() + offset_ * host_view_.extent(1); break; @@ -506,8 +518,8 @@ class batched_device_view_from_host { // if data is managed and evict --> evict last active // make sure to sync on prefetch stream & res - switch (mem_type_) { - case cudaMemoryTypeManaged: + switch (mem_strategy_) { + case memory_strategy::managed_only: if (!host_writeback_) { uint32_t discard_pos = batch_id_ % num_buffers_; size_t discard_size_rows = actual_batch_size_[discard_pos]; @@ -520,8 +532,7 @@ class batched_device_view_from_host { } writeback_stream_.synchronize(); break; - case cudaMemoryTypeHost: - case cudaMemoryTypeUnregistered: + case memory_strategy::copy_device: if (host_writeback_) { uint32_t writeback_pos_last = batch_id_ % num_buffers_; if (batch_id_ > 0) { @@ -538,7 +549,7 @@ class batched_device_view_from_host { } writeback_stream_.synchronize(); break; - case cudaMemoryTypeDevice: break; + case memory_strategy::device_only: break; } } @@ -630,10 +641,10 @@ class batched_device_view_from_host { rmm::cuda_stream_view writeback_stream_; // configuration + memory_strategy mem_strategy_; const raft::resources& res_; bool initialize_; // initialize the data on the device bool host_writeback_; // write back the data to the host - bool hmm_as_managed_; // treat unregistered memory as managed memory // batch position information uint64_t batch_size_; @@ -643,7 +654,6 @@ class batched_device_view_from_host { cudaMemLocation location_; // input pointer information - cudaMemoryType mem_type_; raft::host_matrix_view host_view_; // internal device buffers From d0e3daefdfc7fcdec3ceaaa62a8d95134a726f15 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Tue, 10 Mar 2026 17:31:23 +0000 Subject: [PATCH 17/40] add logging / remove stats compute --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 51 +++------------ cpp/src/neighbors/detail/cagra/utils.hpp | 62 ++++++++++++------- 2 files changed, 46 insertions(+), 67 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index a6e4c08350..b5e055820d 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -648,44 +648,6 @@ __global__ void kern_mst_opt_postprocessing(IdxT* outgoing_num_edges, // [graph } } -template -uint64_t pos_in_array(T val, const T* array, uint64_t num) -{ - for (uint64_t i = 0; i < num; i++) { - if (val == array[i]) { return i; } - } - return num; -} - -template -void shift_array(T* array, uint64_t num) -{ - for (uint64_t i = num; i > 0; i--) { - array[i] = array[i - 1]; - } -} - -template -void log_replaced_edges_stats(const IdxT* output_graph_ptr, - uint64_t graph_size, - uint64_t output_graph_degree) -{ - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/stats"); - uint64_t num_replaced_edges = 0; -#pragma omp parallel for reduction(+ : num_replaced_edges) - for (uint64_t i = 0; i < graph_size; i++) { - for (uint64_t k = 0; k < output_graph_degree; k++) { - const uint64_t j = output_graph_ptr[k + (output_graph_degree * i)]; - const uint64_t pos = - pos_in_array(j, output_graph_ptr + (output_graph_degree * i), output_graph_degree); - if (pos == output_graph_degree) { num_replaced_edges += 1; } - } - } - RAFT_LOG_DEBUG("# Average number of replaced edges per node: %.2f", - (double)num_replaced_edges / graph_size); -} - template void log_incoming_edges_histogram(const IdxT* output_graph_ptr, uint64_t graph_size, @@ -755,7 +717,10 @@ void check_duplicates_and_out_of_range(const IdxT* output_graph_ptr, for (uint32_t k = j + 1; k < output_graph_degree; k++) { const auto neighbor_b = my_out_graph[k]; - if (neighbor_a == neighbor_b) { num_dup++; } + if (neighbor_a == neighbor_b) { + num_dup++; + break; + } } } } @@ -1606,10 +1571,10 @@ void prune_graph_gpu(raft::resources const& res, } // TODO allow pinned input for both knn_graph and new_graph -template +template void optimize(raft::resources const& res, - InOutMatrixView knn_graph, - InOutMatrixView new_graph, + InputMatrixView knn_graph, + OutputMatrixView new_graph, const bool guarantee_connectivity = true, const bool use_gpu = true) { @@ -1707,8 +1672,6 @@ void optimize(raft::resources const& res, if (is_ptr_host_accessible(new_graph.data_handle())) { // following checks require host access - log_replaced_edges_stats(new_graph.data_handle(), graph_size, output_graph_degree); - log_incoming_edges_histogram(new_graph.data_handle(), graph_size, output_graph_degree); check_duplicates_and_out_of_range( diff --git a/cpp/src/neighbors/detail/cagra/utils.hpp b/cpp/src/neighbors/detail/cagra/utils.hpp index 8f6cfb063f..75883a9636 100644 --- a/cpp/src/neighbors/detail/cagra/utils.hpp +++ b/cpp/src/neighbors/detail/cagra/utils.hpp @@ -360,25 +360,18 @@ class batched_device_view_from_host { return; } - cudaPointerAttributes attr; - RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, host_view.data_handle())); - switch (attr.type) { + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr_, host_view.data_handle())); + switch (attr_.type) { case cudaMemoryTypeUnregistered: case cudaMemoryTypeHost: case cudaMemoryTypeManaged: mem_strategy_ = memory_strategy::copy_device; break; case cudaMemoryTypeDevice: mem_strategy_ = memory_strategy::device_only; break; } - // setup streams - if (res.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL) && - raft::resource::get_stream_pool_size(res) >= 1) { - prefetch_stream_ = raft::resource::get_stream_from_stream_pool(res); - writeback_stream_ = raft::resource::get_stream_from_stream_pool(res); - } else { - local_stream_pool_ = std::make_shared(2); - prefetch_stream_ = local_stream_pool_.value()->get_stream(); - writeback_stream_ = local_stream_pool_.value()->get_stream(); - } + RAFT_LOG_DEBUG("Memory strategy: %d for type %d, size %zu", + static_cast(mem_strategy_), + static_cast(attr_.type), + host_view.extent(0) * host_view.extent(1) * sizeof(T)); // buffer allocations if (mem_strategy_ == memory_strategy::copy_device) { @@ -405,14 +398,14 @@ class batched_device_view_from_host { } } catch (std::bad_alloc& e) { RAFT_LOG_DEBUG("Insufficient memory for device buffers"); - if (attr.devicePointer != nullptr) { + if (attr_.devicePointer != nullptr) { mem_strategy_ = memory_strategy::managed_only; } else { throw std::bad_alloc(); } } catch (raft::logic_error& e) { RAFT_LOG_DEBUG("Insufficient memory for device buffers (logic error)"); - if (attr.devicePointer != nullptr) { + if (attr_.devicePointer != nullptr) { mem_strategy_ = memory_strategy::managed_only; } else { throw raft::logic_error("Insufficient memory for device buffers (logic error)"); @@ -420,6 +413,17 @@ class batched_device_view_from_host { } } + // setup streams + if (res.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL) && + raft::resource::get_stream_pool_size(res) >= 1) { + prefetch_stream_ = raft::resource::get_stream_from_stream_pool(res); + writeback_stream_ = raft::resource::get_stream_from_stream_pool(res); + } else { + local_stream_pool_ = std::make_shared(2); + prefetch_stream_ = local_stream_pool_.value()->get_stream(); + writeback_stream_ = local_stream_pool_.value()->get_stream(); + } + // if data is managed and not for_write_ we can set the attribute on the device ptr if (mem_strategy_ == memory_strategy::managed_only) { // location_.type = CU_MEM_LOCATION_TYPE_DEVICE; @@ -621,18 +625,29 @@ class batched_device_view_from_host { void prefetch_from_host_to_device(T* dev_ptr, size_t src_row_offset, size_t num_rows) { - raft::copy(dev_ptr, - host_view_.data_handle() + src_row_offset * host_view_.extent(1), - num_rows * host_view_.extent(1), - prefetch_stream_); + const size_t n_elem = num_rows * host_view_.extent(1); + const size_t n_bytes = n_elem * sizeof(T); + RAFT_CUDA_TRY(cudaHostRegister(host_view_.data_handle() + src_row_offset * host_view_.extent(1), + n_bytes, + cudaHostRegisterDefault)); + // use memcpy instead of raft::copy to avoid strange behavior with HMM/ATS memory + RAFT_CUDA_TRY(cudaMemcpyAsync(dev_ptr, + host_view_.data_handle() + src_row_offset * host_view_.extent(1), + n_bytes, + cudaMemcpyHostToDevice, + prefetch_stream_)); } void writeback_from_device_to_host(T* dev_ptr, size_t dst_row_offset, size_t num_rows) { - raft::copy(host_view_.data_handle() + dst_row_offset * host_view_.extent(1), - dev_ptr, - num_rows * host_view_.extent(1), - writeback_stream_); + const size_t n_elem = num_rows * host_view_.extent(1); + const size_t n_bytes = n_elem * sizeof(T); + // use memcpy instead of raft::copy to avoid strange behavior with HMM/ATS memory + RAFT_CUDA_TRY(cudaMemcpyAsync(host_view_.data_handle() + dst_row_offset * host_view_.extent(1), + dev_ptr, + n_bytes, + cudaMemcpyDeviceToHost, + writeback_stream_)); } // stream pool for local streams @@ -655,6 +670,7 @@ class batched_device_view_from_host { // input pointer information raft::host_matrix_view host_view_; + cudaPointerAttributes attr_; // internal device buffers uint64_t num_buffers_; From ec45fd251d90cd8713c58252d8258ebee3b700a8 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Tue, 10 Mar 2026 22:46:18 +0000 Subject: [PATCH 18/40] add test, persist stream pool, cleanup --- cpp/src/neighbors/detail/cagra/utils.hpp | 214 ++++++++++-------- cpp/tests/CMakeLists.txt | 1 + .../test_batched_device_view_from_host.cu | 205 +++++++++++++++++ 3 files changed, 326 insertions(+), 94 deletions(-) create mode 100644 cpp/tests/neighbors/ann_cagra/test_batched_device_view_from_host.cu diff --git a/cpp/src/neighbors/detail/cagra/utils.hpp b/cpp/src/neighbors/detail/cagra/utils.hpp index 75883a9636..44d87d2993 100644 --- a/cpp/src/neighbors/detail/cagra/utils.hpp +++ b/cpp/src/neighbors/detail/cagra/utils.hpp @@ -322,15 +322,22 @@ void copy_with_padding( * writeback of the data Each batch can be referenced exactlyonce by calling the next_view() * function * + * Usage: + * ``` + * batched_device_view_from_host view(res, host_view, batch_size, host_writeback, + * initialize); while (view.next_view().extent(0) > 0) { auto device_view = view.next_view(); + * // use device_view + * } + * ``` + * + * The call to next_view() will + * * synchronize on all previous operations / increments batch_id_ + * * (optionally) write back the data of the previous batch to the host + * * (optionally) prefetch the data of the next batch + * * return the view of the current batch + * * @tparam T The type of the data * @tparam IdxT The type of the index - * @param res The resources - * @param host_view The host view to create the batched device view from - * @param batch_size The batch size - * @param host_writeback Whether to write back the data to the host (only for host memory) - * @param initialize Whether to initialize the data (only for managed memory) - * - * @return The batched device view */ template class batched_device_view_from_host { @@ -341,6 +348,18 @@ class batched_device_view_from_host { managed_only, // data is on managed memory (system managed) }; + /** + * Create a batched device view from a host view and will handle the prefetch and + * writeback of the data. Each batch can be referenced exactly once by calling the next_view() + * method. + * + * @param res The resources to use + * @param host_view The host view to create the batched device view from + * @param batch_size The batch size + * @param host_writeback Whether to write back the data to the host (only for host memory) + * (default: false) + * @param initialize Whether to initialize the data (only for managed memory) (default: true) + */ batched_device_view_from_host(raft::resources const& res, raft::host_matrix_view host_view, uint64_t batch_size, @@ -360,6 +379,9 @@ class batched_device_view_from_host { return; } + RAFT_EXPECTS(host_writeback_ || initialize_, + "At least one of host_writeback or initialize must be true"); + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr_, host_view.data_handle())); switch (attr_.type) { case cudaMemoryTypeUnregistered: @@ -388,7 +410,8 @@ class batched_device_view_from_host { raft::make_extents(batch_size, host_view.extent(1)))); device_ptr[1] = device_mem_[1]->data_handle(); } - if (host_writeback_ && batch_size * 2 < static_cast(host_view.extent(0))) { + if (host_writeback_ && initialize_ && + batch_size * 2 < static_cast(host_view.extent(0))) { num_buffers_ = 3; device_mem_[2].emplace(raft::make_device_mdarray( res, @@ -397,15 +420,16 @@ class batched_device_view_from_host { device_ptr[2] = device_mem_[2]->data_handle(); } } catch (std::bad_alloc& e) { - RAFT_LOG_DEBUG("Insufficient memory for device buffers"); if (attr_.devicePointer != nullptr) { + RAFT_LOG_DEBUG("Insufficient memory for device buffers, switching to managed memory"); mem_strategy_ = memory_strategy::managed_only; } else { throw std::bad_alloc(); } } catch (raft::logic_error& e) { - RAFT_LOG_DEBUG("Insufficient memory for device buffers (logic error)"); if (attr_.devicePointer != nullptr) { + RAFT_LOG_DEBUG( + "Insufficient memory for device buffers (logic error), switching to managed memory"); mem_strategy_ = memory_strategy::managed_only; } else { throw raft::logic_error("Insufficient memory for device buffers (logic error)"); @@ -413,20 +437,18 @@ class batched_device_view_from_host { } } - // setup streams - if (res.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL) && - raft::resource::get_stream_pool_size(res) >= 1) { - prefetch_stream_ = raft::resource::get_stream_from_stream_pool(res); - writeback_stream_ = raft::resource::get_stream_from_stream_pool(res); - } else { - local_stream_pool_ = std::make_shared(2); - prefetch_stream_ = local_stream_pool_.value()->get_stream(); - writeback_stream_ = local_stream_pool_.value()->get_stream(); + // setup stream pool if not already present + size_t required_streams = host_writeback_ && initialize_ ? 2 : 1; + if (!res.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL) || + raft::resource::get_stream_pool_size(res) < required_streams) { + // always create at least 2 streams to account for subsequent iterator calls + raft::resource::set_cuda_stream_pool(res, std::make_shared(2)); } + prefetch_stream_ = raft::resource::get_stream_from_stream_pool(res); + writeback_stream_ = raft::resource::get_stream_from_stream_pool(res); // if data is managed and not for_write_ we can set the attribute on the device ptr if (mem_strategy_ == memory_strategy::managed_only) { - // location_.type = CU_MEM_LOCATION_TYPE_DEVICE; location_.type = cudaMemLocationTypeDevice; location_.id = static_cast(raft::resource::get_device_id(res_)); if (!host_writeback_) { @@ -440,6 +462,84 @@ class batched_device_view_from_host { prefetch_next_batch(); } + ~batched_device_view_from_host() noexcept + { + raft::resource::sync_stream(res_); + + // if data is on host and for_write --> make sure to copy back last active + // if data is managed and evict --> evict last active + + // make sure to sync on prefetch stream & res + switch (mem_strategy_) { + case memory_strategy::managed_only: + if (!host_writeback_) { + uint32_t discard_pos = batch_id_ % num_buffers_; + size_t discard_size_rows = actual_batch_size_[discard_pos]; + if (batch_id_ > 0) { + discard_pos = (batch_id_ - 1) % num_buffers_; + discard_size_rows += batch_size_; + } + discard_managed_region(device_ptr[discard_pos], + discard_size_rows * host_view_.extent(1) * sizeof(T)); + writeback_stream_.synchronize(); + } + break; + case memory_strategy::copy_device: + if (host_writeback_) { + uint32_t writeback_pos_last = batch_id_ % num_buffers_; + if (batch_id_ > 0) { + uint32_t writeback_pos = (batch_id_ - 1) % num_buffers_; + uint64_t writeback_offset = (batch_id_ - 1) * batch_size_; + writeback_from_device_to_host(device_ptr[writeback_pos], writeback_offset, batch_size_); + } + { + uint64_t writeback_offset_last = batch_id_ * batch_size_; + writeback_from_device_to_host(device_ptr[writeback_pos_last], + writeback_offset_last, + actual_batch_size_[writeback_pos_last]); + } + writeback_stream_.synchronize(); + } + break; + case memory_strategy::device_only: break; + } + } + + /** + * Returns the next view of the batch + * + * This function will ensure the next batch is ready and will trigger the prefetch of the + * subsequent next batch. If writeback is enabled, the last active batch will be written back to + * the host. + * + * @return The next view of the batch + */ + raft::device_matrix_view next_view() + { + bool end_of_data = static_cast((batch_id_ + 1) * batch_size_) >= + static_cast(host_view_.extent(0)); + + // special case for empty host view or last batch surpassed + if (end_of_data) { + return raft::make_device_matrix_view(nullptr, 0, host_view_.extent(1)); + } + + // trigger prefetch of next batch (also increments batch_id_) + prefetch_next_batch(); + + uint32_t current_pos = batch_id_ % num_buffers_; + return raft::make_device_matrix_view( + device_ptr[current_pos], actual_batch_size_[current_pos], host_view_.extent(1)); + } + + private: + /** + * Prefetch the next batch + * + * This function will prefetch the next batch and will handle the writeback of the data. + * + * @return True if the next batch exists, false otherwise + */ bool prefetch_next_batch() { batch_id_++; @@ -512,77 +612,6 @@ class batched_device_view_from_host { return next_batch_exists; } - ~batched_device_view_from_host() noexcept - { - prefetch_stream_.synchronize(); - writeback_stream_.synchronize(); - raft::resource::sync_stream(res_); - - // if data is on host and for_write --> make sure to copy back last active - // if data is managed and evict --> evict last active - - // make sure to sync on prefetch stream & res - switch (mem_strategy_) { - case memory_strategy::managed_only: - if (!host_writeback_) { - uint32_t discard_pos = batch_id_ % num_buffers_; - size_t discard_size_rows = actual_batch_size_[discard_pos]; - if (batch_id_ > 0) { - discard_pos = (batch_id_ - 1) % num_buffers_; - discard_size_rows += batch_size_; - } - discard_managed_region(device_ptr[discard_pos], - discard_size_rows * host_view_.extent(1) * sizeof(T)); - } - writeback_stream_.synchronize(); - break; - case memory_strategy::copy_device: - if (host_writeback_) { - uint32_t writeback_pos_last = batch_id_ % num_buffers_; - if (batch_id_ > 0) { - uint32_t writeback_pos = (batch_id_ - 1) % num_buffers_; - uint64_t writeback_offset = (batch_id_ - 1) * batch_size_; - writeback_from_device_to_host(device_ptr[writeback_pos], writeback_offset, batch_size_); - } - { - uint64_t writeback_offset_last = batch_id_ * batch_size_; - writeback_from_device_to_host(device_ptr[writeback_pos_last], - writeback_offset_last, - actual_batch_size_[writeback_pos_last]); - } - } - writeback_stream_.synchronize(); - break; - case memory_strategy::device_only: break; - } - } - - /** - * Returns the next view of the batch - * - * This function will ensure the next batch is ready and will trigger the prefetch of the - * subsequent next batch - * - * @return The next view of the batch - */ - raft::device_matrix_view next_view() - { - // special case for empty host view - if (host_view_.extent(0) == 0) { - return raft::make_device_matrix_view(nullptr, 0, host_view_.extent(1)); - } - - // trigger prefetch of next batch - bool next_batch_exists = prefetch_next_batch(); - - RAFT_EXPECTS(batch_id_ * batch_size_ < host_view_.extent(0), "Batch index out of bounds"); - - uint32_t current_pos = batch_id_ % num_buffers_; - return raft::make_device_matrix_view( - device_ptr[current_pos], actual_batch_size_[current_pos], host_view_.extent(1)); - } - - private: void advise_read_mostly(T* ptr, size_t size) { #if CUDA_VERSION >= 13000 @@ -627,9 +656,6 @@ class batched_device_view_from_host { { const size_t n_elem = num_rows * host_view_.extent(1); const size_t n_bytes = n_elem * sizeof(T); - RAFT_CUDA_TRY(cudaHostRegister(host_view_.data_handle() + src_row_offset * host_view_.extent(1), - n_bytes, - cudaHostRegisterDefault)); // use memcpy instead of raft::copy to avoid strange behavior with HMM/ATS memory RAFT_CUDA_TRY(cudaMemcpyAsync(dev_ptr, host_view_.data_handle() + src_row_offset * host_view_.extent(1), diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 35794adf9b..77fd18c7d3 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -173,6 +173,7 @@ ConfigureTest( ConfigureTest( NAME NEIGHBORS_ANN_CAGRA_HELPERS_TEST PATH neighbors/ann_cagra/test_optimize_uint32_t.cu + neighbors/ann_cagra/test_batched_device_view_from_host.cu GPUS 1 PERCENT 100 ) diff --git a/cpp/tests/neighbors/ann_cagra/test_batched_device_view_from_host.cu b/cpp/tests/neighbors/ann_cagra/test_batched_device_view_from_host.cu new file mode 100644 index 0000000000..1e1cc13093 --- /dev/null +++ b/cpp/tests/neighbors/ann_cagra/test_batched_device_view_from_host.cu @@ -0,0 +1,205 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../../../src/neighbors/detail/cagra/utils.hpp" + +#include +#include +#include +#include + +namespace cuvs::neighbors::cagra { + +using IdxT = uint32_t; + +struct BatchConfig { + bool initialize; + bool host_writeback; +}; + +struct DimsConfig { + int64_t n_rows; + int64_t n_cols; + uint64_t batch_size; +}; + +class BatchedDeviceViewFromHostTest : public ::testing::Test { + protected: + void SetUp() override { raft::resource::sync_stream(res); } + + /** + * Run batched_device_view_from_host over host data, copy device views back, + * and verify against the input. + */ + template + void run_and_verify_batched(InputMatrixView input_view, + uint64_t batch_size, + bool host_writeback, + bool initialize) + { + int64_t n_rows = input_view.extent(0); + int64_t n_cols = input_view.extent(1); + + std::vector readback(n_rows * n_cols); + + int64_t total_processed = 0; + + { + cagra::detail::batched_device_view_from_host batched( + res, + raft::make_host_matrix_view(input_view.data_handle(), n_rows, n_cols), + batch_size, + host_writeback, + initialize); + while (true) { + auto dev_view = batched.next_view(); + if (dev_view.extent(0) == 0) break; + + if (initialize) { + raft::copy(readback.data() + total_processed * n_cols, + dev_view.data_handle(), + dev_view.extent(0) * dev_view.extent(1), + raft::resource::get_cuda_stream(res)); + } + if (host_writeback) { raft::matrix::fill(res, dev_view, IdxT(17)); } + total_processed += dev_view.extent(0); + } + } + raft::resource::sync_stream(res); + + EXPECT_EQ(total_processed, n_rows); + if (initialize) { + for (int64_t i = 0; i < n_rows * n_cols; ++i) { + EXPECT_EQ(readback[i], IdxT(13)) << "Mismatch (initialize) at index " << i; + } + } + if (host_writeback) { + auto readback_view = + raft::make_host_matrix_view(readback.data(), n_rows, n_cols); + raft::copy(res, readback_view, input_view); + raft::resource::sync_stream(res); + for (int64_t i = 0; i < n_rows * n_cols; ++i) { + EXPECT_EQ(readback[i], IdxT(17)) << "Mismatch (host_writeback) at index " << i; + } + } + } + + raft::resources res; +}; + +TEST_F(BatchedDeviceViewFromHostTest, EmptyView) +{ + auto host_empty = raft::make_host_matrix(0, 8); + auto host_view = host_empty.view(); + cagra::detail::batched_device_view_from_host batched( + res, host_view, /*batch_size=*/128, /*host_writeback=*/false, /*initialize=*/true); + + auto view = batched.next_view(); + EXPECT_EQ(view.extent(0), 0); + EXPECT_EQ(view.extent(1), 8); + EXPECT_EQ(view.data_handle(), nullptr); +} + +using BatchDimsParam = std::tuple; + +class BatchedDeviceViewFromHostParameterizedTest + : public BatchedDeviceViewFromHostTest, + public ::testing::WithParamInterface {}; + +TEST_P(BatchedDeviceViewFromHostParameterizedTest, VectorHostData) +{ + auto [batch_config, dims_config] = GetParam(); + auto [initialize, host_writeback] = batch_config; + auto [n_rows, n_cols, batch_size] = dims_config; + + std::vector host_data(n_rows * n_cols); + auto host_view = raft::make_host_matrix_view(host_data.data(), n_rows, n_cols); + + std::fill(host_view.data_handle(), host_view.data_handle() + n_rows * n_cols, IdxT(13)); + + run_and_verify_batched(host_view, batch_size, host_writeback, initialize); +} + +TEST_P(BatchedDeviceViewFromHostParameterizedTest, PinnedMemory) +{ + auto [batch_config, dims_config] = GetParam(); + auto [initialize, host_writeback] = batch_config; + auto [n_rows, n_cols, batch_size] = dims_config; + + auto host_matrix = raft::make_pinned_matrix(res, n_rows, n_cols); + auto host_view = host_matrix.view(); + + std::fill(host_view.data_handle(), host_view.data_handle() + n_rows * n_cols, IdxT(13)); + + run_and_verify_batched(host_view, batch_size, host_writeback, initialize); +} + +TEST_P(BatchedDeviceViewFromHostParameterizedTest, ManagedMemory) +{ + auto [batch_config, dims_config] = GetParam(); + auto [initialize, host_writeback] = batch_config; + auto [n_rows, n_cols, batch_size] = dims_config; + + auto host_matrix = raft::make_managed_matrix(res, n_rows, n_cols); + auto host_view = host_matrix.view(); + + std::fill(host_view.data_handle(), host_view.data_handle() + n_rows * n_cols, IdxT(13)); + + run_and_verify_batched(host_view, batch_size, host_writeback, initialize); +} + +TEST_P(BatchedDeviceViewFromHostParameterizedTest, DeviceMemory) +{ + auto [batch_config, dims_config] = GetParam(); + auto [initialize, host_writeback] = batch_config; + auto [n_rows, n_cols, batch_size] = dims_config; + + auto host_matrix = raft::make_device_matrix(res, n_rows, n_cols); + auto host_view = host_matrix.view(); + + raft::matrix::fill(res, host_view, IdxT(13)); + + run_and_verify_batched(host_view, batch_size, host_writeback, initialize); +} + +static const std::array kBatchConfigs = {{ + {/*initialize=*/true, /*host_writeback=*/false}, + {/*initialize=*/false, /*host_writeback=*/true}, + {/*initialize=*/true, /*host_writeback=*/true}, +}}; + +static const std::array kDimsConfigs = {{ + {/*n_rows=*/64, /*n_cols=*/32, /*batch_size=*/256}, // rows less than batch size, single batch + {/*n_rows=*/64, /*n_cols=*/32, /*batch_size=*/64}, // single batch + {/*n_rows=*/256, /*n_cols=*/32, /*batch_size=*/32}, // multiple batches + {/*n_rows=*/500, + /*n_cols=*/32, + /*batch_size=*/128}, // multiple batches, partial batch in the end +}}; + +INSTANTIATE_TEST_SUITE_P(BatchConfigs, + BatchedDeviceViewFromHostParameterizedTest, + ::testing::Combine(::testing::ValuesIn(kBatchConfigs), + ::testing::ValuesIn(kDimsConfigs))); + +} // namespace cuvs::neighbors::cagra From c412138a0dd6e3b81fa9bc4e10a1b546d71c5476 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Wed, 11 Mar 2026 00:04:52 +0000 Subject: [PATCH 19/40] switch to cooperative groups as __reduce_min_sync causes issues --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index b5e055820d..2444350253 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -29,11 +29,16 @@ #include #include +#include +#include + #include #include #include #include +namespace cg = cooperative_groups; + namespace cuvs::neighbors::cagra::detail::graph { // unnamed namespace to avoid multiple definition error @@ -196,6 +201,9 @@ __global__ void kern_fused_prune(const IdxT* const knn_graph, // [graph_chunk_ { extern __shared__ unsigned char smem_buf[]; + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); + const uint32_t wid = threadIdx.x / raft::WarpSize; const uint32_t lane_id = threadIdx.x % raft::WarpSize; @@ -207,8 +215,7 @@ __global__ void kern_fused_prune(const IdxT* const knn_graph, // [graph_chunk_ uint64_t* const num_retain = stats; uint64_t* const num_full = stats + 1; - const unsigned warp_mask = 0xffffffff; - const uint32_t maxval16 = 0x0000ffff; + const uint32_t maxval16 = 0x0000ffff; const uint64_t nid_batch = blockIdx.x * num_warps + wid; const uint64_t nid = nid_batch + (batch_size * batch_id); @@ -255,11 +262,7 @@ __global__ void kern_fused_prune(const IdxT* const knn_graph, // [graph_chunk_ __syncwarp(); - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 1); - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 2); - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 4); - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 8); - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 16); + num_edges_no_detour = cg::reduce(warp, num_edges_no_detour, cg::plus()); num_edges_no_detour = min(num_edges_no_detour, output_graph_degree); if (lane_id == 0) { @@ -280,7 +283,7 @@ __global__ void kern_fused_prune(const IdxT* const knn_graph, // [graph_chunk_ } uint32_t local_min_with_tag = (local_min << 16) | ((uint32_t)local_idx); - uint32_t warp_min_with_tag = __reduce_min_sync(warp_mask, local_min_with_tag); + uint32_t warp_min_with_tag = cg::reduce(warp, local_min_with_tag, cg::less()); uint32_t warp_min_count = warp_min_with_tag >> 16; uint32_t warp_local_idx = warp_min_with_tag & 0xffff; @@ -294,7 +297,7 @@ __global__ void kern_fused_prune(const IdxT* const knn_graph, // [graph_chunk_ for (uint32_t k = lane_id; k < knn_graph_degree; k += raft::WarpSize) { if (smem_indices[k] == selected_node) { smem_num_detour[k] = maxval16; } } - __syncwarp(warp_mask); + __syncwarp(); if (lane_id == 0) { output_graph_ptr[nid_batch * output_graph_degree + i] = selected_node; } } @@ -312,7 +315,10 @@ __device__ unsigned int warp_pos_in_array(T val, const T* array, uint64_t num) break; } } - ret = __reduce_min_sync(0xffffffff, ret); + + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); + ret = cg::reduce(warp, ret, cg::less()); return ret; } From ab01bab594e4337a9b6530a686e0d8642ce61866 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Fri, 13 Mar 2026 18:43:55 +0000 Subject: [PATCH 20/40] back to column wise reverse graph creation to boost closer connections --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 124 +++++++++++------- 1 file changed, 78 insertions(+), 46 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 88c13c139e..5d43da851b 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -171,24 +171,38 @@ __global__ void kern_sort(const DATA_T* const dataset, // [dataset_chunk_size, } } +template +__global__ void kern_make_rev_graph(const IdxT* const dest_nodes, // [graph_size] + IdxT* const rev_graph, // [size, degree] + uint32_t* const rev_graph_count, // [graph_size] + const uint32_t graph_size, + const uint32_t degree) +{ + const uint32_t tid = threadIdx.x + (blockDim.x * blockIdx.x); + const uint32_t tnum = blockDim.x * gridDim.x; + + for (uint32_t src_id = tid; src_id < graph_size; src_id += tnum) { + const IdxT dest_id = dest_nodes[src_id]; + if (dest_id >= graph_size) continue; + + const uint32_t pos = atomicAdd(rev_graph_count + dest_id, 1); + if (pos < degree) { rev_graph[pos + ((uint64_t)degree * dest_id)] = src_id; } + } +} + template -__global__ void kern_rev_graph_batched(const IdxT* const dest_nodes, // [batch_size, degree] - IdxT* const rev_graph, // [graph_size, degree] - uint32_t* const rev_graph_count, // [graph_size] - const uint32_t graph_size, - const uint32_t degree, - const uint32_t batch_size, - const uint32_t batch_id) +__global__ void kern_make_rev_graph_k(const IdxT* const output_graph, // [graph_size, degree] + IdxT* const rev_graph, // [graph_size, degree] + uint32_t* const rev_graph_count, // [graph_size] + const uint32_t graph_size, + const uint32_t degree, + uint64_t k) { const uint64_t tid = threadIdx.x + (blockDim.x * blockIdx.x); const uint64_t tnum = blockDim.x * gridDim.x; - const uint64_t block_batch_size = min(batch_size, graph_size - batch_id * batch_size); - - for (uint64_t idx = tid; idx < block_batch_size * degree; idx += tnum) { - const IdxT dest_id = dest_nodes[idx]; - const uint32_t src_id = idx / degree; - + for (uint64_t src_id = tid; src_id < graph_size; src_id += tnum) { + IdxT dest_id = output_graph[k + (degree * src_id)]; if (dest_id >= graph_size) continue; const uint32_t pos = atomicAdd(rev_graph_count + dest_id, 1); @@ -840,50 +854,67 @@ void make_reverse_graph_gpu(raft::resources const& res, uint64_t output_graph_degree) { raft::common::nvtx::range block_scope( - "cagra::graph::optimize/reverse"); + "cagra::graph::optimize/reverse2"); - auto default_ws_mr = raft::resource::get_workspace_resource(res); + auto d_rev_graph = + raft::make_device_vector_view(d_rev_graph_ptr, graph_size * output_graph_degree); + auto d_rev_graph_count = + raft::make_device_vector_view(d_rev_graph_count_ptr, graph_size); - raft::matrix::fill( - res, - raft::make_device_vector_view(d_rev_graph_ptr, graph_size * output_graph_degree), - IdxT(-1)); + // + // Make reverse graph + // + const double time_make_start = cur_time(); - raft::matrix::fill( - res, - raft::make_device_vector_view(d_rev_graph_count_ptr, graph_size), - uint32_t(0)); + raft::matrix::fill(res, d_rev_graph, IdxT(-1)); + raft::matrix::fill(res, d_rev_graph_count, uint32_t(0)); - const uint32_t batch_size = - std::min(static_cast(graph_size), static_cast(256 * 1024)); - const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; + if (is_ptr_host_accessible(output_graph_ptr)) { + auto d_dest_nodes = + raft::make_device_mdarray(res, raft::make_extents(graph_size)); - batched_device_view_from_host d_output_graph( - res, - raft::make_host_matrix_view(output_graph_ptr, graph_size, output_graph_degree), - /*batch_size*/ batch_size, - /*host_writeback*/ false, - /*initialize*/ true); - - for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { + for (uint64_t k = 0; k < output_graph_degree; k++) { + RAFT_CUDA_TRY(cudaMemcpy2DAsync(d_dest_nodes.data_handle(), + sizeof(IdxT), + output_graph_ptr + k, + output_graph_degree * sizeof(IdxT), + 1 * sizeof(IdxT), + graph_size, + cudaMemcpyHostToDevice, + raft::resource::get_cuda_stream(res))); + + dim3 threads(256, 1, 1); + dim3 blocks(1024, 1, 1); + kern_make_rev_graph<<>>( + d_dest_nodes.data_handle(), + d_rev_graph.data_handle(), + d_rev_graph_count.data_handle(), + graph_size, + output_graph_degree); + RAFT_LOG_DEBUG("# Making reverse graph on GPUs: %lu / %u \r", k, output_graph_degree); + } + } else { + // output graph is fully device accessible, so we need no copy to device dim3 threads(256, 1, 1); dim3 blocks(1024, 1, 1); - auto output_view = d_output_graph.next_view(); - - kern_rev_graph_batched<<>>( - output_view.data_handle(), - d_rev_graph_ptr, - d_rev_graph_count_ptr, - static_cast(graph_size), - static_cast(output_graph_degree), - static_cast(batch_size), - static_cast(i_batch)); + for (uint64_t k = 0; k < output_graph_degree; k++) { + kern_make_rev_graph_k<<>>( + output_graph_ptr, + d_rev_graph.data_handle(), + d_rev_graph_count.data_handle(), + graph_size, + output_graph_degree, + k); + } } raft::resource::sync_stream(res); RAFT_LOG_DEBUG("\n"); + + const double time_make_end = cur_time(); + RAFT_LOG_DEBUG("# Making reverse graph time: %.1lf ms", + (time_make_end - time_make_start) * 1000.0); } -} // namespace template void optimize(raft::resources const& res, InputMatrixView knn_graph, From 68f78839a5437d48f54d876b484efded20e8448d Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Fri, 13 Mar 2026 20:00:55 +0000 Subject: [PATCH 21/40] fix signness --- cpp/src/neighbors/detail/cagra/utils.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/neighbors/detail/cagra/utils.hpp b/cpp/src/neighbors/detail/cagra/utils.hpp index 44d87d2993..7dae487863 100644 --- a/cpp/src/neighbors/detail/cagra/utils.hpp +++ b/cpp/src/neighbors/detail/cagra/utils.hpp @@ -561,7 +561,7 @@ class batched_device_view_from_host { // if data is on device only this is almost a noop, just prepping the pointers - RAFT_EXPECTS(offset_ <= host_view_.extent(0), "Offset out of bounds"); + RAFT_EXPECTS(static_cast(offset_) <= host_view_.extent(0), "Offset out of bounds"); bool next_batch_exists = offset_ < static_cast(host_view_.extent(0)); From add206a7697aaf019543a43e39f763858992c5a2 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Fri, 13 Mar 2026 22:51:05 +0000 Subject: [PATCH 22/40] stupid me trusting cursor to fix this --- cpp/src/neighbors/detail/cagra/utils.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/neighbors/detail/cagra/utils.hpp b/cpp/src/neighbors/detail/cagra/utils.hpp index 7dae487863..79d1ed1cae 100644 --- a/cpp/src/neighbors/detail/cagra/utils.hpp +++ b/cpp/src/neighbors/detail/cagra/utils.hpp @@ -561,7 +561,7 @@ class batched_device_view_from_host { // if data is on device only this is almost a noop, just prepping the pointers - RAFT_EXPECTS(static_cast(offset_) <= host_view_.extent(0), "Offset out of bounds"); + RAFT_EXPECTS(static_cast(offset_) <= host_view_.extent(0), "Offset out of bounds"); bool next_batch_exists = offset_ < static_cast(host_view_.extent(0)); From ef1ec1886213130f78e299dff80661262098f0c5 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Mon, 16 Mar 2026 14:33:48 +0000 Subject: [PATCH 23/40] remove pointer arithmetic part1 --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 111 ++++++++---------- 1 file changed, 49 insertions(+), 62 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 5d43da851b..995b0f29bf 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -758,17 +758,18 @@ void check_duplicates_and_out_of_range(const IdxT* output_graph_ptr, num_oor == 0, "%lu out-of-range index node(s) are found in the generated CAGRA graph", num_oor); } -template +template void merge_graph_gpu(raft::resources const& res, - IdxT* output_graph_ptr, - const IdxT* d_rev_graph_ptr, - uint32_t* d_rev_graph_count_ptr, - IdxT* mst_graph_ptr, - uint32_t* mst_graph_num_edges_ptr, - uint64_t graph_size, - uint64_t output_graph_degree, + OutputMatrixView output_graph, + raft::device_matrix_view d_rev_graph, + raft::device_vector_view d_rev_graph_count, + raft::host_matrix_view mst_graph, + raft::host_vector_view mst_graph_num_edges, bool guarantee_connectivity) { + const uint64_t graph_size = output_graph.extent(0); + const uint64_t output_graph_degree = output_graph.extent(1); + raft::common::nvtx::range block_scope( "cagra::graph::optimize/combine"); @@ -784,23 +785,22 @@ void merge_graph_gpu(raft::resources const& res, batched_device_view_from_host d_output_graph( res, - raft::make_host_matrix_view(output_graph_ptr, graph_size, output_graph_degree), + raft::make_host_matrix_view( + output_graph.data_handle(), graph_size, output_graph_degree), /*batch_size*/ batch_size, /*host_writeback*/ true, /*initialize*/ true); - batched_device_view_from_host d_mst_graph( - res, - raft::make_host_matrix_view( - mst_graph_ptr, guarantee_connectivity ? graph_size : 0, output_graph_degree), - /*batch_size*/ batch_size, - /*host_writeback*/ false, - /*initialize*/ true); + batched_device_view_from_host d_mst_graph(res, + mst_graph, + /*batch_size*/ batch_size, + /*host_writeback*/ false, + /*initialize*/ true); batched_device_view_from_host d_mst_graph_num_edges( res, raft::make_host_matrix_view( - mst_graph_ptr, guarantee_connectivity ? graph_size : 0, output_graph_degree), + mst_graph_num_edges.data_handle(), guarantee_connectivity ? graph_size : 0, 1), /*batch_size*/ batch_size, /*host_writeback*/ false, /*initialize*/ true); @@ -816,8 +816,8 @@ void merge_graph_gpu(raft::resources const& res, kern_merge_graph <<>>( output_view.data_handle(), - d_rev_graph_ptr, - d_rev_graph_count_ptr, + d_rev_graph.data_handle(), + d_rev_graph_count.data_handle(), static_cast(graph_size), static_cast(output_graph_degree), mst_graph_view.data_handle(), @@ -845,21 +845,17 @@ void merge_graph_gpu(raft::resources const& res, (merge_graph_end - merge_graph_start) * 1000.0); } -template +template void make_reverse_graph_gpu(raft::resources const& res, - IdxT* output_graph_ptr, - IdxT* d_rev_graph_ptr, - uint32_t* d_rev_graph_count_ptr, - uint64_t graph_size, - uint64_t output_graph_degree) + OutputMatrixView output_graph, + raft::device_matrix_view d_rev_graph, + raft::device_vector_view d_rev_graph_count) { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/reverse2"); + const uint64_t graph_size = output_graph.extent(0); + const uint64_t output_graph_degree = output_graph.extent(1); - auto d_rev_graph = - raft::make_device_vector_view(d_rev_graph_ptr, graph_size * output_graph_degree); - auto d_rev_graph_count = - raft::make_device_vector_view(d_rev_graph_count_ptr, graph_size); + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/reverse"); // // Make reverse graph @@ -869,14 +865,14 @@ void make_reverse_graph_gpu(raft::resources const& res, raft::matrix::fill(res, d_rev_graph, IdxT(-1)); raft::matrix::fill(res, d_rev_graph_count, uint32_t(0)); - if (is_ptr_host_accessible(output_graph_ptr)) { + if (is_ptr_host_accessible(output_graph.data_handle())) { auto d_dest_nodes = raft::make_device_mdarray(res, raft::make_extents(graph_size)); for (uint64_t k = 0; k < output_graph_degree; k++) { RAFT_CUDA_TRY(cudaMemcpy2DAsync(d_dest_nodes.data_handle(), sizeof(IdxT), - output_graph_ptr + k, + output_graph.data_handle() + k, // host pointer output_graph_degree * sizeof(IdxT), 1 * sizeof(IdxT), graph_size, @@ -899,7 +895,7 @@ void make_reverse_graph_gpu(raft::resources const& res, dim3 blocks(1024, 1, 1); for (uint64_t k = 0; k < output_graph_degree; k++) { kern_make_rev_graph_k<<>>( - output_graph_ptr, + output_graph.data_handle(), d_rev_graph.data_handle(), d_rev_graph_count.data_handle(), graph_size, @@ -1520,14 +1516,15 @@ void mst_optimization(raft::resources const& res, // specified number of edges are picked up for each node, starting with the edge with // the lowest number of 2-hop detours. // -template +template void prune_graph_gpu(raft::resources const& res, - IdxT* knn_graph_ptr, - uint64_t graph_size, - uint64_t knn_graph_degree, - IdxT* output_graph_ptr, - uint64_t output_graph_degree) + InputMatrixView knn_graph, + OutputMatrixView output_graph) { + const uint64_t graph_size = output_graph.extent(0); + const uint64_t knn_graph_degree = knn_graph.extent(1); + const uint64_t output_graph_degree = output_graph.extent(1); + raft::common::nvtx::range block_scope( "cagra::graph::optimize/prune"); auto default_ws_mr = raft::resource::get_workspace_resource(res); @@ -1548,7 +1545,8 @@ void prune_graph_gpu(raft::resources const& res, batched_device_view_from_host d_input_graph( res, - raft::make_host_matrix_view(knn_graph_ptr, graph_size, knn_graph_degree), + raft::make_host_matrix_view( + knn_graph.data_handle(), graph_size, knn_graph_degree), /*batch_size*/ graph_size, /*host_writeback*/ false, /*initialize*/ true); @@ -1556,7 +1554,8 @@ void prune_graph_gpu(raft::resources const& res, batched_device_view_from_host d_output_graph( res, - raft::make_host_matrix_view(output_graph_ptr, graph_size, output_graph_degree), + raft::make_host_matrix_view( + output_graph.data_handle(), graph_size, output_graph_degree), /*batch_size*/ batch_size, /*host_writeback*/ true, /*initialize*/ false); @@ -1641,7 +1640,7 @@ void optimize(raft::resources const& res, const uint64_t knn_graph_degree = knn_graph.extent(1); const uint64_t output_graph_degree = new_graph.extent(1); const uint64_t graph_size = new_graph.extent(0); - // auto input_graph_ptr = knn_graph.data_handle(); + raft::common::nvtx::range fun_scope( "cagra::graph::optimize(%zu, %zu, %u)", graph_size, knn_graph_degree, output_graph_degree); @@ -1673,12 +1672,7 @@ void optimize(raft::resources const& res, // prune graph -- will always use GPU path { - prune_graph_gpu(res, - knn_graph.data_handle(), - graph_size, - knn_graph_degree, - new_graph.data_handle(), - output_graph_degree); + prune_graph_gpu(res, knn_graph, new_graph); } // reverse graph creation will always use the GPU @@ -1691,12 +1685,7 @@ void optimize(raft::resources const& res, const double time_make_start = cur_time(); - make_reverse_graph_gpu(res, - new_graph.data_handle(), - d_rev_graph.data_handle(), - d_rev_graph_count.data_handle(), - graph_size, - output_graph_degree); + make_reverse_graph_gpu(res, new_graph, d_rev_graph.view(), d_rev_graph_count.view()); const double time_make_end = cur_time(); RAFT_LOG_DEBUG("# Making reverse graph time: %.1lf ms", @@ -1705,13 +1694,11 @@ void optimize(raft::resources const& res, // merge graph -- will always use GPU path { merge_graph_gpu(res, - new_graph.data_handle(), - d_rev_graph.data_handle(), - d_rev_graph_count.data_handle(), - mst_graph.data_handle(), - mst_graph_num_edges.data_handle(), - graph_size, - output_graph_degree, + new_graph, + d_rev_graph.view(), + d_rev_graph_count.view(), + mst_graph.view(), + mst_graph_num_edges.view(), guarantee_connectivity); } From 01e133632ea01357cd580bcc84329d5be1c9efa0 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Mon, 16 Mar 2026 15:57:24 +0000 Subject: [PATCH 24/40] remove pointer arithmetic part2 --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 167 ++++++++---------- 1 file changed, 71 insertions(+), 96 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 995b0f29bf..678068b0af 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -171,55 +171,37 @@ __global__ void kern_sort(const DATA_T* const dataset, // [dataset_chunk_size, } } -template -__global__ void kern_make_rev_graph(const IdxT* const dest_nodes, // [graph_size] - IdxT* const rev_graph, // [size, degree] - uint32_t* const rev_graph_count, // [graph_size] - const uint32_t graph_size, - const uint32_t degree) -{ - const uint32_t tid = threadIdx.x + (blockDim.x * blockIdx.x); - const uint32_t tnum = blockDim.x * gridDim.x; - - for (uint32_t src_id = tid; src_id < graph_size; src_id += tnum) { - const IdxT dest_id = dest_nodes[src_id]; - if (dest_id >= graph_size) continue; - - const uint32_t pos = atomicAdd(rev_graph_count + dest_id, 1); - if (pos < degree) { rev_graph[pos + ((uint64_t)degree * dest_id)] = src_id; } - } -} - -template -__global__ void kern_make_rev_graph_k(const IdxT* const output_graph, // [graph_size, degree] - IdxT* const rev_graph, // [graph_size, degree] - uint32_t* const rev_graph_count, // [graph_size] - const uint32_t graph_size, - const uint32_t degree, - uint64_t k) +template +__global__ void kern_make_rev_graph_k( + OutputMatrixView output_graph, // [graph_size, degree] + raft::device_matrix_view rev_graph, // [graph_size, degree] + raft::device_vector_view rev_graph_count, // [graph_size] + uint64_t k) { const uint64_t tid = threadIdx.x + (blockDim.x * blockIdx.x); const uint64_t tnum = blockDim.x * gridDim.x; + const uint64_t graph_size = rev_graph.extent(0); + const uint32_t rev_graph_degree = rev_graph.extent(1); + const uint32_t output_graph_degree = output_graph.extent(1); + for (uint64_t src_id = tid; src_id < graph_size; src_id += tnum) { - IdxT dest_id = output_graph[k + (degree * src_id)]; + IdxT dest_id = output_graph(src_id, k); if (dest_id >= graph_size) continue; - const uint32_t pos = atomicAdd(rev_graph_count + dest_id, 1); - if (pos < degree) { rev_graph[(degree * dest_id) + pos] = static_cast(src_id); } + const uint32_t pos = atomicAdd(&rev_graph_count(dest_id), 1); + if (pos < rev_graph_degree) { rev_graph(dest_id, pos) = static_cast(src_id); } } } template -__global__ void kern_fused_prune(const IdxT* const knn_graph, // [graph_chunk_size, graph_degree] - IdxT* const output_graph_ptr, // [batch_size, output_graph_degree] - const uint32_t graph_size, - const uint32_t knn_graph_degree, - const uint32_t output_graph_degree, - const uint32_t batch_size, - const uint32_t batch_id, - uint32_t* const d_invalid_neighbor_list, - uint64_t* const stats) +__global__ void kern_fused_prune( + raft::device_matrix_view knn_graph, // [graph_chunk_size, graph_degree] + raft::device_matrix_view output_graph, // [batch_size, output_graph_degree] + const uint32_t batch_size, + const uint32_t batch_id, + uint32_t* const d_invalid_neighbor_list, + uint64_t* const stats) { extern __shared__ unsigned char smem_buf[]; @@ -229,6 +211,10 @@ __global__ void kern_fused_prune(const IdxT* const knn_graph, // [graph_chunk_ const uint32_t wid = threadIdx.x / raft::WarpSize; const uint32_t lane_id = threadIdx.x % raft::WarpSize; + const uint64_t graph_size = knn_graph.extent(0); + const uint32_t knn_graph_degree = knn_graph.extent(1); + const uint32_t output_graph_degree = output_graph.extent(1); + IdxT* const smem_indices = reinterpret_cast(smem_buf + wid * knn_graph_degree * sizeof(IdxT)); uint32_t* const smem_num_detour = reinterpret_cast( @@ -239,15 +225,16 @@ __global__ void kern_fused_prune(const IdxT* const knn_graph, // [graph_chunk_ const uint32_t maxval16 = 0x0000ffff; - const uint64_t nid_batch = blockIdx.x * num_warps + wid; - const uint64_t nid = nid_batch + (batch_size * batch_id); + const uint32_t nid_batch = blockIdx.x * num_warps + wid; + const uint64_t nid = static_cast(nid_batch) + + (static_cast(batch_size) * static_cast(batch_id)); if (nid >= graph_size) { return; } // Load this node's neighbor row into shared memory to reduce global reads for (uint32_t k = lane_id; k < knn_graph_degree; k += raft::WarpSize) { smem_num_detour[k] = 0; - smem_indices[k] = knn_graph[k + ((uint64_t)knn_graph_degree * nid)]; + smem_indices[k] = knn_graph(nid, k); if (smem_indices[k] == nid) { // Lower the priority of self-edge smem_num_detour[k] = knn_graph_degree; @@ -260,7 +247,7 @@ __global__ void kern_fused_prune(const IdxT* const knn_graph, // [graph_chunk_ const uint64_t iD = smem_indices[kAD]; if (iD >= graph_size) { continue; } for (uint32_t kDB = lane_id; kDB < knn_graph_degree; kDB += raft::WarpSize) { - const uint64_t iB_candidate = knn_graph[kDB + ((uint64_t)knn_graph_degree * iD)]; + const uint64_t iB_candidate = knn_graph(iD, kDB); for (uint32_t kAB = kAD + 1; kAB < knn_graph_degree; kAB++) { // if ( kDB < kAB ) { @@ -321,7 +308,7 @@ __global__ void kern_fused_prune(const IdxT* const knn_graph, // [graph_chunk_ } __syncwarp(); - if (lane_id == 0) { output_graph_ptr[nid_batch * output_graph_degree + i] = selected_node; } + if (lane_id == 0) { output_graph(nid_batch, i) = selected_node; } } } @@ -353,19 +340,21 @@ __device__ void thread_shift_array(T* array, uint64_t num) } template -__global__ void kern_merge_graph(IdxT* output_graph, // [batch_size, output_graph_degree] - const IdxT* const rev_graph, // [graph_size, output_graph_degree] - uint32_t* const rev_graph_count, // [graph_size] - const uint32_t graph_size, - const uint32_t output_graph_degree, - const IdxT* const mst_graph, // [batch_size, output_graph_degree] - const uint32_t mst_graph_degree, - const uint32_t* const mst_graph_num_edges_ptr, // [batch_size] - const uint32_t batch_size, - const uint32_t batch_id, - bool guarantee_connectivity, - bool* check_num_protected_edges) +__global__ void kern_merge_graph( + raft::device_matrix_view output_graph, // [batch_size, output_graph_degree] + raft::device_matrix_view rev_graph, // [graph_size, output_graph_degree] + raft::device_vector_view rev_graph_count, // [graph_size] + raft::device_matrix_view mst_graph, // [batch_size, output_graph_degree] + raft::device_matrix_view mst_graph_num_edges, // [batch_size, 1] + const uint32_t batch_size, + const uint32_t batch_id, + bool guarantee_connectivity, + bool* check_num_protected_edges) { + const uint64_t graph_size = rev_graph.extent(0); + const uint32_t output_graph_degree = output_graph.extent(1); + const uint32_t mst_graph_degree = mst_graph.extent(1); + extern __shared__ unsigned char smem_buf[]; const uint32_t wid = threadIdx.x / raft::WarpSize; @@ -374,23 +363,25 @@ __global__ void kern_merge_graph(IdxT* output_graph, // [batch_size, o IdxT* smem_sorted_output_graph = reinterpret_cast(smem_buf + wid * output_graph_degree * sizeof(IdxT)); - const uint64_t nid_batch = blockIdx.x * num_warps + wid; - const uint64_t nid = nid_batch + (batch_size * batch_id); + const uint32_t nid_batch = blockIdx.x * num_warps + wid; + const uint64_t nid = static_cast(nid_batch) + + (static_cast(batch_size) * static_cast(batch_id)); if (nid >= graph_size) { return; } - const auto mst_graph_num_edges = guarantee_connectivity ? mst_graph_num_edges_ptr[nid_batch] : 0; + const auto current_mst_graph_num_edges = + guarantee_connectivity ? mst_graph_num_edges(nid_batch, 0) : 0; // If guarantee_connectivity == true, use a temporal list to merge the // neighbor lists of the graphs. if (guarantee_connectivity) { for (uint32_t i = lane_id; i < mst_graph_degree; i += raft::WarpSize) { - smem_sorted_output_graph[i] = mst_graph[nid_batch * mst_graph_degree + i]; + smem_sorted_output_graph[i] = mst_graph(nid_batch, i); } __syncwarp(); - for (uint32_t pruned_j = 0, output_j = mst_graph_num_edges; + for (uint32_t pruned_j = 0, output_j = current_mst_graph_num_edges; (pruned_j < output_graph_degree) && (output_j < output_graph_degree); pruned_j++) { - const auto v = output_graph[output_graph_degree * nid_batch + pruned_j]; + const auto v = output_graph(nid_batch, pruned_j); unsigned int dup = 0; for (uint32_t m = lane_id; m < output_j; m += raft::WarpSize) { if (v == smem_sorted_output_graph[m]) { @@ -410,36 +401,36 @@ __global__ void kern_merge_graph(IdxT* output_graph, // [batch_size, o else { for (uint32_t i = lane_id; i < output_graph_degree; i += raft::WarpSize) { - smem_sorted_output_graph[i] = output_graph[output_graph_degree * nid_batch + i]; + smem_sorted_output_graph[i] = output_graph(nid_batch, i); } __syncwarp(); } - const auto num_protected_edges = max(mst_graph_num_edges, output_graph_degree / 2); + const auto num_protected_edges = max(current_mst_graph_num_edges, output_graph_degree / 2); if (num_protected_edges > output_graph_degree) { check_num_protected_edges[0] = false; } if (num_protected_edges == output_graph_degree) { return; } - auto kr = min(rev_graph_count[nid], output_graph_degree); + auto kr = min(rev_graph_count(nid), output_graph_degree); while (kr) { kr -= 1; - if (rev_graph[kr + (output_graph_degree * nid)] < graph_size) { - uint64_t pos = warp_pos_in_array( - rev_graph[kr + (output_graph_degree * nid)], smem_sorted_output_graph, output_graph_degree); + if (rev_graph(nid, kr) < graph_size) { + uint64_t pos = + warp_pos_in_array(rev_graph(nid, kr), smem_sorted_output_graph, output_graph_degree); if (pos < num_protected_edges) { continue; } uint64_t num_shift = pos - num_protected_edges; if (pos >= output_graph_degree) { num_shift = output_graph_degree - num_protected_edges - 1; } if (lane_id == 0) { thread_shift_array(smem_sorted_output_graph + num_protected_edges, num_shift); - smem_sorted_output_graph[num_protected_edges] = rev_graph[kr + (output_graph_degree * nid)]; + smem_sorted_output_graph[num_protected_edges] = rev_graph(nid, kr); } __syncwarp(); } } for (uint32_t i = lane_id; i < output_graph_degree; i += raft::WarpSize) { - output_graph[(output_graph_degree * nid_batch) + i] = smem_sorted_output_graph[i]; + output_graph(nid_batch, i) = smem_sorted_output_graph[i]; } } @@ -815,14 +806,11 @@ void merge_graph_gpu(raft::resources const& res, auto output_view = d_output_graph.next_view(); kern_merge_graph <<>>( - output_view.data_handle(), - d_rev_graph.data_handle(), - d_rev_graph_count.data_handle(), - static_cast(graph_size), - static_cast(output_graph_degree), - mst_graph_view.data_handle(), - static_cast(output_graph_degree), - mst_graph_num_edges_view.data_handle(), + output_view, + d_rev_graph, + d_rev_graph_count, + mst_graph_view, + mst_graph_num_edges_view, batch_size, i_batch, guarantee_connectivity, @@ -866,8 +854,7 @@ void make_reverse_graph_gpu(raft::resources const& res, raft::matrix::fill(res, d_rev_graph_count, uint32_t(0)); if (is_ptr_host_accessible(output_graph.data_handle())) { - auto d_dest_nodes = - raft::make_device_mdarray(res, raft::make_extents(graph_size)); + auto d_dest_nodes = raft::make_device_matrix(res, graph_size, 1); for (uint64_t k = 0; k < output_graph_degree; k++) { RAFT_CUDA_TRY(cudaMemcpy2DAsync(d_dest_nodes.data_handle(), @@ -881,12 +868,8 @@ void make_reverse_graph_gpu(raft::resources const& res, dim3 threads(256, 1, 1); dim3 blocks(1024, 1, 1); - kern_make_rev_graph<<>>( - d_dest_nodes.data_handle(), - d_rev_graph.data_handle(), - d_rev_graph_count.data_handle(), - graph_size, - output_graph_degree); + kern_make_rev_graph_k<<>>( + d_dest_nodes.view(), d_rev_graph, d_rev_graph_count, 0); RAFT_LOG_DEBUG("# Making reverse graph on GPUs: %lu / %u \r", k, output_graph_degree); } } else { @@ -895,12 +878,7 @@ void make_reverse_graph_gpu(raft::resources const& res, dim3 blocks(1024, 1, 1); for (uint64_t k = 0; k < output_graph_degree; k++) { kern_make_rev_graph_k<<>>( - output_graph.data_handle(), - d_rev_graph.data_handle(), - d_rev_graph_count.data_handle(), - graph_size, - output_graph_degree, - k); + output_graph, d_rev_graph, d_rev_graph_count, k); } } @@ -1570,11 +1548,8 @@ void prune_graph_gpu(raft::resources const& res, const size_t prune_smem_size = num_warps * knn_graph_degree * (sizeof(IdxT) + sizeof(uint32_t)); kern_fused_prune <<>>( - input_view.data_handle(), - output_view.data_handle(), - graph_size, - knn_graph_degree, - output_graph_degree, + input_view, + output_view, batch_size, i_batch, d_invalid_neighbor_list.data_handle(), From 5a24fb02471ac96146bb90091bfb920682a06aab Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Tue, 17 Mar 2026 18:13:02 +0000 Subject: [PATCH 25/40] fix mst graph usage --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 69 +++++++++++-------- cpp/src/neighbors/detail/cagra/utils.hpp | 2 +- 2 files changed, 43 insertions(+), 28 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 678068b0af..bd76ce0d44 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -374,7 +374,7 @@ __global__ void kern_merge_graph( // If guarantee_connectivity == true, use a temporal list to merge the // neighbor lists of the graphs. if (guarantee_connectivity) { - for (uint32_t i = lane_id; i < mst_graph_degree; i += raft::WarpSize) { + for (uint32_t i = lane_id; i < current_mst_graph_num_edges; i += raft::WarpSize) { smem_sorted_output_graph[i] = mst_graph(nid_batch, i); } __syncwarp(); @@ -788,10 +788,10 @@ void merge_graph_gpu(raft::resources const& res, /*host_writeback*/ false, /*initialize*/ true); - batched_device_view_from_host d_mst_graph_num_edges( + batched_device_view_from_host d_mst_graph_num_edges( res, - raft::make_host_matrix_view( - mst_graph_num_edges.data_handle(), guarantee_connectivity ? graph_size : 0, 1), + raft::make_host_matrix_view( + mst_graph_num_edges.data_handle(), mst_graph_num_edges.extent(0), 1), /*batch_size*/ batch_size, /*host_writeback*/ false, /*initialize*/ true); @@ -1101,11 +1101,11 @@ void mst_opt_update_graph(IdxT* mst_graph_ptr, // an approximate MST. // * If the input kNN graph is disconnected, random connection is added to the largest cluster. // -template +template void mst_optimization(raft::resources const& res, InputMatrixView input_graph, - OutputMatrixView output_graph, - VectorView mst_graph_num_edges, + raft::host_matrix_view output_graph, + raft::host_vector_view mst_graph_num_edges, bool use_gpu = true) { if (use_gpu) { @@ -1118,9 +1118,6 @@ void mst_optimization(raft::resources const& res, const IdxT graph_size = input_graph.extent(0); const uint32_t input_graph_degree = input_graph.extent(1); const uint32_t output_graph_degree = output_graph.extent(1); - auto input_graph_ptr = input_graph.data_handle(); - auto output_graph_ptr = output_graph.data_handle(); - auto mst_graph_num_edges_ptr = mst_graph_num_edges.data_handle(); // Allocate temporal arrays const uint32_t mst_graph_degree = output_graph_degree; @@ -1238,9 +1235,20 @@ void mst_optimization(raft::resources const& res, } } else { // Copy rank-k edges from the input knn graph to 'candidate_edges' + if (is_ptr_host_accessible(input_graph.data_handle())) { #pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - candidate_edges_ptr[i] = input_graph_ptr[k + (input_graph_degree * i)]; + for (uint64_t i = 0; i < graph_size; i++) { + candidate_edges_ptr[i] = input_graph(i, k); + } + } else { + // handle device knn graph + RAFT_CUDA_TRY(cudaMemcpy2D(candidate_edges_ptr, + sizeof(IdxT), + &input_graph(0, k), // host pointer + input_graph_degree * sizeof(IdxT), + 1 * sizeof(IdxT), + graph_size, + cudaMemcpyDeviceToHost)); } } @@ -1459,23 +1467,23 @@ void mst_optimization(raft::resources const& res, for (uint64_t i = 0; i < graph_size; i++) { uint64_t k = 0; for (uint64_t kj = 0; kj < mst_graph_degree; kj++) { - uint64_t j = mst_graph_ptr[(mst_graph_degree * i) + kj]; + uint64_t j = mst_graph(i, kj); if (j >= graph_size) continue; // Check to avoid duplication auto flag_match = false; for (uint64_t ki = 0; ki < k; ki++) { - if (j == output_graph_ptr[(output_graph_degree * i) + ki]) { + if (j == output_graph(i, ki)) { flag_match = true; break; } } if (flag_match) continue; - output_graph_ptr[(output_graph_degree * i) + k] = j; + output_graph(i, k) = j; k += 1; } - mst_graph_num_edges_ptr[i] = k; + mst_graph_num_edges(i) = k; } const double time_mst_opt_end = cur_time(); @@ -1621,26 +1629,24 @@ void optimize(raft::resources const& res, // MST optimization // currently, only using GPU path for MST optimization - auto mst_graph = raft::make_host_matrix(0, 0); - auto mst_graph_num_edges = raft::make_host_vector(0); + int64_t mst_graph_size = guarantee_connectivity ? graph_size : 0; + auto mst_graph = + raft::make_host_matrix(mst_graph_size, output_graph_degree); + auto mst_graph_num_edges = raft::make_host_vector(mst_graph_size); if (guarantee_connectivity) { - auto mst_graph_num_edges = raft::make_host_vector(graph_size); - auto mst_graph_num_edges_ptr = mst_graph_num_edges.data_handle(); #pragma omp parallel for for (uint64_t i = 0; i < graph_size; i++) { - mst_graph_num_edges_ptr[i] = 0; + mst_graph_num_edges(i) = 0; } raft::common::nvtx::range block_scope( "cagra::graph::optimize/check_connectivity"); - mst_graph = - raft::make_host_matrix(graph_size, output_graph_degree); RAFT_LOG_INFO("MST optimization is used to guarantee graph connectivity."); mst_optimization(res, knn_graph, mst_graph.view(), mst_graph_num_edges.view(), use_gpu); for (uint64_t i = 0; i < graph_size; i++) { if (i < 8 || i >= graph_size - 8) { - RAFT_LOG_DEBUG("# mst_graph_num_edges_ptr[%lu]: %u\n", i, mst_graph_num_edges_ptr[i]); + RAFT_LOG_DEBUG("# mst_graph_num_edges[%lu]: %u\n", i, mst_graph_num_edges(i)); } } } @@ -1651,9 +1657,18 @@ void optimize(raft::resources const& res, } // reverse graph creation will always use the GPU - auto d_rev_graph = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size, output_graph_degree)); - + // using default workspace resource for random access + // otherwise will be managed memory which is slow upon first access + auto d_rev_graph = raft::make_device_mdarray(res, raft::make_extents(0, 0)); + try { + d_rev_graph = raft::make_device_mdarray( + res, raft::make_extents(graph_size, output_graph_degree)); + } catch (const std::exception& e) { + RAFT_LOG_DEBUG( + "Failed to create device matrix for reverse graph, switching to large workspace resource"); + d_rev_graph = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(graph_size, output_graph_degree)); + } // This should use the default workspace resource for random access / atomics auto d_rev_graph_count = raft::make_device_mdarray( res, default_ws_mr, raft::make_extents(graph_size)); diff --git a/cpp/src/neighbors/detail/cagra/utils.hpp b/cpp/src/neighbors/detail/cagra/utils.hpp index 79d1ed1cae..6ff02a1c64 100644 --- a/cpp/src/neighbors/detail/cagra/utils.hpp +++ b/cpp/src/neighbors/detail/cagra/utils.hpp @@ -358,7 +358,7 @@ class batched_device_view_from_host { * @param batch_size The batch size * @param host_writeback Whether to write back the data to the host (only for host memory) * (default: false) - * @param initialize Whether to initialize the data (only for managed memory) (default: true) + * @param initialize Whether to initialize the data (default: true) */ batched_device_view_from_host(raft::resources const& res, raft::host_matrix_view host_view, From c033436c041eb27f321d5824a1aa43497008845a Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Tue, 17 Mar 2026 18:45:21 +0000 Subject: [PATCH 26/40] remove memcopy2D --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index bd76ce0d44..cebe856811 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -855,16 +855,13 @@ void make_reverse_graph_gpu(raft::resources const& res, if (is_ptr_host_accessible(output_graph.data_handle())) { auto d_dest_nodes = raft::make_device_matrix(res, graph_size, 1); - + auto dest_nodes = make_host_vector(graph_size); for (uint64_t k = 0; k < output_graph_degree; k++) { - RAFT_CUDA_TRY(cudaMemcpy2DAsync(d_dest_nodes.data_handle(), - sizeof(IdxT), - output_graph.data_handle() + k, // host pointer - output_graph_degree * sizeof(IdxT), - 1 * sizeof(IdxT), - graph_size, - cudaMemcpyHostToDevice, - raft::resource::get_cuda_stream(res))); +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + dest_nodes(i) = output_graph(i, k); + } + raft::copy(res, d_dest_nodes.view(), raft::make_const_mdspan(dest_nodes.view())); dim3 threads(256, 1, 1); dim3 blocks(1024, 1, 1); From a585dcb96b1144bf3e1a205460e99de67b6783e0 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Tue, 21 Apr 2026 14:36:52 +0000 Subject: [PATCH 27/40] review suggestions --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 62 +++++++++++++------ cpp/src/neighbors/detail/cagra/utils.hpp | 13 ++-- 2 files changed, 52 insertions(+), 23 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index b8245636d2..0fc52bd2c3 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -203,6 +203,10 @@ __global__ void kern_fused_prune( uint32_t* const d_invalid_neighbor_list, uint64_t* const stats) { + // Check assumption we have exactly one warp per row of the batch + assert(blockDim.x == raft::WarpSize * num_warps); + assert(gridDim.x * num_warps == batch_size); + extern __shared__ unsigned char smem_buf[]; cg::thread_block block = cg::this_thread_block(); @@ -220,8 +224,10 @@ __global__ void kern_fused_prune( uint32_t* const smem_num_detour = reinterpret_cast( smem_buf + wid * knn_graph_degree * sizeof(IdxT) + num_warps * knn_graph_degree * sizeof(IdxT)); +#ifndef NDEBUG uint64_t* const num_retain = stats; uint64_t* const num_full = stats + 1; +#endif const uint32_t maxval16 = 0x0000ffff; @@ -274,12 +280,14 @@ __global__ void kern_fused_prune( num_edges_no_detour = cg::reduce(warp, num_edges_no_detour, cg::plus()); num_edges_no_detour = min(num_edges_no_detour, output_graph_degree); +#ifndef NDEBUG if (lane_id == 0) { atomicAdd((unsigned long long int*)num_retain, (unsigned long long int)num_edges_no_detour); if (num_edges_no_detour >= output_graph_degree) { atomicAdd((unsigned long long int*)num_full, 1); } } +#endif for (uint32_t i = 0; i < output_graph_degree; i++) { uint32_t local_min = maxval16; @@ -332,10 +340,19 @@ __device__ unsigned int warp_pos_in_array(T val, const T* array, uint64_t num) } template -__device__ void thread_shift_array(T* array, uint64_t num) +__device__ void warp_shift_array_one_right(uint32_t lane_id, T* array, uint64_t num) { - for (uint64_t i = num; i > 0; i--) { - array[i] = array[i - 1]; + if (num == 0) { return; } + for (auto chunk_end = static_cast(num); chunk_end >= 1; chunk_end -= 32) { + const int64_t chunk_start_lo = chunk_end - 31; + const int64_t chunk_start = (chunk_start_lo > 1) ? chunk_start_lo : 1; + const int64_t k = chunk_start + static_cast(lane_id); + T val{}; + const bool active = (k <= chunk_end); + if (active) { val = array[k - 1]; } + __syncwarp(); + if (active) { array[k] = val; } + __syncwarp(); } } @@ -351,6 +368,10 @@ __global__ void kern_merge_graph( bool guarantee_connectivity, bool* check_num_protected_edges) { + // Check assumption we have exactly one warp per row of the batch + assert(blockDim.x == raft::WarpSize * num_warps); + assert(gridDim.x * num_warps == batch_size); + const uint64_t graph_size = rev_graph.extent(0); const uint32_t output_graph_degree = output_graph.extent(1); const uint32_t mst_graph_degree = mst_graph.extent(1); @@ -415,16 +436,16 @@ __global__ void kern_merge_graph( while (kr) { kr -= 1; - if (rev_graph(nid, kr) < graph_size) { + const auto rev_graph_value = rev_graph(nid, kr); + if (rev_graph_value < graph_size) { uint64_t pos = - warp_pos_in_array(rev_graph(nid, kr), smem_sorted_output_graph, output_graph_degree); + warp_pos_in_array(rev_graph_value, smem_sorted_output_graph, output_graph_degree); if (pos < num_protected_edges) { continue; } uint64_t num_shift = pos - num_protected_edges; if (pos >= output_graph_degree) { num_shift = output_graph_degree - num_protected_edges - 1; } - if (lane_id == 0) { - thread_shift_array(smem_sorted_output_graph + num_protected_edges, num_shift); - smem_sorted_output_graph[num_protected_edges] = rev_graph(nid, kr); - } + warp_shift_array_one_right( + lane_id, smem_sorted_output_graph + num_protected_edges, num_shift); + if (lane_id == 0) { smem_sorted_output_graph[num_protected_edges] = rev_graph_value; } __syncwarp(); } } @@ -768,8 +789,8 @@ void merge_graph_gpu(raft::resources const& res, const double merge_graph_start = cur_time(); auto d_check_num_protected_edges = raft::make_device_scalar(res, true); - auto d_invalid_neighbor_list = raft::make_device_scalar(res, 0u); + // The batchsize should be divisible by the number of warps per block (currently 4) uint32_t batch_size = std::min(static_cast(graph_size), static_cast(256 * 1024)); const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; @@ -797,6 +818,7 @@ void merge_graph_gpu(raft::resources const& res, /*initialize*/ true); const uint32_t num_warps = 4; + RAFT_EXPECTS(batch_size % num_warps == 0, "batch_size must be divisible by num_warps"); const dim3 threads_merge(raft::WarpSize * num_warps, 1, 1); const dim3 blocks_merge(raft::ceildiv(batch_size, num_warps), 1, 1); const size_t merge_smem_size = num_warps * output_graph_degree * sizeof(IdxT); @@ -1512,6 +1534,7 @@ void prune_graph_gpu(raft::resources const& res, "cagra::graph::optimize/prune"); auto default_ws_mr = raft::resource::get_workspace_resource(res); + // The batchsize should be divisible by the number of warps per block (currently 4) uint32_t batch_size = std::min(static_cast(graph_size), static_cast(256 * 1024)); const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; @@ -1545,12 +1568,14 @@ void prune_graph_gpu(raft::resources const& res, auto d_invalid_neighbor_list = raft::make_device_scalar(res, 0u); + const uint32_t num_warps = 4; + RAFT_EXPECTS(batch_size % num_warps == 0, "batch_size must be divisible by num_warps"); + const dim3 threads_prune(raft::WarpSize * num_warps, 1, 1); + const dim3 blocks_prune(raft::ceildiv(batch_size, num_warps), 1, 1); + const size_t prune_smem_size = num_warps * knn_graph_degree * (sizeof(IdxT) + sizeof(uint32_t)); + for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { - auto output_view = d_output_graph.next_view(); - const uint32_t num_warps = 4; - const dim3 threads_prune(raft::WarpSize * num_warps, 1, 1); - const dim3 blocks_prune(raft::ceildiv(batch_size, num_warps), 1, 1); - const size_t prune_smem_size = num_warps * knn_graph_degree * (sizeof(IdxT) + sizeof(uint32_t)); + auto output_view = d_output_graph.next_view(); kern_fused_prune <<>>( input_view, @@ -1601,8 +1626,8 @@ template block_scope( "cagra::graph::optimize/check_connectivity"); RAFT_LOG_INFO("MST optimization is used to guarantee graph connectivity."); - mst_optimization(res, knn_graph, mst_graph.view(), mst_graph_num_edges.view(), use_gpu); + mst_optimization( + res, knn_graph, mst_graph.view(), mst_graph_num_edges.view(), use_gpu_for_mst_optimization); for (uint64_t i = 0; i < graph_size; i++) { if (i < 8 || i >= graph_size - 8) { diff --git a/cpp/src/neighbors/detail/cagra/utils.hpp b/cpp/src/neighbors/detail/cagra/utils.hpp index 5d4b508d31..1f56d06612 100644 --- a/cpp/src/neighbors/detail/cagra/utils.hpp +++ b/cpp/src/neighbors/detail/cagra/utils.hpp @@ -321,14 +321,17 @@ void copy_with_padding( * Utility to create a batched device view from a host view * * This utility will create a batched device view from a host view and will handle the prefetch and - * writeback of the data Each batch can be referenced exactlyonce by calling the next_view() - * function + * writeback of the data. Each batch can be referenced exactly once by calling the next_view() + * function. * * Usage: * ``` - * batched_device_view_from_host view(res, host_view, batch_size, host_writeback, - * initialize); while (view.next_view().extent(0) > 0) { auto device_view = view.next_view(); - * // use device_view + * batched_device_view_from_host view( + * res, host_view, batch_size, host_writeback, initialize); + * for (;;) { + * auto device_view = view.next_view(); + * if (device_view.extent(0) == 0) { break; } + * // use device_view (one next_view() call per batch; never in the loop condition) * } * ``` * From 1f0ce37bef9366a769caf982b953331aac671616 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Tue, 21 Apr 2026 20:46:28 +0000 Subject: [PATCH 28/40] fix merge conflict --- cpp/src/neighbors/detail/cagra/utils.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/utils.hpp b/cpp/src/neighbors/detail/cagra/utils.hpp index 89cfaace75..6e6d806036 100644 --- a/cpp/src/neighbors/detail/cagra/utils.hpp +++ b/cpp/src/neighbors/detail/cagra/utils.hpp @@ -405,13 +405,13 @@ class batched_device_view_from_host { try { device_mem_[0].emplace(raft::make_device_mdarray( res, - raft::resource::get_workspace_resource(res), + raft::resource::get_workspace_resource_ref(res), raft::make_extents(batch_size, host_view.extent(1)))); device_ptr[0] = device_mem_[0]->data_handle(); if (batch_size < static_cast(host_view.extent(0))) { device_mem_[1].emplace(raft::make_device_mdarray( res, - raft::resource::get_workspace_resource(res), + raft::resource::get_workspace_resource_ref(res), raft::make_extents(batch_size, host_view.extent(1)))); device_ptr[1] = device_mem_[1]->data_handle(); } @@ -420,7 +420,7 @@ class batched_device_view_from_host { num_buffers_ = 3; device_mem_[2].emplace(raft::make_device_mdarray( res, - raft::resource::get_workspace_resource(res), + raft::resource::get_workspace_resource_ref(res), raft::make_extents(batch_size, host_view.extent(1)))); device_ptr[2] = device_mem_[2]->data_handle(); } From 24c606ec9c6dd805738a8a259bbd11dcd4f63d31 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Wed, 22 Apr 2026 21:07:24 +0000 Subject: [PATCH 29/40] more review suggestions --- .../neighbors/detail/cagra/cagra_build.cuh | 5 +-- cpp/src/neighbors/detail/cagra/graph_core.cuh | 33 ++++++++----------- cpp/src/neighbors/detail/cagra/utils.hpp | 8 ----- 3 files changed, 16 insertions(+), 30 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/cagra_build.cuh b/cpp/src/neighbors/detail/cagra/cagra_build.cuh index 7f2127981f..b0bdfd5508 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_build.cuh @@ -838,6 +838,7 @@ inline std::pair optimize_workspace_size(size_t n_rows, size_t prune_dev = batch_size * intermediate_degree * 1; // detour count (uint8_t) prune_dev += batch_size * sizeof(uint32_t); // d_num_detour_edges prune_dev += n_rows * intermediate_degree * index_size; // d_input_graph + prune_dev += 2 * batch_size * graph_degree * index_size; // d_output_graph(2*batch) // Reverse graph stage memory size_t rev_dev = n_rows * graph_degree * index_size; // d_rev_graph @@ -848,8 +849,8 @@ inline std::pair optimize_workspace_size(size_t n_rows, size_t combine_host = n_rows * sizeof(uint32_t) + graph_degree * sizeof(uint32_t); // in_edge_count + hist - // additional memory for combine stage on device - size_t combine_dev = n_rows * graph_degree * index_size; // d_output_graph + // additional memory for combine stage on device (3 batches) + size_t combine_dev = 3 * batch_size * graph_degree * index_size; // d_output_graph(3*batch) size_t total_host = mst_host + combine_host; size_t total_dev = std::max(prune_dev, rev_dev + combine_dev); diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 611b38ca34..f14f25a799 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -203,9 +203,9 @@ __global__ void kern_fused_prune( uint32_t* const d_invalid_neighbor_list, uint64_t* const stats) { - // Check assumption we have exactly one warp per row of the batch + // Check assumption we have at least one warp per row of the batch assert(blockDim.x == raft::WarpSize * num_warps); - assert(gridDim.x * num_warps == batch_size); + assert(gridDim.x * num_warps >= batch_size); extern __shared__ unsigned char smem_buf[]; @@ -368,9 +368,9 @@ __global__ void kern_merge_graph( bool guarantee_connectivity, bool* check_num_protected_edges) { - // Check assumption we have exactly one warp per row of the batch + // Check assumption we have at least one warp per row of the batch assert(blockDim.x == raft::WarpSize * num_warps); - assert(gridDim.x * num_warps == batch_size); + assert(gridDim.x * num_warps >= batch_size); const uint64_t graph_size = rev_graph.extent(0); const uint32_t output_graph_degree = output_graph.extent(1); @@ -785,12 +785,10 @@ void merge_graph_gpu(raft::resources const& res, raft::common::nvtx::range block_scope( "cagra::graph::optimize/combine"); - auto default_ws_mr = raft::resource::get_workspace_resource(res); const double merge_graph_start = cur_time(); auto d_check_num_protected_edges = raft::make_device_scalar(res, true); - // The batchsize should be divisible by the number of warps per block (currently 4) uint32_t batch_size = std::min(static_cast(graph_size), static_cast(256 * 1024)); const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; @@ -818,7 +816,6 @@ void merge_graph_gpu(raft::resources const& res, /*initialize*/ true); const uint32_t num_warps = 4; - RAFT_EXPECTS(batch_size % num_warps == 0, "batch_size must be divisible by num_warps"); const dim3 threads_merge(raft::WarpSize * num_warps, 1, 1); const dim3 blocks_merge(raft::ceildiv(batch_size, num_warps), 1, 1); const size_t merge_smem_size = num_warps * output_graph_degree * sizeof(IdxT); @@ -870,8 +867,6 @@ void make_reverse_graph_gpu(raft::resources const& res, // // Make reverse graph // - const double time_make_start = cur_time(); - raft::matrix::fill(res, d_rev_graph, IdxT(-1)); raft::matrix::fill(res, d_rev_graph_count, uint32_t(0)); @@ -900,13 +895,6 @@ void make_reverse_graph_gpu(raft::resources const& res, output_graph, d_rev_graph, d_rev_graph_count, k); } } - - raft::resource::sync_stream(res); - RAFT_LOG_DEBUG("\n"); - - const double time_make_end = cur_time(); - RAFT_LOG_DEBUG("# Making reverse graph time: %.1lf ms", - (time_make_end - time_make_start) * 1000.0); } template block_scope( "cagra::graph::optimize/prune"); - auto default_ws_mr = raft::resource::get_workspace_resource(res); - // The batchsize should be divisible by the number of warps per block (currently 4) uint32_t batch_size = std::min(static_cast(graph_size), static_cast(256 * 1024)); const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; @@ -1569,7 +1561,6 @@ void prune_graph_gpu(raft::resources const& res, auto d_invalid_neighbor_list = raft::make_device_scalar(res, 0u); const uint32_t num_warps = 4; - RAFT_EXPECTS(batch_size % num_warps == 0, "batch_size must be divisible by num_warps"); const dim3 threads_prune(raft::WarpSize * num_warps, 1, 1); const dim3 blocks_prune(raft::ceildiv(batch_size, num_warps), 1, 1); const size_t prune_smem_size = num_warps * knn_graph_degree * (sizeof(IdxT) + sizeof(uint32_t)); @@ -1700,8 +1691,10 @@ void optimize(raft::resources const& res, make_reverse_graph_gpu(res, new_graph, d_rev_graph.view(), d_rev_graph_count.view()); + raft::resource::sync_stream(res); + const double time_make_end = cur_time(); - RAFT_LOG_DEBUG("# Making reverse graph time: %.1lf ms", + RAFT_LOG_DEBUG("\n# Making reverse graph time: %.1lf ms", (time_make_end - time_make_start) * 1000.0); // merge graph -- will always use GPU path diff --git a/cpp/src/neighbors/detail/cagra/utils.hpp b/cpp/src/neighbors/detail/cagra/utils.hpp index 6e6d806036..0f22d6cc45 100644 --- a/cpp/src/neighbors/detail/cagra/utils.hpp +++ b/cpp/src/neighbors/detail/cagra/utils.hpp @@ -161,14 +161,6 @@ struct gen_index_msb_1_mask { }; } // namespace utils -template -bool is_ptr_device_accessible(T* ptr) -{ - cudaPointerAttributes attr; - RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, ptr)); - return attr.devicePointer != nullptr; -} - template bool is_ptr_host_accessible(T* ptr) { From f6957630ccea34741898575f6b4b8b9cd435587a Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Thu, 23 Apr 2026 14:03:48 +0000 Subject: [PATCH 30/40] coderabbit suggestions --- .../neighbors/detail/cagra/cagra_build.cuh | 6 +++- cpp/src/neighbors/detail/cagra/graph_core.cuh | 28 +++++++++++-------- cpp/src/neighbors/detail/cagra/utils.hpp | 10 +++++++ 3 files changed, 32 insertions(+), 12 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/cagra_build.cuh b/cpp/src/neighbors/detail/cagra/cagra_build.cuh index b0bdfd5508..b7fb57b2dc 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_build.cuh @@ -843,7 +843,7 @@ inline std::pair optimize_workspace_size(size_t n_rows, // Reverse graph stage memory size_t rev_dev = n_rows * graph_degree * index_size; // d_rev_graph rev_dev += n_rows * sizeof(uint32_t); // d_rev_graph_count - rev_dev += n_rows * sizeof(uint32_t); // d_dest_nodes + rev_dev += n_rows * index_size; // d_dest_nodes // Memory for merging graphs (host only optional) size_t combine_host = @@ -851,6 +851,10 @@ inline std::pair optimize_workspace_size(size_t n_rows, // additional memory for combine stage on device (3 batches) size_t combine_dev = 3 * batch_size * graph_degree * index_size; // d_output_graph(3*batch) + if (mst_optimize) { + combine_dev += 2 * batch_size * graph_degree * index_size; // d_mst_graph(2*batch) + combine_dev += 2 * batch_size * sizeof(uint32_t); // d_mst_graph_num_edges(2*batch) + } size_t total_host = mst_host + combine_host; size_t total_dev = std::max(prune_dev, rev_dev + combine_dev); diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index f14f25a799..a55401f9dd 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -221,8 +221,9 @@ __global__ void kern_fused_prune( IdxT* const smem_indices = reinterpret_cast(smem_buf + wid * knn_graph_degree * sizeof(IdxT)); - uint32_t* const smem_num_detour = reinterpret_cast( - smem_buf + wid * knn_graph_degree * sizeof(IdxT) + num_warps * knn_graph_degree * sizeof(IdxT)); + uint32_t* const smem_num_detour = + reinterpret_cast(smem_buf + num_warps * knn_graph_degree * sizeof(IdxT) + + wid * knn_graph_degree * sizeof(uint32_t)); #ifndef NDEBUG uint64_t* const num_retain = stats; @@ -429,7 +430,10 @@ __global__ void kern_merge_graph( const auto num_protected_edges = max(current_mst_graph_num_edges, output_graph_degree / 2); - if (num_protected_edges > output_graph_degree) { check_num_protected_edges[0] = false; } + if (num_protected_edges > output_graph_degree) { + check_num_protected_edges[0] = false; + return; + } if (num_protected_edges == output_graph_degree) { return; } auto kr = min(rev_graph_count(nid), output_graph_degree); @@ -750,7 +754,7 @@ void check_duplicates_and_out_of_range(const IdxT* output_graph_ptr, for (uint32_t j = 0; j < output_graph_degree; j++) { const auto neighbor_a = my_out_graph[j]; - if (neighbor_a > graph_size) { + if (neighbor_a >= graph_size) { num_oor++; continue; } @@ -1249,13 +1253,15 @@ void mst_optimization(raft::resources const& res, } } else { // handle device knn graph - RAFT_CUDA_TRY(cudaMemcpy2D(candidate_edges_ptr, - sizeof(IdxT), - &input_graph(0, k), // host pointer - input_graph_degree * sizeof(IdxT), - 1 * sizeof(IdxT), - graph_size, - cudaMemcpyDeviceToHost)); + RAFT_CUDA_TRY(cudaMemcpy2DAsync(candidate_edges_ptr, + sizeof(IdxT), + &input_graph(0, k), + input_graph_degree * sizeof(IdxT), + 1 * sizeof(IdxT), // width + graph_size, + cudaMemcpyDeviceToHost, + raft::resource::get_cuda_stream(res))); + raft::resource::sync_stream(res); // FIXME: use submdspan and raft::copy once supported /*auto column_view = cuda::std::submdspan(input_graph, diff --git a/cpp/src/neighbors/detail/cagra/utils.hpp b/cpp/src/neighbors/detail/cagra/utils.hpp index 0f22d6cc45..c5cbc1a6b3 100644 --- a/cpp/src/neighbors/detail/cagra/utils.hpp +++ b/cpp/src/neighbors/detail/cagra/utils.hpp @@ -376,6 +376,7 @@ class batched_device_view_from_host { return; } + RAFT_EXPECTS(batch_size_ > 0, "batch_size must be greater than zero for non-empty input"); RAFT_EXPECTS(host_writeback_ || initialize_, "At least one of host_writeback or initialize must be true"); @@ -418,6 +419,9 @@ class batched_device_view_from_host { } } catch (std::bad_alloc& e) { if (attr_.devicePointer != nullptr) { + for (auto& mem : device_mem_) { + mem.reset(); + } RAFT_LOG_DEBUG("Insufficient memory for device buffers, switching to managed memory"); mem_strategy_ = memory_strategy::managed_only; } else { @@ -425,6 +429,9 @@ class batched_device_view_from_host { } } catch (raft::logic_error& e) { if (attr_.devicePointer != nullptr) { + for (auto& mem : device_mem_) { + mem.reset(); + } RAFT_LOG_DEBUG( "Insufficient memory for device buffers (logic error), switching to managed memory"); mem_strategy_ = memory_strategy::managed_only; @@ -463,6 +470,9 @@ class batched_device_view_from_host { { raft::resource::sync_stream(res_); + // No view was handed to the caller, so no device-side modifications can require writeback. + if (batch_id_ < 0) { return; } + // if data is on host and for_write --> make sure to copy back last active // if data is managed and evict --> evict last active From 023fbc680cc764aa54b2ecfba537034a65765478 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Fri, 24 Apr 2026 16:55:38 +0000 Subject: [PATCH 31/40] trying to fix IllegalAccess --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index a55401f9dd..99a52c0ee7 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -537,7 +537,7 @@ __global__ void kern_mst_opt_update_graph(IdxT* mst_graph, // [graph_size, grap // Check to avoid duplication for (uint64_t kl = 0; kl < graph_degree; kl++) { uint64_t m = mst_graph[(graph_degree * l) + kl]; - if (m > graph_size) continue; + if (m >= graph_size) continue; uint32_t rm = get_root_label(m, label); if (ri == rm) { ret = 0; @@ -793,8 +793,7 @@ void merge_graph_gpu(raft::resources const& res, auto d_check_num_protected_edges = raft::make_device_scalar(res, true); - uint32_t batch_size = - std::min(static_cast(graph_size), static_cast(256 * 1024)); + uint32_t batch_size = static_cast(std::min(graph_size, 256 * 1024)); const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; batched_device_view_from_host d_output_graph( @@ -845,7 +844,7 @@ void merge_graph_gpu(raft::resources const& res, d_check_num_protected_edges.data_handle(), 1, raft::resource::get_cuda_stream(res)); - + raft::resource::sync_stream(res); const auto merge_graph_end = cur_time(); RAFT_EXPECTS(check_num_protected_edges, "Failed to merge the MST, pruned, and reverse edge graphs. " @@ -888,6 +887,7 @@ void make_reverse_graph_gpu(raft::resources const& res, dim3 blocks(1024, 1, 1); kern_make_rev_graph_k<<>>( d_dest_nodes.view(), d_rev_graph, d_rev_graph_count, 0); + raft::resource::sync_stream(res); RAFT_LOG_DEBUG("# Making reverse graph on GPUs: %lu / %u \r", k, output_graph_degree); } } else { @@ -1074,7 +1074,7 @@ void mst_opt_update_graph(IdxT* mst_graph_ptr, // Check to avoid duplication for (uint64_t kl = 0; kl < mst_graph_degree; kl++) { uint64_t m = mst_graph_ptr[(mst_graph_degree * l) + kl]; - if (m > graph_size) continue; + if (m >= graph_size) continue; if (label_ptr[i] == label_ptr[m]) { ret = 0; break; @@ -1533,8 +1533,7 @@ void prune_graph_gpu(raft::resources const& res, raft::common::nvtx::range block_scope( "cagra::graph::optimize/prune"); - uint32_t batch_size = - std::min(static_cast(graph_size), static_cast(256 * 1024)); + uint32_t batch_size = static_cast(std::min(graph_size, 256 * 1024)); const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; RAFT_LOG_DEBUG("# Pruning kNN Graph on GPUs\r"); From 3437700e40e8baf3e1051d143763eac8b62539af Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Mon, 27 Apr 2026 16:39:48 +0000 Subject: [PATCH 32/40] try more fixes for V100/cuda12.2 --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 110 +++++++++--------- cpp/src/neighbors/detail/cagra/utils.hpp | 13 ++- 2 files changed, 67 insertions(+), 56 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 99a52c0ee7..bcf7b3ec90 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -511,7 +511,8 @@ __global__ void kern_mst_opt_update_graph(IdxT* mst_graph, // [graph_size, grap ret = 0; auto kj = atomicAdd(incoming_num_edges + j, (IdxT)1); if (kj < incoming_max_edges[j]) { - auto ki = outgoing_num_edges[i]++; + auto ki = outgoing_num_edges[i]; + outgoing_num_edges[i] = ki + 1; mst_graph[(graph_degree * (i)) + ki] = j; // outgoing mst_graph[(graph_degree * (j + 1)) - 1 - kj] = i; // incoming ret = 1; @@ -566,30 +567,31 @@ __global__ void kern_mst_opt_labeling(IdxT* label, // [graph_size] uint64_t* stats) { const uint64_t i = threadIdx.x + (blockDim.x * blockIdx.x); - if (i >= graph_size) return; __shared__ uint32_t smem_updated[1]; if (threadIdx.x == 0) { smem_updated[0] = 0; } __syncthreads(); - for (uint64_t ki = 0; ki < graph_degree; ki++) { - uint64_t j = mst_graph[(graph_degree * i) + ki]; - if (j >= graph_size) continue; + if (i < graph_size) { + for (uint64_t ki = 0; ki < graph_degree; ki++) { + uint64_t j = mst_graph[(graph_degree * i) + ki]; + if (j >= graph_size) continue; - IdxT li = label[i]; - IdxT ri = get_root_label(i, label); - if (ri < li) { atomicMin(label + i, ri); } - IdxT lj = label[j]; - IdxT rj = get_root_label(j, label); - if (rj < lj) { atomicMin(label + j, rj); } - if (ri == rj) continue; - - if (ri > rj) { - atomicCAS(label + i, ri, rj); - } else if (rj > ri) { - atomicCAS(label + j, rj, ri); + IdxT li = label[i]; + IdxT ri = get_root_label(i, label); + if (ri < li) { atomicMin(label + i, ri); } + IdxT lj = label[j]; + IdxT rj = get_root_label(j, label); + if (rj < lj) { atomicMin(label + j, rj); } + if (ri == rj) continue; + + if (ri > rj) { + atomicCAS(label + i, ri, rj); + } else if (rj > ri) { + atomicCAS(label + j, rj, ri); + } + smem_updated[0] = 1; } - smem_updated[0] = 1; } __syncthreads(); @@ -603,18 +605,19 @@ __global__ void kern_mst_opt_cluster_size(IdxT* cluster_size, // [graph_size] uint64_t* stats) { const uint64_t i = threadIdx.x + (blockDim.x * blockIdx.x); - if (i >= graph_size) return; __shared__ uint64_t smem_num_clusters[1]; if (threadIdx.x == 0) { smem_num_clusters[0] = 0; } __syncthreads(); - IdxT ri = get_root_label(i, label); - if (ri == i) { - atomicAdd((unsigned long long int*)smem_num_clusters, 1); - } else { - atomicAdd(cluster_size + ri, cluster_size[i]); - cluster_size[i] = 0; + if (i < graph_size) { + IdxT ri = get_root_label(i, label); + if (ri == i) { + atomicAdd((unsigned long long int*)smem_num_clusters, 1); + } else { + atomicAdd(cluster_size + ri, cluster_size[i]); + cluster_size[i] = 0; + } } __syncthreads(); @@ -634,8 +637,6 @@ __global__ void kern_mst_opt_postprocessing(IdxT* outgoing_num_edges, // [graph uint64_t* stats) { const uint64_t i = threadIdx.x + (blockDim.x * blockIdx.x); - if (i >= graph_size) return; - __shared__ uint64_t smem_cluster_size_min[1]; __shared__ uint64_t smem_cluster_size_max[1]; __shared__ uint64_t smem_total_outgoing_edges[1]; @@ -648,34 +649,36 @@ __global__ void kern_mst_opt_postprocessing(IdxT* outgoing_num_edges, // [graph } __syncthreads(); - // Adjust incoming_num_edges - if (incoming_num_edges[i] > incoming_max_edges[i]) { - incoming_num_edges[i] = incoming_max_edges[i]; - } - - // Calculate min/max of cluster_size - if (cluster_size[i] > 0) { - if (smem_cluster_size_min[0] > cluster_size[i]) { - atomicMin((unsigned long long int*)smem_cluster_size_min, - (unsigned long long int)(cluster_size[i])); + if (i < graph_size) { + // Adjust incoming_num_edges + if (incoming_num_edges[i] > incoming_max_edges[i]) { + incoming_num_edges[i] = incoming_max_edges[i]; } - if (smem_cluster_size_max[0] < cluster_size[i]) { - atomicMax((unsigned long long int*)smem_cluster_size_max, - (unsigned long long int)(cluster_size[i])); + + // Calculate min/max of cluster_size + if (cluster_size[i] > 0) { + if (smem_cluster_size_min[0] > cluster_size[i]) { + atomicMin((unsigned long long int*)smem_cluster_size_min, + (unsigned long long int)(cluster_size[i])); + } + if (smem_cluster_size_max[0] < cluster_size[i]) { + atomicMax((unsigned long long int*)smem_cluster_size_max, + (unsigned long long int)(cluster_size[i])); + } } - } - // Calculate total number of outgoing/incoming edges - atomicAdd((unsigned long long int*)smem_total_outgoing_edges, - (unsigned long long int)(outgoing_num_edges[i])); - atomicAdd((unsigned long long int*)smem_total_incoming_edges, - (unsigned long long int)(incoming_num_edges[i])); - - // Adjust incoming/outgoing_max_edges - if (outgoing_num_edges[i] == outgoing_max_edges[i]) { - if (outgoing_num_edges[i] + incoming_num_edges[i] < graph_degree) { - outgoing_max_edges[i] += 1; - incoming_max_edges[i] -= 1; + // Calculate total number of outgoing/incoming edges + atomicAdd((unsigned long long int*)smem_total_outgoing_edges, + (unsigned long long int)(outgoing_num_edges[i])); + atomicAdd((unsigned long long int*)smem_total_incoming_edges, + (unsigned long long int)(incoming_num_edges[i])); + + // Adjust incoming/outgoing_max_edges + if (outgoing_num_edges[i] == outgoing_max_edges[i]) { + if (outgoing_num_edges[i] + incoming_num_edges[i] < graph_degree) { + outgoing_max_edges[i] += 1; + incoming_max_edges[i] -= 1; + } } } @@ -1594,14 +1597,15 @@ void prune_graph_gpu(raft::resources const& res, d_invalid_neighbor_list.data_handle(), 1, raft::resource::get_cuda_stream(res)); + raft::copy(res, host_stats.view(), raft::make_const_mdspan(dev_stats.view())); raft::resource::sync_stream(res); + RAFT_EXPECTS( invalid_neighbor_list == 0, "Could not generate an intermediate CAGRA graph because the initial kNN graph contains too " "many invalid or duplicated neighbor nodes. This error can occur, for example, if too many " "overflows occur during the norm computation between the dataset vectors."); - raft::copy(res, host_stats.view(), raft::make_const_mdspan(dev_stats.view())); num_keep = host_stats.data_handle()[0]; num_full = host_stats.data_handle()[1]; diff --git a/cpp/src/neighbors/detail/cagra/utils.hpp b/cpp/src/neighbors/detail/cagra/utils.hpp index c5cbc1a6b3..5e205ceafa 100644 --- a/cpp/src/neighbors/detail/cagra/utils.hpp +++ b/cpp/src/neighbors/detail/cagra/utils.hpp @@ -166,7 +166,14 @@ bool is_ptr_host_accessible(T* ptr) { cudaPointerAttributes attr; RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, ptr)); - return attr.hostPointer != nullptr; + // Pageable host memory comes back as cudaMemoryTypeUnregistered with both + // hostPointer and devicePointer set to nullptr, but it is still readable + // from the host. Treat anything that is not strictly device-only as host + // accessible. Without this, a pageable host pointer was previously routed + // into the "device-only" code paths in graph_core.cuh, which dereferenced + // a pageable host address from the GPU and caused cudaErrorIllegalAddress + // on systems without HMM (e.g. V100 with CUDA 12.2 drivers). + return attr.type != cudaMemoryTypeDevice; } /** @@ -418,7 +425,7 @@ class batched_device_view_from_host { device_ptr[2] = device_mem_[2]->data_handle(); } } catch (std::bad_alloc& e) { - if (attr_.devicePointer != nullptr) { + if (attr_.devicePointer != nullptr && attr_.type == cudaMemoryTypeManaged) { for (auto& mem : device_mem_) { mem.reset(); } @@ -428,7 +435,7 @@ class batched_device_view_from_host { throw std::bad_alloc(); } } catch (raft::logic_error& e) { - if (attr_.devicePointer != nullptr) { + if (attr_.devicePointer != nullptr && attr_.type == cudaMemoryTypeManaged) { for (auto& mem : device_mem_) { mem.reset(); } From f04022cceafbe5ea8e3ba049de1807c7090db354 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Fri, 1 May 2026 01:12:12 +0000 Subject: [PATCH 33/40] refactor remove all device pointer arithmetic from batch_device_view, utilize submdspan for passthrough --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 184 ++++--- cpp/src/neighbors/detail/cagra/utils.hpp | 465 +++++++----------- .../test_batched_device_view_from_host.cu | 137 ++++-- 3 files changed, 389 insertions(+), 397 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index bcf7b3ec90..eeaf35d4f9 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -194,14 +194,18 @@ __global__ void kern_make_rev_graph_k( } } -template -__global__ void kern_fused_prune( - raft::device_matrix_view knn_graph, // [graph_chunk_size, graph_degree] - raft::device_matrix_view output_graph, // [batch_size, output_graph_degree] - const uint32_t batch_size, - const uint32_t batch_id, - uint32_t* const d_invalid_neighbor_list, - uint64_t* const stats) +// KnnGraphView and OutputGraphView are mdspans with element type IdxT and +// dynamic 2D extents. They may have layout_right (raft::device_matrix_view) or +// layout_stride (the result of cuda::std::submdspan), and any accessor that is +// device-accessible (default_accessor, raft::host_device_accessor with a +// device-accessible memory_type, etc.). +template +__global__ void kern_fused_prune(KnnGraphView knn_graph, // [graph_chunk_size, graph_degree] + OutputGraphView output_graph, // [batch_size, output_graph_degree] + const uint32_t batch_size, + const uint32_t batch_id, + uint32_t* const d_invalid_neighbor_list, + uint64_t* const stats) { // Check assumption we have at least one warp per row of the batch assert(blockDim.x == raft::WarpSize * num_warps); @@ -357,13 +361,20 @@ __device__ void warp_shift_array_one_right(uint32_t lane_id, T* array, uint64_t } } -template +// OutputGraphView, MstGraphView and MstNumEdgesView are 2D mdspans that may +// have layout_right or layout_stride and any device-accessible accessor; see +// the comment on kern_fused_prune for details. +template __global__ void kern_merge_graph( - raft::device_matrix_view output_graph, // [batch_size, output_graph_degree] + OutputGraphView output_graph, // [batch_size, output_graph_degree] raft::device_matrix_view rev_graph, // [graph_size, output_graph_degree] raft::device_vector_view rev_graph_count, // [graph_size] - raft::device_matrix_view mst_graph, // [batch_size, output_graph_degree] - raft::device_matrix_view mst_graph_num_edges, // [batch_size, 1] + MstGraphView mst_graph, // [batch_size, output_graph_degree] + MstNumEdgesView mst_graph_num_edges, // [batch_size, 1] const uint32_t batch_size, const uint32_t batch_id, bool guarantee_connectivity, @@ -777,14 +788,16 @@ void check_duplicates_and_out_of_range(const IdxT* output_graph_ptr, num_oor == 0, "%lu out-of-range index node(s) are found in the generated CAGRA graph", num_oor); } -template -void merge_graph_gpu(raft::resources const& res, - OutputMatrixView output_graph, - raft::device_matrix_view d_rev_graph, - raft::device_vector_view d_rev_graph_count, - raft::host_matrix_view mst_graph, - raft::host_vector_view mst_graph_num_edges, - bool guarantee_connectivity) +template +void merge_graph_gpu( + raft::resources const& res, + raft::mdspan, raft::row_major, AccessorOutputGraph> + output_graph, + raft::device_matrix_view d_rev_graph, + raft::device_vector_view d_rev_graph_count, + raft::host_matrix_view mst_graph, + raft::host_vector_view mst_graph_num_edges, + bool guarantee_connectivity) { const uint64_t graph_size = output_graph.extent(0); const uint64_t output_graph_degree = output_graph.extent(1); @@ -799,27 +812,32 @@ void merge_graph_gpu(raft::resources const& res, uint32_t batch_size = static_cast(std::min(graph_size, 256 * 1024)); const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; - batched_device_view_from_host d_output_graph( - res, - raft::make_host_matrix_view( - output_graph.data_handle(), graph_size, output_graph_degree), - /*batch_size*/ batch_size, - /*host_writeback*/ true, - /*initialize*/ true); - - batched_device_view_from_host d_mst_graph(res, - mst_graph, - /*batch_size*/ batch_size, - /*host_writeback*/ false, - /*initialize*/ true); - - batched_device_view_from_host d_mst_graph_num_edges( - res, - raft::make_host_matrix_view( - mst_graph_num_edges.data_handle(), mst_graph_num_edges.extent(0), 1), - /*batch_size*/ batch_size, - /*host_writeback*/ false, - /*initialize*/ true); + batched_device_view d_output_graph(res, + output_graph, + /*batch_size*/ batch_size, + /*host_writeback*/ true, + /*initialize*/ true); + + batched_device_view< + IdxT, + int64_t, + raft::host_device_accessor, raft::memory_type::host>> + d_mst_graph(res, + mst_graph, + /*batch_size*/ batch_size, + /*host_writeback*/ false, + /*initialize*/ true); + + batched_device_view< + uint32_t, + int64_t, + raft::host_device_accessor, raft::memory_type::host>> + d_mst_graph_num_edges(res, + raft::make_host_matrix_view( + mst_graph_num_edges.data_handle(), mst_graph_num_edges.extent(0), 1l), + /*batch_size*/ batch_size, + /*host_writeback*/ false, + /*initialize*/ true); const uint32_t num_warps = 4; const dim3 threads_merge(raft::WarpSize * num_warps, 1, 1); @@ -858,11 +876,13 @@ void merge_graph_gpu(raft::resources const& res, (merge_graph_end - merge_graph_start) * 1000.0); } -template -void make_reverse_graph_gpu(raft::resources const& res, - OutputMatrixView output_graph, - raft::device_matrix_view d_rev_graph, - raft::device_vector_view d_rev_graph_count) +template +void make_reverse_graph_gpu( + raft::resources const& res, + raft::mdspan, raft::row_major, AccessorOutputGraph> + output_graph, + raft::device_matrix_view d_rev_graph, + raft::device_vector_view d_rev_graph_count) { const uint64_t graph_size = output_graph.extent(0); const uint64_t output_graph_degree = output_graph.extent(1); @@ -876,9 +896,9 @@ void make_reverse_graph_gpu(raft::resources const& res, raft::matrix::fill(res, d_rev_graph, IdxT(-1)); raft::matrix::fill(res, d_rev_graph_count, uint32_t(0)); - if (is_ptr_host_accessible(output_graph.data_handle())) { + if constexpr (!AccessorOutputGraph::is_device_accessible) { auto d_dest_nodes = raft::make_device_matrix(res, graph_size, 1); - auto dest_nodes = make_host_vector(graph_size); + auto dest_nodes = raft::make_host_vector(graph_size); for (uint64_t k = 0; k < output_graph_degree; k++) { #pragma omp parallel for for (uint64_t i = 0; i < graph_size; i++) { @@ -1115,12 +1135,13 @@ void mst_opt_update_graph(IdxT* mst_graph_ptr, // an approximate MST. // * If the input kNN graph is disconnected, random connection is added to the largest cluster. // -template -void mst_optimization(raft::resources const& res, - InputMatrixView input_graph, - raft::host_matrix_view output_graph, - raft::host_vector_view mst_graph_num_edges, - bool use_gpu = true) +template +void mst_optimization( + raft::resources const& res, + raft::mdspan, raft::row_major, AccessorKnnGraph> input_graph, + raft::host_matrix_view output_graph, + raft::host_vector_view mst_graph_num_edges, + bool use_gpu = true) { if (use_gpu) { RAFT_LOG_DEBUG("# MST optimization on GPU"); @@ -1249,7 +1270,7 @@ void mst_optimization(raft::resources const& res, } } else { // Copy rank-k edges from the input knn graph to 'candidate_edges' - if (is_ptr_host_accessible(input_graph.data_handle())) { + if constexpr (AccessorKnnGraph::is_host_accessible) { #pragma omp parallel for for (uint64_t i = 0; i < graph_size; i++) { candidate_edges_ptr[i] = input_graph(i, k); @@ -1524,10 +1545,12 @@ void mst_optimization(raft::resources const& res, // specified number of edges are picked up for each node, starting with the edge with // the lowest number of 2-hop detours. // -template -void prune_graph_gpu(raft::resources const& res, - InputMatrixView knn_graph, - OutputMatrixView output_graph) +template +void prune_graph_gpu( + raft::resources const& res, + raft::mdspan, raft::row_major, AccessorKnnGraph> knn_graph, + raft::mdspan, raft::row_major, AccessorOutputGraph> + output_graph) { const uint64_t graph_size = output_graph.extent(0); const uint64_t knn_graph_degree = knn_graph.extent(1); @@ -1549,22 +1572,18 @@ void prune_graph_gpu(raft::resources const& res, auto host_stats = raft::make_host_vector(2); raft::matrix::fill(res, dev_stats.view(), uint64_t(0)); - batched_device_view_from_host d_input_graph( - res, - raft::make_host_matrix_view( - knn_graph.data_handle(), graph_size, knn_graph_degree), - /*batch_size*/ graph_size, - /*host_writeback*/ false, - /*initialize*/ true); + batched_device_view d_input_graph(res, + knn_graph, + /*batch_size*/ graph_size, + /*host_writeback*/ false, + /*initialize*/ true); auto input_view = d_input_graph.next_view(); - batched_device_view_from_host d_output_graph( - res, - raft::make_host_matrix_view( - output_graph.data_handle(), graph_size, output_graph_degree), - /*batch_size*/ batch_size, - /*host_writeback*/ true, - /*initialize*/ false); + batched_device_view d_output_graph(res, + output_graph, + /*batch_size*/ batch_size, + /*host_writeback*/ true, + /*initialize*/ false); auto d_invalid_neighbor_list = raft::make_device_scalar(res, 0u); @@ -1622,12 +1641,17 @@ void prune_graph_gpu(raft::resources const& res, } // namespace -template -void optimize(raft::resources const& res, - InputMatrixView knn_graph, - OutputMatrixView new_graph, - const bool guarantee_connectivity = true, - const bool use_gpu_for_mst_optimization = true) +template , raft::memory_type::host>, + typename AccessorOutputGraph = + raft::host_device_accessor, raft::memory_type::host>> +void optimize( + raft::resources const& res, + raft::mdspan, raft::row_major, AccessorKnnGraph> knn_graph, + raft::mdspan, raft::row_major, AccessorOutputGraph> new_graph, + const bool guarantee_connectivity = true, + const bool use_gpu_for_mst_optimization = true) { RAFT_LOG_DEBUG( "# Pruning kNN graph (size=%lu, degree=%lu)\n", knn_graph.extent(0), knn_graph.extent(1)); @@ -1719,7 +1743,7 @@ void optimize(raft::resources const& res, raft::resource::sync_stream(res); - if (is_ptr_host_accessible(new_graph.data_handle())) { + if constexpr (AccessorOutputGraph::is_host_accessible) { // following checks require host access log_incoming_edges_histogram(new_graph.data_handle(), graph_size, output_graph_degree); diff --git a/cpp/src/neighbors/detail/cagra/utils.hpp b/cpp/src/neighbors/detail/cagra/utils.hpp index 5e205ceafa..ea1d830192 100644 --- a/cpp/src/neighbors/detail/cagra/utils.hpp +++ b/cpp/src/neighbors/detail/cagra/utils.hpp @@ -161,21 +161,6 @@ struct gen_index_msb_1_mask { }; } // namespace utils -template -bool is_ptr_host_accessible(T* ptr) -{ - cudaPointerAttributes attr; - RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, ptr)); - // Pageable host memory comes back as cudaMemoryTypeUnregistered with both - // hostPointer and devicePointer set to nullptr, but it is still readable - // from the host. Treat anything that is not strictly device-only as host - // accessible. Without this, a pageable host pointer was previously routed - // into the "device-only" code paths in graph_core.cuh, which dereferenced - // a pageable host address from the GPU and caused cudaErrorIllegalAddress - // on systems without HMM (e.g. V100 with CUDA 12.2 drivers). - return attr.type != cudaMemoryTypeDevice; -} - /** * Utility to sync memory from a host_matrix_view to a device_matrix_view * @@ -317,60 +302,94 @@ void copy_with_padding( } /** - * Utility to create a batched device view from a host view + * Iterate a 2D mdspan in row-batches with overlapping prefetch/writeback. * - * This utility will create a batched device view from a host view and will handle the prefetch and - * writeback of the data. Each batch can be referenced exactly once by calling the next_view() - * function. + * The strategy is selected at compile time from + * `AccessorInputView::is_device_accessible`: + * * passthrough: each batch is `cuda::std::submdspan(input_view_, ...)`. + * No buffering, no arithmetic on `input_view_.data_handle()`. + * * copy_device: each batch is staged through an internal device buffer + * via `cudaMemcpyAsync`. When `host_writeback` is set, every returned + * batch is copied back to `input_view_` on destruction (and earlier + * batches incrementally as iteration progresses). * * Usage: * ``` - * batched_device_view_from_host view( - * res, host_view, batch_size, host_writeback, initialize); + * batched_device_view view( + * res, input_view, batch_size, host_writeback, initialize); * for (;;) { * auto device_view = view.next_view(); - * if (device_view.extent(0) == 0) { break; } - * // use device_view (one next_view() call per batch; never in the loop condition) + * if (device_view.extent(0) == 0) { break; } // sole stop condition + * // ... call next_view() exactly once per batch; never in the loop condition * } * ``` * - * The call to next_view() will - * * synchronize on all previous operations / increments batch_id_ - * * (optionally) write back the data of the previous batch to the host - * * (optionally) prefetch the data of the next batch - * * return the view of the current batch + * Differences vs `cuvs::neighbors::detail::utils::batch_load_iterator` + * (cpp/src/neighbors/detail/ann_utils.cuh): + * * Input typing: this class takes a typed mdspan and decides the strategy + * at compile time from the accessor; `batch_load_iterator` takes + * `(T const*, n_rows, row_width)` and decides at runtime via + * `cudaPointerGetAttributes` (also forces a copy for HMM/ATS sources). + * * API: this class exposes a single-pass `next_view()` that returns an + * mdspan; `batch_load_iterator` is an STL iterator (begin/end, + * operator++/--, random access) that yields a `batch` wrapping a + * `T const*`. + * * Mutability: returned views here are mutable T* and support host + * writeback on destruction; `batch_load_iterator` is read-only and never + * copies back. + * * Pipelining: copy_device uses 2-3 device buffers and dedicated prefetch + * and writeback streams for true overlap; `batch_load_iterator` uses + * 1-2 buffers and exposes a synchronous `prefetch_next_batch()` that the + * caller must interleave with kernels on a separate stream. * - * @tparam T The type of the data - * @tparam IdxT The type of the index + * @tparam T element type + * @tparam IdxT index type for the input mdspan extents + * @tparam AccessorInputView accessor of the input mdspan */ -template -class batched_device_view_from_host { +template +class batched_device_view { + using input_view_t = + raft::mdspan, raft::row_major, AccessorInputView>; + public: - enum class memory_strategy { - device_only, // data is on device only (no copy needed) - copy_device, // data is explicitly moved to/from device buffers - managed_only, // data is on managed memory (system managed) - }; + // Compile-time strategy switch; see class-level documentation for semantics. + static constexpr bool kPassthrough = AccessorInputView::is_device_accessible; + + static_assert(kPassthrough || + cuda::std::is_convertible_v, + "copy_device path issues cudaMemcpyAsync against input_view_.data_handle() " + "in both directions, so AccessorInputView::data_handle_type must be " + "convertible to T*. To lift this, route prefetch/writeback through " + "raft::copy on submdspans of input_view_."); + + // Result of submdspan(layout_right, tuple, full_extent): + // layout_stride with AccessorInputView preserved. + using next_view_passthrough_type = + decltype(cuda::std::submdspan(std::declval(), + std::declval>(), + cuda::std::full_extent)); + + // Internal contiguous row-major device buffer. + using next_view_copy_type = raft::device_matrix_view; /** - * Create a batched device view from a host view and will handle the prefetch and - * writeback of the data. Each batch can be referenced exactly once by calling the next_view() - * method. + * @param res raft resources (must outlive this object) + * @param input_view mdspan to iterate over + * @param batch_size rows per batch (must be > 0 if input_view is non-empty) + * @param host_writeback copy each batch back to input_view_ after use + * (no-op for passthrough) + * @param initialize stage each batch's contents into the device buffer + * before returning it (no-op for passthrough) * - * @param res The resources to use - * @param host_view The host view to create the batched device view from - * @param batch_size The batch size - * @param host_writeback Whether to write back the data to the host (only for host memory) - * (default: false) - * @param initialize Whether to initialize the data (default: true) + * At least one of host_writeback / initialize must be true for non-empty input. */ - batched_device_view_from_host(raft::resources const& res, - raft::host_matrix_view host_view, - uint64_t batch_size, - bool host_writeback = false, - bool initialize = true) + batched_device_view(raft::resources const& res, + input_view_t input_view, + uint64_t batch_size, + bool host_writeback = false, + bool initialize = true) : res_(res), - host_view_(host_view), + input_view_(input_view), batch_size_(batch_size), offset_(0), batch_id_(-2), @@ -378,246 +397,174 @@ class batched_device_view_from_host { host_writeback_(host_writeback), initialize_(initialize) { - if (host_view.extent(0) == 0) { - mem_strategy_ = memory_strategy::device_only; - return; - } + if (input_view.extent(0) == 0) { return; } RAFT_EXPECTS(batch_size_ > 0, "batch_size must be greater than zero for non-empty input"); RAFT_EXPECTS(host_writeback_ || initialize_, "At least one of host_writeback or initialize must be true"); - RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr_, host_view.data_handle())); - switch (attr_.type) { - case cudaMemoryTypeUnregistered: - case cudaMemoryTypeHost: - case cudaMemoryTypeManaged: mem_strategy_ = memory_strategy::copy_device; break; - case cudaMemoryTypeDevice: mem_strategy_ = memory_strategy::device_only; break; - } - - RAFT_LOG_DEBUG("Memory strategy: %d for type %d, size %zu", - static_cast(mem_strategy_), - static_cast(attr_.type), - host_view.extent(0) * host_view.extent(1) * sizeof(T)); + RAFT_LOG_DEBUG("Memory strategy: %s for matrix of type %s, dimensions %zu x %zu", + kPassthrough ? "passthrough" : "copy_device", + typeid(T).name(), + input_view.extent(0), + input_view.extent(1)); - // buffer allocations - if (mem_strategy_ == memory_strategy::copy_device) { + // device buffers (copy_device only) + if constexpr (!kPassthrough) { try { device_mem_[0].emplace(raft::make_device_mdarray( res, raft::resource::get_workspace_resource_ref(res), - raft::make_extents(batch_size, host_view.extent(1)))); + raft::make_extents(batch_size, input_view.extent(1)))); device_ptr[0] = device_mem_[0]->data_handle(); - if (batch_size < static_cast(host_view.extent(0))) { + if (batch_size < static_cast(input_view.extent(0))) { device_mem_[1].emplace(raft::make_device_mdarray( res, raft::resource::get_workspace_resource_ref(res), - raft::make_extents(batch_size, host_view.extent(1)))); + raft::make_extents(batch_size, input_view.extent(1)))); device_ptr[1] = device_mem_[1]->data_handle(); } if (host_writeback_ && initialize_ && - batch_size * 2 < static_cast(host_view.extent(0))) { + batch_size * 2 < static_cast(input_view.extent(0))) { num_buffers_ = 3; device_mem_[2].emplace(raft::make_device_mdarray( res, raft::resource::get_workspace_resource_ref(res), - raft::make_extents(batch_size, host_view.extent(1)))); + raft::make_extents(batch_size, input_view.extent(1)))); device_ptr[2] = device_mem_[2]->data_handle(); } } catch (std::bad_alloc& e) { - if (attr_.devicePointer != nullptr && attr_.type == cudaMemoryTypeManaged) { - for (auto& mem : device_mem_) { - mem.reset(); - } - RAFT_LOG_DEBUG("Insufficient memory for device buffers, switching to managed memory"); - mem_strategy_ = memory_strategy::managed_only; - } else { - throw std::bad_alloc(); - } + throw std::bad_alloc(); } catch (raft::logic_error& e) { - if (attr_.devicePointer != nullptr && attr_.type == cudaMemoryTypeManaged) { - for (auto& mem : device_mem_) { - mem.reset(); - } - RAFT_LOG_DEBUG( - "Insufficient memory for device buffers (logic error), switching to managed memory"); - mem_strategy_ = memory_strategy::managed_only; - } else { - throw raft::logic_error("Insufficient memory for device buffers (logic error)"); - } + throw raft::logic_error("Insufficient memory for device buffers (logic error)"); } } - // setup stream pool if not already present + // ensure a stream pool with one stream per concurrent direction size_t required_streams = host_writeback_ && initialize_ ? 2 : 1; if (!res.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL) || raft::resource::get_stream_pool_size(res) < required_streams) { - // always create at least 2 streams to account for subsequent iterator calls raft::resource::set_cuda_stream_pool(res, std::make_shared(2)); } prefetch_stream_ = raft::resource::get_stream_from_stream_pool(res); writeback_stream_ = raft::resource::get_stream_from_stream_pool(res); - // if data is managed and not for_write_ we can set the attribute on the device ptr - if (mem_strategy_ == memory_strategy::managed_only) { - location_.type = cudaMemLocationTypeDevice; - location_.id = static_cast(raft::resource::get_device_id(res_)); - if (!host_writeback_) { - advise_read_mostly(host_view_.data_handle(), - host_view_.extent(0) * host_view_.extent(1) * sizeof(T)); - // TODO maybe also reset upon destruction - } - } - - // prefetch next batch (0) + // prime batch 0 prefetch_next_batch(); } - ~batched_device_view_from_host() noexcept + ~batched_device_view() noexcept { raft::resource::sync_stream(res_); - // No view was handed to the caller, so no device-side modifications can require writeback. + // Nothing was returned to the caller -> nothing to write back. if (batch_id_ < 0) { return; } - // if data is on host and for_write --> make sure to copy back last active - // if data is managed and evict --> evict last active - - // make sure to sync on prefetch stream & res - switch (mem_strategy_) { - case memory_strategy::managed_only: - if (!host_writeback_) { - uint32_t discard_pos = batch_id_ % num_buffers_; - size_t discard_size_rows = actual_batch_size_[discard_pos]; - if (batch_id_ > 0) { - discard_pos = (batch_id_ - 1) % num_buffers_; - discard_size_rows += batch_size_; - } - discard_managed_region(device_ptr[discard_pos], - discard_size_rows * host_view_.extent(1) * sizeof(T)); - writeback_stream_.synchronize(); + // Passthrough has no internal buffer and writes through input_view_ directly. + if constexpr (!kPassthrough) { + if (host_writeback_) { + uint32_t writeback_pos_last = batch_id_ % num_buffers_; + if (batch_id_ > 0) { + uint32_t writeback_pos = (batch_id_ - 1) % num_buffers_; + uint64_t writeback_offset = (batch_id_ - 1) * batch_size_; + writeback_from_device_to_host(device_ptr[writeback_pos], writeback_offset, batch_size_); } - break; - case memory_strategy::copy_device: - if (host_writeback_) { - uint32_t writeback_pos_last = batch_id_ % num_buffers_; - if (batch_id_ > 0) { - uint32_t writeback_pos = (batch_id_ - 1) % num_buffers_; - uint64_t writeback_offset = (batch_id_ - 1) * batch_size_; - writeback_from_device_to_host(device_ptr[writeback_pos], writeback_offset, batch_size_); - } - { - uint64_t writeback_offset_last = batch_id_ * batch_size_; - writeback_from_device_to_host(device_ptr[writeback_pos_last], - writeback_offset_last, - actual_batch_size_[writeback_pos_last]); - } - writeback_stream_.synchronize(); + { + uint64_t writeback_offset_last = batch_id_ * batch_size_; + writeback_from_device_to_host(device_ptr[writeback_pos_last], + writeback_offset_last, + actual_batch_size_[writeback_pos_last]); } - break; - case memory_strategy::device_only: break; + writeback_stream_.synchronize(); + } } } /** - * Returns the next view of the batch + * Advance to the next batch and return its view. + * + * In copy_device mode this also kicks off prefetch of the batch after it + * and writeback of the previously returned batch (when host_writeback is + * set). In passthrough mode the call is essentially a slice update. * - * This function will ensure the next batch is ready and will trigger the prefetch of the - * subsequent next batch. If writeback is enabled, the last active batch will be written back to - * the host. + * Return type is `next_view_passthrough_type` or `next_view_copy_type`; + * both are 2D mdspans of T with the same element- and extent-API surface. + * Iteration ends when extent(0) == 0; this is the only legal stop signal. * - * @return The next view of the batch + * Must be called exactly once per batch (never in a loop condition). */ - raft::device_matrix_view next_view() + auto next_view() { bool end_of_data = static_cast((batch_id_ + 1) * batch_size_) >= - static_cast(host_view_.extent(0)); + static_cast(input_view_.extent(0)); + + if constexpr (kPassthrough) { + // Slice goes through the accessor; data_handle() is never used. + if (end_of_data) { + auto end = static_cast(input_view_.extent(0)); + return cuda::std::submdspan( + input_view_, cuda::std::tuple{end, end}, cuda::std::full_extent); + } - // special case for empty host view or last batch surpassed - if (end_of_data) { - return raft::make_device_matrix_view(nullptr, 0, host_view_.extent(1)); - } + prefetch_next_batch(); - // trigger prefetch of next batch (also increments batch_id_) - prefetch_next_batch(); + uint32_t current_pos = batch_id_ % num_buffers_; + auto first = static_cast(batch_id_ * batch_size_); + auto last = static_cast(first + actual_batch_size_[current_pos]); + return cuda::std::submdspan( + input_view_, cuda::std::tuple{first, last}, cuda::std::full_extent); + } else { + auto cols = static_cast(input_view_.extent(1)); - uint32_t current_pos = batch_id_ % num_buffers_; - return raft::make_device_matrix_view( - device_ptr[current_pos], actual_batch_size_[current_pos], host_view_.extent(1)); + if (end_of_data) { return next_view_copy_type{nullptr, IdxT{0}, cols}; } + + prefetch_next_batch(); + + uint32_t current_pos = batch_id_ % num_buffers_; + return next_view_copy_type{ + device_ptr[current_pos], static_cast(actual_batch_size_[current_pos]), cols}; + } } private: /** - * Prefetch the next batch + * Advance batch_id_ to the next batch and stage I/O around it (copy_device only): + * * prefetch the (batch_id_ + 1)-th batch when initialize_ is set + * * write back the (batch_id_ - 1)-th batch when host_writeback_ is set + * In passthrough mode this only updates batch_id_ and offset_ bookkeeping. * - * This function will prefetch the next batch and will handle the writeback of the data. - * - * @return True if the next batch exists, false otherwise + * @return true if more input rows remain after this advance. */ bool prefetch_next_batch() { batch_id_++; - // ensure previous batch at position batch_id_ is ready + // wait for the prior round's prefetch / writeback before reusing the slots if (initialize_) { prefetch_stream_.synchronize(); } if (host_writeback_) { writeback_stream_.synchronize(); } - // this step will - // * write back data from batch_id_ - 1 - // * prefetch data for batch_id_ + 1 - - // if data is on host and host_writeback_ is true we will have to copy it back - // if data is on host and initialize_ is true we will have to copy it to the device_ptr - - // if data is managed and !host_writeback_ we can discard the data from device memory - // if data is managed and initialize_ is true we can prefetch it to the device - // if data is managed and !initialize_ we can discard and prefetch the data location + RAFT_EXPECTS(static_cast(offset_) <= input_view_.extent(0), "Offset out of bounds"); - // if data is on device only this is almost a noop, just prepping the pointers - - RAFT_EXPECTS(static_cast(offset_) <= host_view_.extent(0), "Offset out of bounds"); - - bool next_batch_exists = offset_ < static_cast(host_view_.extent(0)); + bool next_batch_exists = offset_ < static_cast(input_view_.extent(0)); if (next_batch_exists) { - // synchronize to ensure all previous operations are completed - // in particular all work on batch_id_ - 1 + // make sure all kernels touching batch_id_ - 1 (the now-writable slot) + // are done before we either overwrite it (initialize) or read it back raft::resource::sync_stream(res_); int32_t prefetch_pos = (batch_id_ + 1) % num_buffers_; - actual_batch_size_[prefetch_pos] = min(batch_size_, host_view_.extent(0) - offset_); - - switch (mem_strategy_) { - case memory_strategy::managed_only: - if (!host_writeback_ && batch_id_ > 1) { - uint32_t discard_pos = (batch_id_ - 1) % num_buffers_; - size_t discard_size = batch_size_ * host_view_.extent(1) * sizeof(T); - discard_managed_region(device_ptr[discard_pos], discard_size); - } - // prefetch next position - device_ptr[prefetch_pos] = host_view_.data_handle() + offset_ * host_view_.extent(1); - prefetch_managed_region( - device_ptr[prefetch_pos], - actual_batch_size_[prefetch_pos] * host_view_.extent(1) * sizeof(T)); - break; - case memory_strategy::copy_device: - if (host_writeback_ && batch_id_ > 0) { - // copy back last active - uint32_t writeback_pos = (batch_id_ - 1) % num_buffers_; - uint64_t writeback_offset = (batch_id_ - 1) * batch_size_; - writeback_from_device_to_host(device_ptr[writeback_pos], writeback_offset, batch_size_); - } - if (initialize_) { - // prefetch next position - prefetch_from_host_to_device( - device_ptr[prefetch_pos], offset_, actual_batch_size_[prefetch_pos]); - } - - break; - case memory_strategy::device_only: - // just move pointer to next position - device_ptr[prefetch_pos] = host_view_.data_handle() + offset_ * host_view_.extent(1); - break; + actual_batch_size_[prefetch_pos] = min(batch_size_, input_view_.extent(0) - offset_); + + if constexpr (!kPassthrough) { + if (host_writeback_ && batch_id_ > 0) { + uint32_t writeback_pos = (batch_id_ - 1) % num_buffers_; + uint64_t writeback_offset = (batch_id_ - 1) * batch_size_; + writeback_from_device_to_host(device_ptr[writeback_pos], writeback_offset, batch_size_); + } + if (initialize_) { + prefetch_from_host_to_device( + device_ptr[prefetch_pos], offset_, actual_batch_size_[prefetch_pos]); + } } offset_ += actual_batch_size_[prefetch_pos]; @@ -626,93 +573,49 @@ class batched_device_view_from_host { return next_batch_exists; } - void advise_read_mostly(T* ptr, size_t size) - { -#if CUDA_VERSION >= 13000 - RAFT_CUDA_TRY(cudaMemAdvise(ptr, size, cudaMemAdviseSetReadMostly, location_)); -#else - RAFT_CUDA_TRY(cudaMemAdvise_v2(ptr, size, cudaMemAdviseSetReadMostly, location_)); -#endif - } - - void discard_managed_region(T* dev_ptr, size_t size) - { -#if CUDA_VERSION >= 13000 - void* dptrs[1] = {dev_ptr}; - size_t sizes[1] = {size}; - RAFT_CUDA_TRY(cudaMemDiscardBatchAsync(dptrs, sizes, 1, 0, writeback_stream_)); -#endif - // FIXME: CUDA12 does not support discard - } - - void prefetch_managed_region(T* dev_ptr, size_t size) - { -#if CUDA_VERSION >= 13000 - if (initialize_) { - RAFT_CUDA_TRY(cudaMemPrefetchAsync(dev_ptr, size, location_, 0, prefetch_stream_)); - } else { - void* dptrs[1] = {dev_ptr}; - size_t sizes[1] = {size}; - RAFT_CUDA_TRY( - cudaMemDiscardAndPrefetchBatchAsync(dptrs, sizes, 1, location_, 0, prefetch_stream_)); - } -#else - // FIXME: CUDA12 does not support discard - so we just prefetch - if (initialize_) { - RAFT_CUDA_TRY(cudaMemPrefetchAsync_v2(dev_ptr, size, location_, 0, prefetch_stream_)); - } else { - RAFT_CUDA_TRY(cudaMemPrefetchAsync_v2(dev_ptr, size, location_, 0, prefetch_stream_)); - } -#endif - } - void prefetch_from_host_to_device(T* dev_ptr, size_t src_row_offset, size_t num_rows) { - const size_t n_elem = num_rows * host_view_.extent(1); + const size_t n_elem = num_rows * input_view_.extent(1); const size_t n_bytes = n_elem * sizeof(T); // use memcpy instead of raft::copy to avoid strange behavior with HMM/ATS memory - RAFT_CUDA_TRY(cudaMemcpyAsync(dev_ptr, - host_view_.data_handle() + src_row_offset * host_view_.extent(1), - n_bytes, - cudaMemcpyHostToDevice, - prefetch_stream_)); + RAFT_CUDA_TRY( + cudaMemcpyAsync(dev_ptr, + input_view_.data_handle() + src_row_offset * input_view_.extent(1), + n_bytes, + cudaMemcpyHostToDevice, + prefetch_stream_)); } void writeback_from_device_to_host(T* dev_ptr, size_t dst_row_offset, size_t num_rows) { - const size_t n_elem = num_rows * host_view_.extent(1); + const size_t n_elem = num_rows * input_view_.extent(1); const size_t n_bytes = n_elem * sizeof(T); // use memcpy instead of raft::copy to avoid strange behavior with HMM/ATS memory - RAFT_CUDA_TRY(cudaMemcpyAsync(host_view_.data_handle() + dst_row_offset * host_view_.extent(1), - dev_ptr, - n_bytes, - cudaMemcpyDeviceToHost, - writeback_stream_)); + RAFT_CUDA_TRY( + cudaMemcpyAsync(input_view_.data_handle() + dst_row_offset * input_view_.extent(1), + dev_ptr, + n_bytes, + cudaMemcpyDeviceToHost, + writeback_stream_)); } - // stream pool for local streams - std::optional> local_stream_pool_; + // streams (copy_device only; unused in passthrough) rmm::cuda_stream_view prefetch_stream_; rmm::cuda_stream_view writeback_stream_; // configuration - memory_strategy mem_strategy_; const raft::resources& res_; - bool initialize_; // initialize the data on the device - bool host_writeback_; // write back the data to the host + bool initialize_; + bool host_writeback_; - // batch position information + // iteration state uint64_t batch_size_; - int32_t batch_id_; - uint64_t offset_; + int32_t batch_id_; // -2 before any prefetch; >= 0 once a batch has been returned + uint64_t offset_; // first row of the upcoming-prefetch batch, in input_view_ rows - cudaMemLocation location_; - - // input pointer information - raft::host_matrix_view host_view_; - cudaPointerAttributes attr_; + input_view_t input_view_; - // internal device buffers + // device buffers (copy_device only) uint64_t num_buffers_; std::optional> device_mem_[3]; T* device_ptr[3]; diff --git a/cpp/tests/neighbors/ann_cagra/test_batched_device_view_from_host.cu b/cpp/tests/neighbors/ann_cagra/test_batched_device_view_from_host.cu index 1e1cc13093..f9a91cca3c 100644 --- a/cpp/tests/neighbors/ann_cagra/test_batched_device_view_from_host.cu +++ b/cpp/tests/neighbors/ann_cagra/test_batched_device_view_from_host.cu @@ -31,6 +31,14 @@ namespace cuvs::neighbors::cagra { using IdxT = uint32_t; +using DeviceAccessor = + raft::host_device_accessor, raft::memory_type::device>; +using HostAccessor = + raft::host_device_accessor, raft::memory_type::host>; +using PinnedAccessor = + raft::host_device_accessor, raft::memory_type::pinned>; +using ManagedAccessor = + raft::host_device_accessor, raft::memory_type::managed>; struct BatchConfig { bool initialize; @@ -43,19 +51,20 @@ struct DimsConfig { uint64_t batch_size; }; -class BatchedDeviceViewFromHostTest : public ::testing::Test { +class BatchedDeviceViewTest : public ::testing::Test { protected: void SetUp() override { raft::resource::sync_stream(res); } /** - * Run batched_device_view_from_host over host data, copy device views back, + * Run batched_device_view over host data, copy device views back, * and verify against the input. */ - template - void run_and_verify_batched(InputMatrixView input_view, - uint64_t batch_size, - bool host_writeback, - bool initialize) + template + void run_and_verify_batched( + raft::mdspan, raft::row_major, AccessorInputView> input_view, + uint64_t batch_size, + bool host_writeback, + bool initialize) { int64_t n_rows = input_view.extent(0); int64_t n_cols = input_view.extent(1); @@ -65,12 +74,8 @@ class BatchedDeviceViewFromHostTest : public ::testing::Test { int64_t total_processed = 0; { - cagra::detail::batched_device_view_from_host batched( - res, - raft::make_host_matrix_view(input_view.data_handle(), n_rows, n_cols), - batch_size, - host_writeback, - initialize); + cagra::detail::batched_device_view batched( + res, input_view, batch_size, host_writeback, initialize); while (true) { auto dev_view = batched.next_view(); if (dev_view.extent(0) == 0) break; @@ -81,7 +86,20 @@ class BatchedDeviceViewFromHostTest : public ::testing::Test { dev_view.extent(0) * dev_view.extent(1), raft::resource::get_cuda_stream(res)); } - if (host_writeback) { raft::matrix::fill(res, dev_view, IdxT(17)); } + if (host_writeback) { + // Re-wrap as a plain device_matrix_view to strip the (potentially + // layout_stride / pinned- or managed-accessor) shape that the + // passthrough path would otherwise hand us, so raft::matrix::fill's + // device_matrix_view overload accepts the call. dev_view is always + // exhaustive (contiguous row range of a row-major matrix), so + // (data_handle, extent(0), extent(1)) describes the same memory. + // This should eventually be fixed by adding a more generic + // overload to raft::matrix::fill. + raft::matrix::fill(res, + raft::make_device_matrix_view( + dev_view.data_handle(), dev_view.extent(0), dev_view.extent(1)), + IdxT(17)); + } total_processed += dev_view.extent(0); } } @@ -107,11 +125,11 @@ class BatchedDeviceViewFromHostTest : public ::testing::Test { raft::resources res; }; -TEST_F(BatchedDeviceViewFromHostTest, EmptyView) +TEST_F(BatchedDeviceViewTest, EmptyViewFromHost) { auto host_empty = raft::make_host_matrix(0, 8); auto host_view = host_empty.view(); - cagra::detail::batched_device_view_from_host batched( + cagra::detail::batched_device_view batched( res, host_view, /*batch_size=*/128, /*host_writeback=*/false, /*initialize=*/true); auto view = batched.next_view(); @@ -120,13 +138,25 @@ TEST_F(BatchedDeviceViewFromHostTest, EmptyView) EXPECT_EQ(view.data_handle(), nullptr); } +TEST_F(BatchedDeviceViewTest, EmptyViewFromDevice) +{ + auto device_empty = raft::make_device_matrix(res, 0, 8); + auto device_view = device_empty.view(); + cagra::detail::batched_device_view batched( + res, device_view, /*batch_size=*/128, /*host_writeback=*/false, /*initialize=*/true); + + auto view = batched.next_view(); + EXPECT_EQ(view.extent(0), 0); + EXPECT_EQ(view.extent(1), 8); + EXPECT_EQ(view.data_handle(), nullptr); +} + using BatchDimsParam = std::tuple; -class BatchedDeviceViewFromHostParameterizedTest - : public BatchedDeviceViewFromHostTest, - public ::testing::WithParamInterface {}; +class BatchedDeviceViewParameterizedTest : public BatchedDeviceViewTest, + public ::testing::WithParamInterface {}; -TEST_P(BatchedDeviceViewFromHostParameterizedTest, VectorHostData) +TEST_P(BatchedDeviceViewParameterizedTest, VectorHostData) { auto [batch_config, dims_config] = GetParam(); auto [initialize, host_writeback] = batch_config; @@ -140,46 +170,81 @@ TEST_P(BatchedDeviceViewFromHostParameterizedTest, VectorHostData) run_and_verify_batched(host_view, batch_size, host_writeback, initialize); } -TEST_P(BatchedDeviceViewFromHostParameterizedTest, PinnedMemory) +TEST_P(BatchedDeviceViewParameterizedTest, PinnedMemory) { auto [batch_config, dims_config] = GetParam(); auto [initialize, host_writeback] = batch_config; auto [n_rows, n_cols, batch_size] = dims_config; - auto host_matrix = raft::make_pinned_matrix(res, n_rows, n_cols); - auto host_view = host_matrix.view(); + auto pinned_matrix = raft::make_pinned_matrix(res, n_rows, n_cols); + // auto pinned_view = pinned_matrix.view(); + auto pinned_view = + raft::mdspan, raft::row_major, PinnedAccessor>( + pinned_matrix.data_handle(), n_rows, n_cols); + std::fill(pinned_view.data_handle(), pinned_view.data_handle() + n_rows * n_cols, IdxT(13)); - std::fill(host_view.data_handle(), host_view.data_handle() + n_rows * n_cols, IdxT(13)); + run_and_verify_batched(pinned_view, batch_size, host_writeback, initialize); +} - run_and_verify_batched(host_view, batch_size, host_writeback, initialize); +TEST_P(BatchedDeviceViewParameterizedTest, PinnedMemoryForcedToHost) +{ + auto [batch_config, dims_config] = GetParam(); + auto [initialize, host_writeback] = batch_config; + auto [n_rows, n_cols, batch_size] = dims_config; + + auto pinned_matrix = raft::make_pinned_matrix(res, n_rows, n_cols); + + auto pinned_view = + raft::mdspan, raft::row_major, HostAccessor>( + pinned_matrix.data_handle(), n_rows, n_cols); + + std::fill(pinned_view.data_handle(), pinned_view.data_handle() + n_rows * n_cols, IdxT(13)); + run_and_verify_batched(pinned_view, batch_size, host_writeback, initialize); } -TEST_P(BatchedDeviceViewFromHostParameterizedTest, ManagedMemory) +TEST_P(BatchedDeviceViewParameterizedTest, ManagedMemory) { auto [batch_config, dims_config] = GetParam(); auto [initialize, host_writeback] = batch_config; auto [n_rows, n_cols, batch_size] = dims_config; - auto host_matrix = raft::make_managed_matrix(res, n_rows, n_cols); - auto host_view = host_matrix.view(); + auto managed_matrix = raft::make_managed_matrix(res, n_rows, n_cols); + auto managed_view = managed_matrix.view(); - std::fill(host_view.data_handle(), host_view.data_handle() + n_rows * n_cols, IdxT(13)); + std::fill(managed_view.data_handle(), managed_view.data_handle() + n_rows * n_cols, IdxT(13)); - run_and_verify_batched(host_view, batch_size, host_writeback, initialize); + run_and_verify_batched(managed_view, batch_size, host_writeback, initialize); } -TEST_P(BatchedDeviceViewFromHostParameterizedTest, DeviceMemory) +TEST_P(BatchedDeviceViewParameterizedTest, ManagedMemoryForcedToHost) { auto [batch_config, dims_config] = GetParam(); auto [initialize, host_writeback] = batch_config; auto [n_rows, n_cols, batch_size] = dims_config; - auto host_matrix = raft::make_device_matrix(res, n_rows, n_cols); - auto host_view = host_matrix.view(); + auto managed_matrix = raft::make_managed_matrix(res, n_rows, n_cols); - raft::matrix::fill(res, host_view, IdxT(13)); + auto managed_view = + raft::mdspan, raft::row_major, HostAccessor>( + managed_matrix.data_handle(), n_rows, n_cols); - run_and_verify_batched(host_view, batch_size, host_writeback, initialize); + std::fill(managed_view.data_handle(), managed_view.data_handle() + n_rows * n_cols, IdxT(13)); + + run_and_verify_batched(managed_view, batch_size, host_writeback, initialize); +} + +TEST_P(BatchedDeviceViewParameterizedTest, DeviceMemory) +{ + auto [batch_config, dims_config] = GetParam(); + auto [initialize, host_writeback] = batch_config; + auto [n_rows, n_cols, batch_size] = dims_config; + + auto device_matrix = raft::make_device_matrix(res, n_rows, n_cols); + auto device_view = device_matrix.view(); + + raft::matrix::fill(res, device_view, IdxT(13)); + + run_and_verify_batched(device_view, batch_size, host_writeback, initialize); } static const std::array kBatchConfigs = {{ @@ -198,7 +263,7 @@ static const std::array kDimsConfigs = {{ }}; INSTANTIATE_TEST_SUITE_P(BatchConfigs, - BatchedDeviceViewFromHostParameterizedTest, + BatchedDeviceViewParameterizedTest, ::testing::Combine(::testing::ValuesIn(kBatchConfigs), ::testing::ValuesIn(kDimsConfigs))); From cf86064e40f8aad2137a26e0cc0609a1afa5ad7b Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Fri, 1 May 2026 02:31:32 +0000 Subject: [PATCH 34/40] simplify batched_view to 2 buffers 1 copy stream --- .../neighbors/detail/cagra/cagra_build.cuh | 2 +- cpp/src/neighbors/detail/cagra/graph_core.cuh | 6 +- cpp/src/neighbors/detail/cagra/utils.hpp | 263 ++++++++++-------- cpp/tests/CMakeLists.txt | 3 +- ...om_host.cu => test_batched_device_view.cu} | 5 + 5 files changed, 162 insertions(+), 117 deletions(-) rename cpp/tests/neighbors/ann_cagra/{test_batched_device_view_from_host.cu => test_batched_device_view.cu} (97%) diff --git a/cpp/src/neighbors/detail/cagra/cagra_build.cuh b/cpp/src/neighbors/detail/cagra/cagra_build.cuh index b7fb57b2dc..74dce643ea 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_build.cuh @@ -850,7 +850,7 @@ inline std::pair optimize_workspace_size(size_t n_rows, n_rows * sizeof(uint32_t) + graph_degree * sizeof(uint32_t); // in_edge_count + hist // additional memory for combine stage on device (3 batches) - size_t combine_dev = 3 * batch_size * graph_degree * index_size; // d_output_graph(3*batch) + size_t combine_dev = 2 * batch_size * graph_degree * index_size; // d_output_graph(2*batch) if (mst_optimize) { combine_dev += 2 * batch_size * graph_degree * index_size; // d_mst_graph(2*batch) combine_dev += 2 * batch_size * sizeof(uint32_t); // d_mst_graph_num_edges(2*batch) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index eeaf35d4f9..2fbf05cf60 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -858,6 +858,10 @@ void merge_graph_gpu( i_batch, guarantee_connectivity, d_check_num_protected_edges.data_handle()); + + d_output_graph.prefetch_next(); + d_mst_graph.prefetch_next(); + d_mst_graph_num_edges.prefetch_next(); } bool check_num_protected_edges = true; @@ -1603,7 +1607,7 @@ void prune_graph_gpu( d_invalid_neighbor_list.data_handle(), dev_stats.data_handle()); - raft::resource::sync_stream(res); + d_output_graph.prefetch_next(); RAFT_LOG_DEBUG( "# Pruning kNN Graph on GPUs (%.1lf %%)\r", (double)std::min((i_batch + 1) * batch_size, graph_size) / graph_size * 100); diff --git a/cpp/src/neighbors/detail/cagra/utils.hpp b/cpp/src/neighbors/detail/cagra/utils.hpp index ea1d830192..5c38f31684 100644 --- a/cpp/src/neighbors/detail/cagra/utils.hpp +++ b/cpp/src/neighbors/detail/cagra/utils.hpp @@ -302,7 +302,7 @@ void copy_with_padding( } /** - * Iterate a 2D mdspan in row-batches with overlapping prefetch/writeback. + * Iterate a 2D mdspan in row-batches with overlapping copies and kernel work. * * The strategy is selected at compile time from * `AccessorInputView::is_device_accessible`: @@ -310,8 +310,22 @@ void copy_with_padding( * No buffering, no arithmetic on `input_view_.data_handle()`. * * copy_device: each batch is staged through an internal device buffer * via `cudaMemcpyAsync`. When `host_writeback` is set, every returned - * batch is copied back to `input_view_` on destruction (and earlier - * batches incrementally as iteration progresses). + * batch is copied back to `input_view_`; flushing happens lazily during + * subsequent `prefetch_next()` calls and the destructor flushes the tail. + * + * Concurrency model (copy_device): + * * Two device buffers, ping-ponged across iterations. + * * One non-res_ stream `copy_stream_` carries both D2H writebacks and + * H2D prefetches in FIFO order. + * * `next_view()` drains `copy_stream_` so the caller sees a fully-staged + * slot; it does NOT touch res_'s stream. + * * `prefetch_next()` queues D2H of the previous-iter batch and H2D of the + * next-iter batch on `copy_stream_` (slots different from the running + * kernel's slot), then ends with `sync_stream(res_)`. The host stall on + * that final sync overlaps with both the kernel on `res_` and the copies + * on `copy_stream_` -- so for pageable host the host-bound + * pageable->pinned phase of `cudaMemcpyAsync` also overlaps with the + * kernel. * * Usage: * ``` @@ -319,8 +333,9 @@ void copy_with_padding( * res, input_view, batch_size, host_writeback, initialize); * for (;;) { * auto device_view = view.next_view(); - * if (device_view.extent(0) == 0) { break; } // sole stop condition - * // ... call next_view() exactly once per batch; never in the loop condition + * if (device_view.extent(0) == 0) { break; } // sole stop condition + * kernel<<<..., raft::resource::get_cuda_stream(res)>>>(device_view); + * view.prefetch_next(); // pair with next_view(); copies overlap with kernel * } * ``` * @@ -330,17 +345,21 @@ void copy_with_padding( * at compile time from the accessor; `batch_load_iterator` takes * `(T const*, n_rows, row_width)` and decides at runtime via * `cudaPointerGetAttributes` (also forces a copy for HMM/ATS sources). - * * API: this class exposes a single-pass `next_view()` that returns an - * mdspan; `batch_load_iterator` is an STL iterator (begin/end, - * operator++/--, random access) that yields a `batch` wrapping a - * `T const*`. + * * API shape: both classes split iteration from prefetch -- here it's + * `next_view()` + `prefetch_next()`; in `batch_load_iterator` it's + * `operator*` + `prefetch_next_batch()`. `batch_load_iterator` is also + * an STL iterator with begin/end and random access; this class is + * single-pass and returns mdspans directly. * * Mutability: returned views here are mutable T* and support host * writeback on destruction; `batch_load_iterator` is read-only and never * copies back. - * * Pipelining: copy_device uses 2-3 device buffers and dedicated prefetch - * and writeback streams for true overlap; `batch_load_iterator` uses - * 1-2 buffers and exposes a synchronous `prefetch_next_batch()` that the - * caller must interleave with kernels on a separate stream. + * * Pipelining: copy_device uses 2 device buffers and a single non-res_ + * stream that carries both directions; cross-iter ordering of D2H and + * H2D on the same slot is enforced by FIFO, and overlap with the kernel + * is achieved by issuing copies *before* the trailing `sync_stream(res_)` + * in `prefetch_next()`. `batch_load_iterator` uses 1-2 buffers, never + * writes back, and the caller is responsible for using a separate stream + * for kernels in order to overlap with `prefetch_next_batch()`. * * @tparam T element type * @tparam IdxT index type for the input mdspan extents @@ -391,9 +410,9 @@ class batched_device_view { : res_(res), input_view_(input_view), batch_size_(batch_size), - offset_(0), - batch_id_(-2), - num_buffers_(2), + batch_id_(-1), + next_prefetched_(false), + last_flushed_batch_id_(-1), host_writeback_(host_writeback), initialize_(initialize) { @@ -409,7 +428,11 @@ class batched_device_view { input_view.extent(0), input_view.extent(1)); - // device buffers (copy_device only) + // device buffers (copy_device only). Two slots suffice: at any iter K + // the kernel runs on slot K%nb while prefetch_next() queues a D2H of + // slot (K-1)%nb and an H2D into slot (K+1)%nb on the *same* copy_stream_, + // so D2H and H2D ordering on one slot is enforced by FIFO -- no third + // buffer needed. if constexpr (!kPassthrough) { try { device_mem_[0].emplace(raft::make_device_mdarray( @@ -424,15 +447,6 @@ class batched_device_view { raft::make_extents(batch_size, input_view.extent(1)))); device_ptr[1] = device_mem_[1]->data_handle(); } - if (host_writeback_ && initialize_ && - batch_size * 2 < static_cast(input_view.extent(0))) { - num_buffers_ = 3; - device_mem_[2].emplace(raft::make_device_mdarray( - res, - raft::resource::get_workspace_resource_ref(res), - raft::make_extents(batch_size, input_view.extent(1)))); - device_ptr[2] = device_mem_[2]->data_handle(); - } } catch (std::bad_alloc& e) { throw std::bad_alloc(); } catch (raft::logic_error& e) { @@ -440,137 +454,160 @@ class batched_device_view { } } - // ensure a stream pool with one stream per concurrent direction - size_t required_streams = host_writeback_ && initialize_ ? 2 : 1; + // One non-res_ stream is enough: D2H and H2D for a given iter are queued + // back-to-back on this stream and run concurrently with the user's kernel + // on res_'s stream. if (!res.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL) || - raft::resource::get_stream_pool_size(res) < required_streams) { - raft::resource::set_cuda_stream_pool(res, std::make_shared(2)); + raft::resource::get_stream_pool_size(res) < 1) { + raft::resource::set_cuda_stream_pool(res, std::make_shared(1)); } - prefetch_stream_ = raft::resource::get_stream_from_stream_pool(res); - writeback_stream_ = raft::resource::get_stream_from_stream_pool(res); + copy_stream_ = raft::resource::get_stream_from_stream_pool(res); - // prime batch 0 - prefetch_next_batch(); + // Prime batch 0 (slot 0 staged on copy_stream_; first next_view() syncs). + issue_prefetch_for_next_batch(); } ~batched_device_view() noexcept { raft::resource::sync_stream(res_); - // Nothing was returned to the caller -> nothing to write back. + // Nothing was ever returned (empty input or no next_view() call); streams + // may still be default-constructed so bail out early. if (batch_id_ < 0) { return; } - // Passthrough has no internal buffer and writes through input_view_ directly. if constexpr (!kPassthrough) { if (host_writeback_) { - uint32_t writeback_pos_last = batch_id_ % num_buffers_; - if (batch_id_ > 0) { - uint32_t writeback_pos = (batch_id_ - 1) % num_buffers_; - uint64_t writeback_offset = (batch_id_ - 1) * batch_size_; - writeback_from_device_to_host(device_ptr[writeback_pos], writeback_offset, batch_size_); - } - { - uint64_t writeback_offset_last = batch_id_ * batch_size_; - writeback_from_device_to_host(device_ptr[writeback_pos_last], - writeback_offset_last, - actual_batch_size_[writeback_pos_last]); + // Each prefetch_next() flushes batch (batch_id_ - 1) lazily; that + // leaves the most recent 1 batch (normal exit on empty next_view) or + // up to 2 batches (early break without final prefetch_next()) still + // pending. Flush whatever's left on copy_stream_ FIFO. + for (int32_t i = last_flushed_batch_id_ + 1; i <= batch_id_; ++i) { + uint32_t pos = i % 2; + uint64_t off = static_cast(i) * batch_size_; + writeback_from_device_to_host(device_ptr[pos], off, actual_batch_size_[pos]); } - writeback_stream_.synchronize(); + copy_stream_.synchronize(); } } } /** - * Advance to the next batch and return its view. + * Return the view of the batch staged by the constructor or the most recent + * `prefetch_next()`. After this call, that batch is the "current" batch; its + * slot is owned by the caller until `prefetch_next()` advances the pipeline. * - * In copy_device mode this also kicks off prefetch of the batch after it - * and writeback of the previously returned batch (when host_writeback is - * set). In passthrough mode the call is essentially a slice update. - * - * Return type is `next_view_passthrough_type` or `next_view_copy_type`; - * both are 2D mdspans of T with the same element- and extent-API surface. + * Return type is `next_view_passthrough_type` or `next_view_copy_type`; both + * are 2D mdspans of T with the same element- and extent-API surface. * Iteration ends when extent(0) == 0; this is the only legal stop signal. * - * Must be called exactly once per batch (never in a loop condition). + * Pair every non-empty `next_view()` with a `prefetch_next()` call. Skipping + * the pairing simply stops iteration at the current batch. */ auto next_view() { - bool end_of_data = static_cast((batch_id_ + 1) * batch_size_) >= - static_cast(input_view_.extent(0)); - if constexpr (kPassthrough) { - // Slice goes through the accessor; data_handle() is never used. - if (end_of_data) { + // Passthrough has no buffer; the slice goes through the accessor. + if (!next_prefetched_) { auto end = static_cast(input_view_.extent(0)); return cuda::std::submdspan( input_view_, cuda::std::tuple{end, end}, cuda::std::full_extent); } + ++batch_id_; + next_prefetched_ = false; - prefetch_next_batch(); - - uint32_t current_pos = batch_id_ % num_buffers_; + uint32_t current_pos = batch_id_ % 2; auto first = static_cast(batch_id_ * batch_size_); auto last = static_cast(first + actual_batch_size_[current_pos]); return cuda::std::submdspan( input_view_, cuda::std::tuple{first, last}, cuda::std::full_extent); } else { auto cols = static_cast(input_view_.extent(1)); + if (!next_prefetched_) { return next_view_copy_type{nullptr, IdxT{0}, cols}; } - if (end_of_data) { return next_view_copy_type{nullptr, IdxT{0}, cols}; } + // Drain the copies queued by the previous prefetch_next() / ctor: this + // ensures the slot we're about to hand out is fully staged AND that the + // writeback for the older batch (if any) has finished before we return + // -- which matters for slot recycling on subsequent iterations. + copy_stream_.synchronize(); - prefetch_next_batch(); + ++batch_id_; + next_prefetched_ = false; - uint32_t current_pos = batch_id_ % num_buffers_; + uint32_t current_pos = batch_id_ % 2; return next_view_copy_type{ device_ptr[current_pos], static_cast(actual_batch_size_[current_pos]), cols}; } } - private: /** - * Advance batch_id_ to the next batch and stage I/O around it (copy_device only): - * * prefetch the (batch_id_ + 1)-th batch when initialize_ is set - * * write back the (batch_id_ - 1)-th batch when host_writeback_ is set - * In passthrough mode this only updates batch_id_ and offset_ bookkeeping. + * Advance the prefetch pipeline -- call once after each non-empty + * `next_view()`, AFTER launching the kernel on res_'s stream. + * + * In copy_device mode this: + * 1. Queues D2H of batch (batch_id_ - 1) on copy_stream_ (the slot the + * kernel is NOT on; data is from the previous iter's kernel which has + * already been sync'd on res_). Skipped on the first iteration when no + * previous batch exists. + * 2. Queues H2D of batch (batch_id_ + 1) on copy_stream_ (also a different + * slot from the running kernel). + * 3. Calls sync_stream(res_) at the *end*. + * + * Steps 1-2 run on copy_stream_ concurrently with the just-launched kernel + * on res_'s stream. The host stall during cudaMemcpyAsync's pageable->pinned + * staging (for pageable host sources/destinations) and the host stall on + * step 3 both overlap with the kernel: that is what makes the pipeline + * actually asynchronous, even for plain pageable host memory. * - * @return true if more input rows remain after this advance. + * In passthrough mode this is pure bookkeeping (no copies, no syncs). */ - bool prefetch_next_batch() + void prefetch_next() { - batch_id_++; - - // wait for the prior round's prefetch / writeback before reusing the slots - if (initialize_) { prefetch_stream_.synchronize(); } - if (host_writeback_) { writeback_stream_.synchronize(); } - - RAFT_EXPECTS(static_cast(offset_) <= input_view_.extent(0), "Offset out of bounds"); + if constexpr (!kPassthrough) { + if (host_writeback_ && batch_id_ - 1 > last_flushed_batch_id_) { + // Writeback batch (batch_id_ - 1) -- the slot the kernel is *not* on. + // The corresponding kernel (kernel-(batch_id_-1)) finished at the end + // of the previous prefetch_next() (sync_stream(res_)), so its writes + // are globally visible. + uint32_t pos = (batch_id_ - 1) % 2; + uint64_t off = static_cast(batch_id_ - 1) * batch_size_; + writeback_from_device_to_host(device_ptr[pos], off, actual_batch_size_[pos]); + last_flushed_batch_id_ = batch_id_ - 1; + } + } - bool next_batch_exists = offset_ < static_cast(input_view_.extent(0)); + issue_prefetch_for_next_batch(); - if (next_batch_exists) { - // make sure all kernels touching batch_id_ - 1 (the now-writable slot) - // are done before we either overwrite it (initialize) or read it back - raft::resource::sync_stream(res_); + if constexpr (!kPassthrough) { + // Wait for the kernel we paired with this prefetch_next() before + // returning. + if (input_view_.extent(0) == 0) raft::resource::sync_stream(res_); + } + } - int32_t prefetch_pos = (batch_id_ + 1) % num_buffers_; - actual_batch_size_[prefetch_pos] = min(batch_size_, input_view_.extent(0) - offset_); + private: + /** + * Stage batch (batch_id_ + 1) into slot ((batch_id_ + 1) % 2) on + * copy_stream_, after any prior op on the stream (FIFO). Pure bookkeeping + * in passthrough or !initialize_ mode. Sets next_prefetched_ accordingly. + */ + void issue_prefetch_for_next_batch() + { + uint64_t target_offset = static_cast(batch_id_ + 1) * batch_size_; + if (target_offset >= static_cast(input_view_.extent(0))) { + next_prefetched_ = false; + return; + } + int32_t prefetch_pos = (batch_id_ + 1) % 2; + actual_batch_size_[prefetch_pos] = + static_cast(min(batch_size_, input_view_.extent(0) - target_offset)); - if constexpr (!kPassthrough) { - if (host_writeback_ && batch_id_ > 0) { - uint32_t writeback_pos = (batch_id_ - 1) % num_buffers_; - uint64_t writeback_offset = (batch_id_ - 1) * batch_size_; - writeback_from_device_to_host(device_ptr[writeback_pos], writeback_offset, batch_size_); - } - if (initialize_) { - prefetch_from_host_to_device( - device_ptr[prefetch_pos], offset_, actual_batch_size_[prefetch_pos]); - } + if constexpr (!kPassthrough) { + if (initialize_) { + prefetch_from_host_to_device( + device_ptr[prefetch_pos], target_offset, actual_batch_size_[prefetch_pos]); } - - offset_ += actual_batch_size_[prefetch_pos]; } - - return next_batch_exists; + next_prefetched_ = true; } void prefetch_from_host_to_device(T* dev_ptr, size_t src_row_offset, size_t num_rows) @@ -583,7 +620,7 @@ class batched_device_view { input_view_.data_handle() + src_row_offset * input_view_.extent(1), n_bytes, cudaMemcpyHostToDevice, - prefetch_stream_)); + copy_stream_)); } void writeback_from_device_to_host(T* dev_ptr, size_t dst_row_offset, size_t num_rows) @@ -596,12 +633,12 @@ class batched_device_view { dev_ptr, n_bytes, cudaMemcpyDeviceToHost, - writeback_stream_)); + copy_stream_)); } - // streams (copy_device only; unused in passthrough) - rmm::cuda_stream_view prefetch_stream_; - rmm::cuda_stream_view writeback_stream_; + // single non-res_ stream that carries both H2D prefetches and D2H writebacks + // in FIFO order (copy_device only; unused in passthrough) + rmm::cuda_stream_view copy_stream_; // configuration const raft::resources& res_; @@ -610,16 +647,16 @@ class batched_device_view { // iteration state uint64_t batch_size_; - int32_t batch_id_; // -2 before any prefetch; >= 0 once a batch has been returned - uint64_t offset_; // first row of the upcoming-prefetch batch, in input_view_ rows + int32_t batch_id_; // most-recently-returned batch id; -1 if none returned + bool next_prefetched_; // slot for batch_id_+1 holds staged data + int32_t last_flushed_batch_id_; // highest batch id whose writeback has been issued; -1 if none input_view_t input_view_; // device buffers (copy_device only) - uint64_t num_buffers_; - std::optional> device_mem_[3]; - T* device_ptr[3]; - uint32_t actual_batch_size_[3]; + std::optional> device_mem_[2]; + T* device_ptr[2]; + uint32_t actual_batch_size_[2]; }; } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 4be381689e..34c067e9f5 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -182,8 +182,7 @@ ConfigureTest( ConfigureTest( NAME NEIGHBORS_ANN_CAGRA_HELPERS_TEST - PATH neighbors/ann_cagra/test_optimize_uint32_t.cu - neighbors/ann_cagra/test_batched_device_view_from_host.cu + PATH neighbors/ann_cagra/test_optimize_uint32_t.cu neighbors/ann_cagra/test_batched_device_view.cu GPUS 1 PERCENT 100 ) diff --git a/cpp/tests/neighbors/ann_cagra/test_batched_device_view_from_host.cu b/cpp/tests/neighbors/ann_cagra/test_batched_device_view.cu similarity index 97% rename from cpp/tests/neighbors/ann_cagra/test_batched_device_view_from_host.cu rename to cpp/tests/neighbors/ann_cagra/test_batched_device_view.cu index f9a91cca3c..6733ac784a 100644 --- a/cpp/tests/neighbors/ann_cagra/test_batched_device_view_from_host.cu +++ b/cpp/tests/neighbors/ann_cagra/test_batched_device_view.cu @@ -101,6 +101,11 @@ class BatchedDeviceViewTest : public ::testing::Test { IdxT(17)); } total_processed += dev_view.extent(0); + + // Pair next_view() with prefetch_next(): the next batch's H2D and the + // previous batch's D2H run on copy_stream_ concurrently with the + // raft::copy / raft::matrix::fill kernels we just queued on res_. + batched.prefetch_next(); } } raft::resource::sync_stream(res); From 821eae6e5065218ea926ede1cc1dcda7d3f51a3f Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Mon, 4 May 2026 08:56:00 +0000 Subject: [PATCH 35/40] stream-sync fix typo --- cpp/src/neighbors/detail/cagra/utils.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/neighbors/detail/cagra/utils.hpp b/cpp/src/neighbors/detail/cagra/utils.hpp index 5c38f31684..b8ea4e3d54 100644 --- a/cpp/src/neighbors/detail/cagra/utils.hpp +++ b/cpp/src/neighbors/detail/cagra/utils.hpp @@ -580,7 +580,7 @@ class batched_device_view { if constexpr (!kPassthrough) { // Wait for the kernel we paired with this prefetch_next() before // returning. - if (input_view_.extent(0) == 0) raft::resource::sync_stream(res_); + if (input_view_.extent(0) != 0) raft::resource::sync_stream(res_); } } From 6211ff3e2cceec8d67979210682ef22594f235e2 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Thu, 7 May 2026 23:35:06 +0000 Subject: [PATCH 36/40] merge into batch_load_iterator --- cpp/src/cluster/detail/kmeans_batched.cuh | 9 +- cpp/src/neighbors/detail/ann_utils.cuh | 872 ++++++++++++++---- cpp/src/neighbors/detail/cagra/add_nodes.cuh | 7 +- .../neighbors/detail/cagra/cagra_build.cuh | 16 +- cpp/src/neighbors/detail/cagra/graph_core.cuh | 110 ++- cpp/src/neighbors/detail/cagra/utils.hpp | 358 ------- cpp/src/neighbors/detail/nn_descent.cuh | 9 +- .../neighbors/detail/vamana/vamana_build.cuh | 16 +- cpp/src/neighbors/detail/vpq_dataset.cuh | 25 +- cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh | 36 +- cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh | 14 +- cpp/src/neighbors/ivf_pq/ivf_pq_transform.cuh | 15 +- .../neighbors/scann/detail/scann_build.cuh | 16 +- cpp/tests/CMakeLists.txt | 2 +- ...ce_view.cu => test_batch_load_iterator.cu} | 192 +++- 15 files changed, 1000 insertions(+), 697 deletions(-) rename cpp/tests/neighbors/ann_cagra/{test_batched_device_view.cu => test_batch_load_iterator.cu} (52%) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 6c0b9a3253..8d566fae59 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -317,8 +317,13 @@ void fit(raft::resources const& handle, // Reset per-iteration state T prior_cluster_cost = 0; - cuvs::spatial::knn::detail::utils::batch_load_iterator data_batches( - X.data_handle(), n_samples, n_features, streaming_batch_size, stream); + auto data_batches = cuvs::spatial::knn::detail::utils::make_batch_load_iterator( + handle, + X.data_handle(), + static_cast(n_samples), + static_cast(n_features), + streaming_batch_size, + stream); for (n_iter[0] = 1; n_iter[0] <= iter_params.max_iter; ++n_iter[0]) { RAFT_LOG_DEBUG("KMeans batched: Iteration %d", n_iter[0]); diff --git a/cpp/src/neighbors/detail/ann_utils.cuh b/cpp/src/neighbors/detail/ann_utils.cuh index a7872f87a0..4089c58a03 100644 --- a/cpp/src/neighbors/detail/ann_utils.cuh +++ b/cpp/src/neighbors/detail/ann_utils.cuh @@ -6,8 +6,14 @@ #pragma once #include +#include +#include #include #include +#include +#include +#include +#include #include #include #include @@ -21,6 +27,9 @@ #include #include +#include +#include +#include namespace cuvs::spatial::knn::detail::utils { @@ -378,166 +387,369 @@ void copy_selected(IdxT n_rows, } /** - * A batch input iterator over the data source. - * Given an input pointer, it decides whether the current device has the access to the data and - * gives it back to the user in batches. Three scenarios are possible: + * Helper that returns the stream to use for non-blocking copies and a flag indicating whether + * cross-stream pipelining (prefetch / writeback) can be enabled. * - * 1. if `source == nullptr`: then `batch.data() == nullptr` - * 2. if `source` is accessible from the device, `batch.data()` points directly at the source at - * the proper offsets on each iteration. - * 3. if `source` is not accessible from the device, `batch.data()` points to an intermediate - * buffer; the corresponding data is copied in the given `stream` on every iterator dereference - * (i.e. batches can be skipped). Dereferencing the same batch two times in a row does not force - * the raft::copy. + * If `res` has a CUDA stream pool with at least one stream, the first pool stream is used and + * `true` is returned (prefetch can run concurrently with kernels on res's main stream). Otherwise + * the main stream itself is returned with `false`, and the caller should treat prefetch as a + * no-op (no overlap is possible on a single stream). + */ +inline auto get_prefetch_stream(raft::resources const& res) + -> std::pair +{ + if (res.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL) && + raft::resource::get_stream_pool_size(res) >= 1) { + return {raft::resource::get_stream_from_stream_pool(res), true}; + } + return {raft::resource::get_cuda_stream(res), false}; +} + +/** + * Iterate a 2D mdspan in row-batches with optional pipelined H2D / D2H copies and kernel work. * - * In all three scenarios, the number of iterations, batch offsets and sizes are the same. + * Strategy is selected at compile time from `MdspanT::accessor_type::is_device_accessible`: + * * passthrough: each batch is a row-range of `input_view_` directly; no buffering, no copy. + * * copy_device: each batch is staged through one or two internal device buffers via + * `cudaMemcpyAsync` on the caller-supplied `copy_stream`. With `prefetch=true`, two buffers + * are used as a ring so that the next batch's H2D and the previous batch's D2H can overlap + * with the user kernel running on `res`'s main stream. * - * The iterator can be reused. If the number of iterations is one, at most one raft::copy will ever - * be invoked (i.e. small datasets are not reloaded multiple times). + * Three orthogonal flags control behavior of the copy_device strategy: + * * `prefetch`: if true, allocate two device buffers and pipeline copies using + * `prefetch_next_batch()`. If false, copies happen synchronously at `operator*` and one buffer + * is allocated. + * * `initialize`: if true, stage source rows H2D into the buffer before yielding the batch. + * If false, the buffer is handed out uninitialized (kernel produces the data from scratch). + * * `host_writeback`: if true, queue D2H of every advanced batch back to `input_view_`. + * Pending writebacks are flushed on destruction. * - * In the case of pageable host buffer input, the iterator is by default (almost) synchronous due to - * the behavior of raft::copy. In order to achieve kernel and copy overlapping, a - * prefetch_next_batch (synchronous) API is provided. Note that since prefetch API is synchronous, - * user may want to schedule kernel, which is asynchronous, first. User is responsible to properly - * manage the order of prefetch and kernel to ensure overlapping. + * `initialize` and `host_writeback` are independent: it is legal to skip H2D + * (`initialize=false, host_writeback=true`) when the kernel produces the result from scratch. + * At least one of them must be true. + * + * Stream model: + * * The user passes `copy_stream`. With `prefetch=true`, this should be a stream distinct from + * `res`'s main stream (use `get_prefetch_stream(res)`); otherwise no real overlap is possible. + * * `prefetch_next_batch()` queues D2H of the just-completed batch (if dirty) followed by + * H2D of the next batch (if `initialize`) on `copy_stream`. With prefetch enabled it then + * calls `sync_stream(res)` so the host stall on the main stream overlaps with the copies on + * `copy_stream`. With prefetch disabled, it synchronizes `copy_stream` directly. + * * `operator*` drains `copy_stream` so the slot is fully staged before the caller dereferences. + * + * Iteration ends when `operator++` reaches `n_iters_`. The iterator can be reused via `reset()`. + * + * Usage with prefetch (matches the legacy `batch_load_iterator` pattern): + * ``` + * auto [copy_stream, enable_prefetch] = utils::get_prefetch_stream(res); + * utils::batch_load_iterator iter(res, view, batch_size, copy_stream, mr, enable_prefetch); + * iter.prefetch_next_batch(); + * for (auto const& batch : iter) { + * kernel<<<..., raft::resource::get_cuda_stream(res)>>>(batch.data(), ...); + * iter.prefetch_next_batch(); + * } + * ``` + * + * Usage with writeback (replaces `batched_device_view`): + * ``` + * auto [copy_stream, enable_prefetch] = utils::get_prefetch_stream(res); + * utils::batch_load_iterator iter(res, view, batch_size, copy_stream, mr, enable_prefetch, + * /\*initialize=*\/false, /\*host_writeback=*\/true); + * iter.prefetch_next_batch(); + * for (auto& batch : iter) { + * kernel<<<..., raft::resource::get_cuda_stream(res)>>>(batch.view()); + * iter.prefetch_next_batch(); + * } + * ``` */ -template +template struct batch_load_iterator { - using size_type = size_t; + using mdspan_type = MdspanT; + using accessor_type = typename MdspanT::accessor_type; + using element_type = typename MdspanT::element_type; + using index_type = typename MdspanT::index_type; + using value_type_d = std::remove_const_t; + using size_type = size_t; - /** A single batch of data residing in device memory. */ + static constexpr bool kPassthrough = accessor_type::is_device_accessible; + + // Type returned by `view()` for the passthrough strategy: a 2D submdspan of `input_view_` over a + // contiguous row range. Built without ever calling `data_handle()` on the input mdspan, so this + // stays valid even for future device mdspans whose accessor exposes no raw pointer. + // (Per the mdspan spec, slicing a `layout_right` with a `tuple{lo, hi}` over the leading dim + // yields a `layout_stride` mdspan with the input's accessor preserved.) + using passthrough_view_type = + decltype(cuda::std::submdspan(std::declval(), + std::declval>(), + cuda::std::full_extent)); + // Type returned by `view()` for the copy_device strategy: a row-major exhaustive device view + // over the iterator's internal device buffer. + using copy_view_type = raft::device_matrix_view; + using batch_view_type = std::conditional_t; + + /** A single batch of data residing in (or accessible from) device memory. */ struct batch { ~batch() noexcept { - /* - If there's no copy, there's no allocation owned by the batch. - If there's no allocation, there's no guarantee that the device pointer is stream-ordered. - If there's no stream order guarantee, we must synchronize with the stream before the batch is - destroyed to make sure all GPU operations in that stream finish earlier. - */ - if (!does_copy()) { RAFT_CUDA_TRY_NO_THROW(cudaStreamSynchronize(stream_)); } + if constexpr (!kPassthrough) { + // Flush any pending writeback for the slot still held in dev_ptr_. + // The "other" slot's writeback (if any) was issued at the last load() that swapped to it. + if (host_writeback_ && source_ != nullptr && dirty_cur_ && pos_.has_value()) { + queue_d2h(dev_ptr_, *pos_); + dirty_cur_ = false; + } + } + // Stream is shared with the iterator; it must be sync'd before the underlying buffers (or, + // in the passthrough case, the source mdspan) can be safely reused. + copy_stream_.synchronize(); } - /** Logical width of a single row in a batch, in elements of type `T`. */ [[nodiscard]] auto row_width() const -> size_type { return row_width_; } - /** Logical offset of the batch, in rows (`row_width()`) */ [[nodiscard]] auto offset() const -> size_type { return pos_.value_or(0) * batch_size_; } - /** Logical size of the batch, in rows (`row_width()`) */ [[nodiscard]] auto size() const -> size_type { return batch_len_; } - /** Logical size of the batch, in rows (`row_width()`) */ - [[nodiscard]] auto data() const -> const T* { return const_cast(dev_ptr_); } - /** Whether this batch copies the data (i.e. the source is inaccessible from the device). */ - [[nodiscard]] auto does_copy() const -> bool { return needs_copy_; } + [[nodiscard]] auto does_copy() const -> bool { return !kPassthrough; } + + /** + * 2D view of the staged batch. + * + * Passthrough: a `cuda::std::submdspan` of `input_view_` over the active row range. The + * implementation never calls `data_handle()` on `input_view_`; the mdspan's accessor is + * preserved end-to-end, which is the contract that lets future device mdspans without a raw + * pointer flow through this code path unchanged. + * + * Copy_device: a `device_matrix_view` over the internal device buffer (row-major exhaustive). + */ + [[nodiscard]] auto view() const -> batch_view_type + { + if constexpr (kPassthrough) { + const index_type row_lo = static_cast(pos_.value_or(0) * batch_size_); + const index_type row_hi = static_cast(row_lo + batch_len_); + return cuda::std::submdspan( + input_view_, cuda::std::tuple{row_lo, row_hi}, cuda::std::full_extent); + } else { + return raft::make_device_matrix_view( + dev_ptr_, static_cast(batch_len_), static_cast(row_width_)); + } + } + + /** + * Raw device pointer of the staged batch. Provided for backward compatibility with raw-pointer + * call sites. In passthrough mode this forwards to `view().data_handle()`, which means it + * relies on the input mdspan's accessor exposing a pointer. Future device mdspans without a + * raw pointer should call `view()` instead and treat the result as an mdspan. + */ + [[nodiscard]] auto data() const -> element_type* + { + if constexpr (kPassthrough) { + return view().data_handle(); + } else { + return dev_ptr_; + } + } private: - batch(const T* source, - size_type n_rows, - size_type row_width, + template + friend struct batch_load_iterator; + + // Helper: only call `data_handle()` on the input mdspan in copy_device mode. In passthrough + // mode we keep `source_` at `nullptr` (it is never read) so the iterator imposes no + // raw-pointer requirement on the input accessor. + static auto get_source(MdspanT input_view) noexcept -> element_type* + { + if constexpr (kPassthrough) { + return nullptr; + } else { + return input_view.data_handle(); + } + } + + batch(raft::resources const& res, + MdspanT input_view, size_type batch_size, - rmm::cuda_stream_view stream, + rmm::cuda_stream_view copy_stream, rmm::device_async_resource_ref mr, - bool prefetch = false) - : stream_(stream), - buf_0_(0, stream, mr), - buf_1_(0, stream, mr), - source_(source), - dev_ptr_(nullptr), - n_rows_(n_rows), - row_width_(row_width), - batch_size_(std::min(batch_size, n_rows)), - pos_(std::nullopt), - prefetch_pos_(std::nullopt), - n_iters_(raft::div_rounding_up_safe(n_rows, batch_size)), - needs_copy_(false), - prefetch_(prefetch) + bool prefetch, + bool initialize, + bool host_writeback) + : copy_stream_(copy_stream), + res_(&res), + input_view_(input_view), + source_(get_source(input_view)), + n_rows_(static_cast(input_view.extent(0))), + row_width_(static_cast(input_view.extent(1))), + batch_size_(std::min(batch_size, std::max(n_rows_, 1))), + n_iters_(n_rows_ == 0 ? 0 : raft::div_rounding_up_safe(n_rows_, batch_size_)), + prefetch_(prefetch), + initialize_(initialize), + host_writeback_(host_writeback), + buf_0_(0, copy_stream, mr), + buf_1_(0, copy_stream, mr) { - if (source_ == nullptr) { return; } - cudaPointerAttributes attr; - RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, source_)); - dev_ptr_ = reinterpret_cast(attr.devicePointer); - - if (dev_ptr_ == nullptr) { needs_copy_ = true; } - if (attr.type != cudaMemoryTypeDevice) { - // Although data might be accessible on device through HMM or ATS, - // it is preferred to copy the dataset explicitly when it is not device data. - needs_copy_ = true; - } - if (needs_copy_) { - buf_0_.resize(row_width_ * batch_size_, stream); - dev_ptr_ = buf_0_.data(); + if (n_rows_ == 0) { return; } + RAFT_EXPECTS(initialize_ || host_writeback_, + "At least one of initialize or host_writeback must be true"); + RAFT_EXPECTS(!host_writeback_ || !std::is_const_v, + "host_writeback=true requires a non-const element type"); + + if constexpr (!kPassthrough) { + if (source_ == nullptr) { + // Null source: yield batches with the right offsets/sizes but data() == nullptr. + // Skip allocation and never queue copies. + return; + } + buf_0_.resize(row_width_ * batch_size_, copy_stream); + dev_ptr_ = reinterpret_cast(buf_0_.data()); if (prefetch_) { - buf_1_.resize(row_width_ * batch_size_, stream); - prefetch_dev_ptr_ = buf_1_.data(); + buf_1_.resize(row_width_ * batch_size_, copy_stream); + prefetch_dev_ptr_ = reinterpret_cast(buf_1_.data()); } } } - rmm::cuda_stream_view stream_; - rmm::device_uvector buf_0_; - rmm::device_uvector buf_1_; - const T* source_; - size_type n_rows_; - size_type row_width_; - size_type batch_size_; - size_type n_iters_; - bool needs_copy_; - bool prefetch_; - - std::optional pos_; - std::optional prefetch_pos_; - size_type batch_len_; - T* dev_ptr_; - T* prefetch_dev_ptr_; - friend class batch_load_iterator; /** - * Changes the state of the batch to point at the `pos` index. - * If necessary, copies the data from the source in the registered stream. + * Make this batch represent position `pos`. In copy_device mode this synchronously stages + * H2D if needed; in passthrough mode this is pure bookkeeping (the per-batch view is + * recomputed on demand by `view()` via `cuda::std::submdspan`, never via pointer arithmetic + * on the input mdspan). + * No-op if the buffer already holds `pos`. Iteration end is signaled by `pos >= n_iters_`. */ - void load(const size_type& pos) + void load(size_type pos) { - // No-op if the data is already loaded, or it's the end of the input. - if (pos == pos_ || pos >= n_iters_) { return; } - pos_.emplace(pos); - batch_len_ = std::min(batch_size_, n_rows_ - std::min(offset(), n_rows_)); - if (source_ == nullptr) { return; } - if (needs_copy_) { - if (size() > 0) { - RAFT_LOG_TRACE("batch_load_iterator::copy(offset = %zu, size = %zu, row_width = %zu)", - size_t(offset()), - size_t(size()), - size_t(row_width())); - if (prefetch_ && prefetch_pos_ == pos_) { - std::swap(dev_ptr_, prefetch_dev_ptr_); - } else { - raft::copy(dev_ptr_, source_ + offset() * row_width(), size() * row_width(), stream_); - } - } + if (n_iters_ == 0) { return; } + if (pos == pos_) { return; } + if (pos >= n_iters_) { return; } + + const size_type row_offset = pos * batch_size_; + const size_type len = + std::min(batch_size_, n_rows_ - std::min(row_offset, n_rows_)); + + if constexpr (kPassthrough) { + // Passthrough: just record the new slice; view() will compute the submdspan. + pos_.emplace(pos); + batch_len_ = len; + return; } else { - dev_ptr_ = const_cast(source_) + offset() * row_width(); + if (source_ == nullptr) { + pos_.emplace(pos); + batch_len_ = len; + // dev_ptr_ remains nullptr (or the empty-source buffer state). + return; + } + + // Always issue D2H of the slot we're about to leave (or recycle) BEFORE swapping in + // / overwriting it with new data. With prefetch=true the prior kernel has already been + // sync'd by the previous prefetch_next_batch()'s sync_stream(res); with prefetch=false + // copies serialize on a single stream so D2H precedes H2D into the same buffer. + if (host_writeback_ && dirty_cur_ && pos_.has_value()) { + queue_d2h(dev_ptr_, *pos_); + dirty_cur_ = false; + } + if (prefetch_ && prefetch_pos_.has_value() && *prefetch_pos_ == pos) { + // Swap to the prefetched slot. The previously-current slot moves into prefetch_dev_ptr_; + // its writeback (if any) was issued just above. + std::swap(dev_ptr_, prefetch_dev_ptr_); + prefetch_pos_.reset(); + // Drain copy_stream so the swapped-in slot is fully staged before user reads. + copy_stream_.synchronize(); + } else { + if (initialize_) { queue_h2d(dev_ptr_, row_offset, len); } + copy_stream_.synchronize(); + } + pos_.emplace(pos); + batch_len_ = len; + if (host_writeback_) { + // Every advanced batch is implicitly dirty: the user kernel will write to it before + // the next load() / prefetch() recycles the slot. + dirty_cur_ = true; + } } } /** - * Helper function for prefetch. NOP if prefetch option is not enabled. This API is synchronous. + * Queue H2D for `pos` into the not-currently-visible slot, plus D2H of the previously + * dirtied (just-completed) slot. No-op if prefetch is disabled, source is null, or + * `pos >= n_iters_`. + * + * With prefetch enabled this is followed by `sync_stream(res)` so the host-side memcpy + * stall on `copy_stream` overlaps with the user kernel on `res`'s main stream. */ - void prefetch(const size_type& pos) + void prefetch(size_type pos) { - if (pos >= n_iters_ || !prefetch_ || !needs_copy_ || source_ == nullptr) { return; } - size_type prefetch_offset = batch_size_ * pos; - size_type prefetch_size = std::min(batch_size_, n_rows_ - std::min(prefetch_offset, n_rows_)); - raft::common::nvtx::push_range( - "batch_load_iterator::prefetch(offset = %zu, size = %zu, row_width = %zu)", - size_t(prefetch_offset), - size_t(prefetch_size), - size_t(row_width())); - raft::copy(prefetch_dev_ptr_, - source_ + prefetch_offset * row_width(), - prefetch_size * row_width(), - stream_); - raft::common::nvtx::pop_range(); - stream_.synchronize(); + if constexpr (kPassthrough) { return; } + if (n_iters_ == 0 || pos >= n_iters_ || source_ == nullptr) { return; } + if (!prefetch_) { + // No-op: in non-pipelined mode load() does the staging synchronously in operator*. + return; + } + + // Issue H2D of `pos` into prefetch_dev_ptr_ (the slot the user kernel is NOT on). + // Writeback of the "other" slot is unnecessary here because it was already issued at the + // last load() that recycled it. + if (initialize_) { + const size_type row_offset = pos * batch_size_; + const size_type len = + std::min(batch_size_, n_rows_ - std::min(row_offset, n_rows_)); + queue_h2d(prefetch_dev_ptr_, row_offset, len); + } prefetch_pos_.emplace(pos); + + // Wait for the kernel paired with this prefetch_next_batch() before returning, so the next + // operator* can safely swap-and-read the slot. Do this AFTER queueing copies, so the host + // stall overlaps with both the kernel and the copies. + raft::resource::sync_stream(*res_); + } + + void queue_h2d(element_type* dst, size_type src_row_offset, size_type num_rows) + { + if (num_rows == 0) { return; } + const size_t n_bytes = num_rows * row_width_ * sizeof(value_type_d); + // dst is `element_type*` (potentially `const T*`), but it points into a non-const internal + // buffer (`rmm::device_uvector`); the const-cast restores the writable view. + // Use cudaMemcpyAsync directly (rather than raft::copy) to avoid issues with + // HMM/ATS-mapped host pointers being misclassified. + RAFT_CUDA_TRY(cudaMemcpyAsync(const_cast(dst), + source_ + src_row_offset * row_width_, + n_bytes, + cudaMemcpyHostToDevice, + copy_stream_)); } + + void queue_d2h(element_type* src, size_type pos) + { + const size_type row_offset = pos * batch_size_; + const size_type num_rows = + std::min(batch_size_, n_rows_ - std::min(row_offset, n_rows_)); + if (num_rows == 0) { return; } + const size_t n_bytes = num_rows * row_width_ * sizeof(value_type_d); + RAFT_CUDA_TRY(cudaMemcpyAsync(const_cast(source_) + row_offset * row_width_, + src, + n_bytes, + cudaMemcpyDeviceToHost, + copy_stream_)); + } + + rmm::cuda_stream_view copy_stream_; + raft::resources const* res_; + MdspanT input_view_; + element_type* source_; + size_type n_rows_; + size_type row_width_; + size_type batch_size_; + size_type n_iters_; + bool prefetch_; + bool initialize_; + bool host_writeback_; + + rmm::device_uvector buf_0_; + rmm::device_uvector buf_1_; + + // Slot bookkeeping (only meaningful for !kPassthrough). + element_type* dev_ptr_ = nullptr; + element_type* prefetch_dev_ptr_ = nullptr; + std::optional pos_; + std::optional prefetch_pos_; + size_type batch_len_ = 0; + bool dirty_cur_ = false; }; using value_type = batch; @@ -545,72 +757,78 @@ struct batch_load_iterator { using pointer = const value_type*; /** - * Create a batch iterator over the data `source`. - * - * For convenience, the data `source` is read in logical units of size `row_width`; batch sizes - * and offsets are calculated in logical rows. Hence, can interpret the data as a contiguous - * row-major matrix of size [n_rows, row_width], and the batches are the sub-matrices of size - * [x<=batch_size, n_rows]. + * Construct an iterator over `input_view`. * - * If prefetch option is enabled, the batch_load_iterator could help to achieve overlapping with - * prefetch_next_batch() with other workloads. This is useful if source buffer is in host memory. - * To achieve overlapping, the other workloads have to be async and scheduled before - * prefetch_next_batch(). Users also need to use a different stream for the workloads. E.g., - * utils::batch_load_iterator batches(..., stream_1, ..., true); - * batches.prefetch_next_batch(); - * for (const auto& batch : batches) { - * // The following kernel and prefetch_next_batch() could be overlapped. - * kernel<<<..., stream_2>>>(...); - * batches.prefetch_next_batch(); - * } - * - * @param source the input data -- host, device, or nullptr. - * @param n_rows the size of the input in logical rows. - * @param row_width the size of the logical row in the elements of type `T`. - * @param batch_size the desired size of the batch. - * @param stream the ordering for the host->device copies, if applicable. - * @param mr a custom memory resource for the intermediate buffer, if applicable. - * @param prefetch enable prefetch feature in order to achieve kernel/copy overlapping. + * @param res raft resources (must outlive the iterator) + * @param input_view typed mdspan to iterate; row-major; passthrough vs copy is decided + * at compile time from the accessor. + * @param batch_size desired batch size in rows. Clamped to n_rows. + * @param copy_stream stream used for H2D / D2H copies in copy_device mode. Pass a non-main + * stream (see `get_prefetch_stream`) to enable real overlap. + * @param mr memory resource for the internal device buffer(s). + * @param prefetch enable 2-buffer pipelining via `prefetch_next_batch()`. + * @param initialize stage H2D source rows before yielding each batch (default true). + * @param host_writeback queue D2H of every advanced batch back to `input_view`. + * At least one of `initialize` / `host_writeback` must be true for + * non-empty input. */ - batch_load_iterator( - const T* source, - size_type n_rows, - size_type row_width, - size_type batch_size, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource_ref(), - bool prefetch = false) - : cur_batch_(new batch(source, n_rows, row_width, batch_size, stream, mr, prefetch)), + batch_load_iterator(raft::resources const& res, + MdspanT input_view, + size_type batch_size, + rmm::cuda_stream_view copy_stream, + rmm::device_async_resource_ref mr, + bool prefetch = false, + bool initialize = true, + bool host_writeback = false) + : cur_batch_(new value_type( + res, input_view, batch_size, copy_stream, mr, prefetch, initialize, host_writeback)), cur_pos_(0), cur_prefetch_pos_(0) { } - /** - * Whether this iterator copies the data on every iteration - * (i.e. the source is inaccessible from the device). - */ - [[nodiscard]] auto does_copy() const -> bool { return cur_batch_->does_copy(); } - /** Reset the iterator position and prefetch position to `begin()` */ + + /** Convenience overload that uses `get_workspace_resource_ref(res)` as the memory resource. */ + batch_load_iterator(raft::resources const& res, + MdspanT input_view, + size_type batch_size, + rmm::cuda_stream_view copy_stream, + bool prefetch = false, + bool initialize = true, + bool host_writeback = false) + : batch_load_iterator(res, + input_view, + batch_size, + copy_stream, + raft::resource::get_workspace_resource_ref(res), + prefetch, + initialize, + host_writeback) + { + } + + /** Whether iteration copies the data on each step (i.e. not passthrough). */ + [[nodiscard]] auto does_copy() const -> bool { return !kPassthrough; } + /** Reset the iterator (and prefetch) position to begin(). Reusable iteration. */ void reset() { cur_pos_ = 0; cur_prefetch_pos_ = 0; } - /** Reset the iterator position and prefetch position to `end()` */ + /** Reset the iterator (and prefetch) position to end(). */ void reset_to_end() { cur_pos_ = cur_batch_->n_iters_; cur_prefetch_pos_ = cur_batch_->n_iters_; } - [[nodiscard]] auto begin() const -> const batch_load_iterator + [[nodiscard]] auto begin() const -> const batch_load_iterator { - batch_load_iterator x(*this); + batch_load_iterator x(*this); x.reset(); return x; } - [[nodiscard]] auto end() const -> const batch_load_iterator + [[nodiscard]] auto end() const -> const batch_load_iterator { - batch_load_iterator x(*this); + batch_load_iterator x(*this); x.reset_to_end(); return x; } @@ -624,36 +842,35 @@ struct batch_load_iterator { cur_batch_->load(cur_pos_); return cur_batch_.get(); } - /* Prefetch next batch. Users are responsible for calling this method to enable kernel/copy - * overlapping. Note that this API is synchronous. */ + /** Issue the prefetch for the next-but-one batch. See class doc for stream semantics. */ void prefetch_next_batch() { cur_batch_->prefetch(cur_prefetch_pos_++); } - friend auto operator==(const batch_load_iterator& x, const batch_load_iterator& y) -> bool + friend auto operator==(const batch_load_iterator& x, const batch_load_iterator& y) -> bool { return x.cur_batch_ == y.cur_batch_ && x.cur_pos_ == y.cur_pos_; - }; - friend auto operator!=(const batch_load_iterator& x, const batch_load_iterator& y) -> bool + } + friend auto operator!=(const batch_load_iterator& x, const batch_load_iterator& y) -> bool { return x.cur_batch_ != y.cur_batch_ || x.cur_pos_ != y.cur_pos_; - }; - auto operator++() -> batch_load_iterator& + } + auto operator++() -> batch_load_iterator& { ++cur_pos_; return *this; } - auto operator++(int) -> batch_load_iterator + auto operator++(int) -> batch_load_iterator { - batch_load_iterator x(*this); + batch_load_iterator x(*this); ++cur_pos_; return x; } - auto operator--() -> batch_load_iterator& + auto operator--() -> batch_load_iterator& { --cur_pos_; return *this; } - auto operator--(int) -> batch_load_iterator + auto operator--(int) -> batch_load_iterator { - batch_load_iterator x(*this); + batch_load_iterator x(*this); --cur_pos_; return x; } @@ -664,4 +881,299 @@ struct batch_load_iterator { size_type cur_prefetch_pos_; }; +/** + * Runtime-dispatched wrapper over `batch_load_iterator` that takes a raw pointer and + * picks the host- or device-typed mdspan instantiation based on `cudaPointerGetAttributes`, + * preserving the legacy "force copy unless cudaMemoryTypeDevice" policy. + * + * Use this at call sites that don't know statically whether `ptr` is host or device memory. + * Sites that already hold a typed mdspan should use `batch_load_iterator` directly to + * pick the strategy at compile time. + */ +template +class batch_load_iterator_dyn { + using HostMd = raft::host_matrix_view; + using DeviceMd = raft::device_matrix_view; + using HostIter = batch_load_iterator; + using DeviceIter = batch_load_iterator; + + public: + using size_type = size_t; + + /** Uniform batch-proxy view across host/device branches. */ + class batch { + public: + [[nodiscard]] auto data() const -> T* { return dev_ptr_; } + [[nodiscard]] auto size() const -> size_type { return batch_len_; } + [[nodiscard]] auto offset() const -> size_type { return offset_; } + [[nodiscard]] auto row_width() const -> size_type { return row_width_; } + [[nodiscard]] auto does_copy() const -> bool { return does_copy_; } + [[nodiscard]] auto view() const -> raft::device_matrix_view + { + return raft::make_device_matrix_view( + dev_ptr_, static_cast(batch_len_), static_cast(row_width_)); + } + + private: + template + friend class batch_load_iterator_dyn; + T* dev_ptr_ = nullptr; + size_type batch_len_ = 0; + size_type offset_ = 0; + size_type row_width_ = 0; + bool does_copy_ = false; + }; + + using value_type = batch; + using reference = const value_type&; + using pointer = const value_type*; + + /** + * Construct via runtime pointer dispatch. + * + * If `ptr` is a pure-device pointer (`cudaMemoryTypeDevice` with non-null `devicePointer`), + * the device-accessor branch is selected (passthrough). Otherwise (host, pinned, managed, + * unregistered, HMM/ATS, or `nullptr`), the host-accessor branch is selected (copy_device). + */ + batch_load_iterator_dyn(raft::resources const& res, + T* ptr, + IdxT n_rows, + IdxT row_width, + size_type batch_size, + rmm::cuda_stream_view copy_stream, + rmm::device_async_resource_ref mr, + bool prefetch = false, + bool initialize = true, + bool host_writeback = false) + : impl_(make_impl(res, + ptr, + n_rows, + row_width, + batch_size, + copy_stream, + mr, + prefetch, + initialize, + host_writeback)), + proxy_(std::make_shared()) + { + } + + /** Convenience overload that uses `get_workspace_resource_ref(res)` as the memory resource. */ + batch_load_iterator_dyn(raft::resources const& res, + T* ptr, + IdxT n_rows, + IdxT row_width, + size_type batch_size, + rmm::cuda_stream_view copy_stream, + bool prefetch = false, + bool initialize = true, + bool host_writeback = false) + : batch_load_iterator_dyn(res, + ptr, + n_rows, + row_width, + batch_size, + copy_stream, + raft::resource::get_workspace_resource_ref(res), + prefetch, + initialize, + host_writeback) + { + } + + [[nodiscard]] auto does_copy() const -> bool + { + return std::visit([](auto const& it) { return it.does_copy(); }, impl_); + } + void reset() + { + std::visit([](auto& it) { it.reset(); }, impl_); + } + void reset_to_end() + { + std::visit([](auto& it) { it.reset_to_end(); }, impl_); + } + [[nodiscard]] auto begin() const -> batch_load_iterator_dyn + { + batch_load_iterator_dyn x(*this); + x.reset(); + return x; + } + [[nodiscard]] auto end() const -> batch_load_iterator_dyn + { + batch_load_iterator_dyn x(*this); + x.reset_to_end(); + return x; + } + [[nodiscard]] auto operator*() const -> reference + { + std::visit( + [this](auto const& it) { + auto const& b = *it; + proxy_->dev_ptr_ = const_cast(b.data()); + proxy_->batch_len_ = b.size(); + proxy_->offset_ = b.offset(); + proxy_->row_width_ = b.row_width(); + proxy_->does_copy_ = b.does_copy(); + }, + impl_); + return *proxy_; + } + [[nodiscard]] auto operator->() const -> pointer + { + (void)**this; + return proxy_.get(); + } + void prefetch_next_batch() + { + std::visit([](auto& it) { it.prefetch_next_batch(); }, impl_); + } + friend auto operator==(const batch_load_iterator_dyn& x, const batch_load_iterator_dyn& y) -> bool + { + return x.impl_ == y.impl_; + } + friend auto operator!=(const batch_load_iterator_dyn& x, const batch_load_iterator_dyn& y) -> bool + { + return !(x == y); + } + auto operator++() -> batch_load_iterator_dyn& + { + std::visit([](auto& it) { ++it; }, impl_); + return *this; + } + auto operator++(int) -> batch_load_iterator_dyn + { + batch_load_iterator_dyn x(*this); + ++(*this); + return x; + } + auto operator--() -> batch_load_iterator_dyn& + { + std::visit([](auto& it) { --it; }, impl_); + return *this; + } + auto operator--(int) -> batch_load_iterator_dyn + { + batch_load_iterator_dyn x(*this); + --(*this); + return x; + } + + private: + std::variant impl_; + // Shared proxy: copies of the iterator share storage so that `*it1++` and `*it2` consistently + // observe the same backing buffer state, mirroring the legacy shared-batch contract. + std::shared_ptr proxy_; + + static auto make_impl(raft::resources const& res, + T* ptr, + IdxT n_rows, + IdxT row_width, + size_type batch_size, + rmm::cuda_stream_view copy_stream, + rmm::device_async_resource_ref mr, + bool prefetch, + bool initialize, + bool host_writeback) -> std::variant + { + bool is_pure_device = false; + if (ptr != nullptr) { + cudaPointerAttributes attr{}; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, ptr)); + is_pure_device = (attr.type == cudaMemoryTypeDevice) && (attr.devicePointer != nullptr); + } + if (is_pure_device) { + return DeviceIter(res, + raft::make_device_matrix_view(ptr, n_rows, row_width), + batch_size, + copy_stream, + mr, + prefetch, + initialize, + host_writeback); + } + return HostIter(res, + raft::make_host_matrix_view(ptr, n_rows, row_width), + batch_size, + copy_stream, + mr, + prefetch, + initialize, + host_writeback); + } +}; + +// Locally-rolled `type_identity_t` so this header compiles in TUs that build with C++17 +// (e.g. the cuvs C tests). Equivalent to `std::type_identity_t` (C++20). +namespace detail { +template +struct type_identity { + using type = T; +}; +template +using type_identity_t = typename type_identity::type; +} // namespace detail + +/** + * Builder for `batch_load_iterator_dyn`. Use at sites that have a raw pointer with + * unknown memory location and want the legacy "force copy unless cudaMemoryTypeDevice" semantics. + * + * `ptr` is taken by `T const*` so callers can pass either a `T*` or a `const T*` (matching the + * legacy `batch_load_iterator(const T* source, ...)` API). The iterator's element type is `T` + * (non-const), so `batch.data()` returns `T*` for kernels that expect a non-const view of the + * source. Const-correctness at the API boundary is the caller's responsibility. + * + * `n_rows` / `row_width` are placed in non-deduced contexts so that `IdxT` is taken from the + * explicit template argument (or its `int64_t` default) and the integer arguments are implicitly + * converted to `IdxT`, regardless of their incoming integer type. + */ +template +auto make_batch_load_iterator(raft::resources const& res, + T const* ptr, + detail::type_identity_t n_rows, + detail::type_identity_t row_width, + size_t batch_size, + rmm::cuda_stream_view copy_stream, + rmm::device_async_resource_ref mr, + bool prefetch = false, + bool initialize = true, + bool host_writeback = false) -> batch_load_iterator_dyn +{ + return batch_load_iterator_dyn(res, + const_cast(ptr), + n_rows, + row_width, + batch_size, + copy_stream, + mr, + prefetch, + initialize, + host_writeback); +} + +/** Convenience overload that uses `get_workspace_resource_ref(res)` as the memory resource. */ +template +auto make_batch_load_iterator(raft::resources const& res, + T const* ptr, + detail::type_identity_t n_rows, + detail::type_identity_t row_width, + size_t batch_size, + rmm::cuda_stream_view copy_stream, + bool prefetch = false, + bool initialize = true, + bool host_writeback = false) -> batch_load_iterator_dyn +{ + return make_batch_load_iterator(res, + ptr, + n_rows, + row_width, + batch_size, + copy_stream, + raft::resource::get_workspace_resource_ref(res), + prefetch, + initialize, + host_writeback); +} + } // namespace cuvs::spatial::knn::detail::utils diff --git a/cpp/src/neighbors/detail/cagra/add_nodes.cuh b/cpp/src/neighbors/detail/cagra/add_nodes.cuh index f198cf957d..5d0a6654e9 100644 --- a/cpp/src/neighbors/detail/cagra/add_nodes.cuh +++ b/cpp/src/neighbors/detail/cagra/add_nodes.cuh @@ -84,10 +84,11 @@ void add_node_core( auto host_neighbor_indices = raft::make_host_matrix(max_search_batch_size, base_degree); - cuvs::spatial::knn::detail::utils::batch_load_iterator additional_dataset_batch( + auto additional_dataset_batch = cuvs::spatial::knn::detail::utils::make_batch_load_iterator( + handle, additional_dataset_view.data_handle(), - num_add, - additional_dataset_view.stride(0), + static_cast(num_add), + static_cast(additional_dataset_view.stride(0)), max_search_batch_size, raft::resource::get_cuda_stream(handle), mr); diff --git a/cpp/src/neighbors/detail/cagra/cagra_build.cuh b/cpp/src/neighbors/detail/cagra/cagra_build.cuh index 74dce643ea..d659aa246f 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_build.cuh @@ -1729,11 +1729,12 @@ void build_knn_graph( bool first = true; const auto start_clock = std::chrono::system_clock::now(); - cuvs::spatial::knn::detail::utils::batch_load_iterator vec_batches( + auto vec_batches = cuvs::spatial::knn::detail::utils::make_batch_load_iterator( + res, dataset.data_handle(), - dataset.extent(0), - dataset.extent(1), - static_cast(max_queries), + static_cast(dataset.extent(0)), + static_cast(dataset.extent(1)), + static_cast(max_queries), raft::resource::get_cuda_stream(res), workspace_mr); @@ -2114,10 +2115,11 @@ auto iterative_build_graph( // Search. // Since there are many queries, divide them into batches and search them. - cuvs::spatial::knn::detail::utils::batch_load_iterator query_batch( + auto query_batch = cuvs::spatial::knn::detail::utils::make_batch_load_iterator( + res, dev_query_view.data_handle(), - curr_query_size, - dev_query_view.extent(1), + static_cast(curr_query_size), + static_cast(dev_query_view.extent(1)), max_chunk_size, raft::resource::get_cuda_stream(res), raft::resource::get_workspace_resource_ref(res)); diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 2fbf05cf60..d0cc88d614 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -812,41 +812,45 @@ void merge_graph_gpu( uint32_t batch_size = static_cast(std::min(graph_size, 256 * 1024)); const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; - batched_device_view d_output_graph(res, - output_graph, - /*batch_size*/ batch_size, - /*host_writeback*/ true, - /*initialize*/ true); - - batched_device_view< - IdxT, - int64_t, - raft::host_device_accessor, raft::memory_type::host>> - d_mst_graph(res, - mst_graph, - /*batch_size*/ batch_size, - /*host_writeback*/ false, - /*initialize*/ true); - - batched_device_view< - uint32_t, - int64_t, - raft::host_device_accessor, raft::memory_type::host>> - d_mst_graph_num_edges(res, - raft::make_host_matrix_view( - mst_graph_num_edges.data_handle(), mst_graph_num_edges.extent(0), 1l), - /*batch_size*/ batch_size, - /*host_writeback*/ false, - /*initialize*/ true); + namespace bli = cuvs::spatial::knn::detail::utils; + auto [copy_stream, enable_prefetch] = bli::get_prefetch_stream(res); + auto workspace_mr = raft::resource::get_workspace_resource_ref(res); + + bli::batch_load_iterator< + raft::mdspan, raft::row_major, AccessorOutputGraph>> + d_output_graph(res, + output_graph, + batch_size, + copy_stream, + workspace_mr, + enable_prefetch, + /*initialize=*/true, + /*host_writeback=*/true); + + bli::batch_load_iterator> d_mst_graph( + res, mst_graph, batch_size, copy_stream, workspace_mr, enable_prefetch); + + bli::batch_load_iterator> d_mst_graph_num_edges( + res, + raft::make_host_matrix_view( + mst_graph_num_edges.data_handle(), mst_graph_num_edges.extent(0), 1l), + batch_size, + copy_stream, + workspace_mr, + enable_prefetch); + + d_output_graph.prefetch_next_batch(); + d_mst_graph.prefetch_next_batch(); + d_mst_graph_num_edges.prefetch_next_batch(); const uint32_t num_warps = 4; const dim3 threads_merge(raft::WarpSize * num_warps, 1, 1); const dim3 blocks_merge(raft::ceildiv(batch_size, num_warps), 1, 1); const size_t merge_smem_size = num_warps * output_graph_degree * sizeof(IdxT); for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { - auto mst_graph_view = d_mst_graph.next_view(); - auto mst_graph_num_edges_view = d_mst_graph_num_edges.next_view(); - auto output_view = d_output_graph.next_view(); + auto mst_graph_view = (*d_mst_graph).view(); + auto mst_graph_num_edges_view = (*d_mst_graph_num_edges).view(); + auto output_view = (*d_output_graph).view(); kern_merge_graph <<>>( output_view, @@ -859,9 +863,12 @@ void merge_graph_gpu( guarantee_connectivity, d_check_num_protected_edges.data_handle()); - d_output_graph.prefetch_next(); - d_mst_graph.prefetch_next(); - d_mst_graph_num_edges.prefetch_next(); + d_output_graph.prefetch_next_batch(); + d_mst_graph.prefetch_next_batch(); + d_mst_graph_num_edges.prefetch_next_batch(); + ++d_output_graph; + ++d_mst_graph; + ++d_mst_graph_num_edges; } bool check_num_protected_edges = true; @@ -1576,18 +1583,28 @@ void prune_graph_gpu( auto host_stats = raft::make_host_vector(2); raft::matrix::fill(res, dev_stats.view(), uint64_t(0)); - batched_device_view d_input_graph(res, - knn_graph, - /*batch_size*/ graph_size, - /*host_writeback*/ false, - /*initialize*/ true); - auto input_view = d_input_graph.next_view(); - - batched_device_view d_output_graph(res, - output_graph, - /*batch_size*/ batch_size, - /*host_writeback*/ true, - /*initialize*/ false); + namespace bli = cuvs::spatial::knn::detail::utils; + auto [copy_stream, enable_prefetch] = bli::get_prefetch_stream(res); + auto workspace_mr = raft::resource::get_workspace_resource_ref(res); + + // Single-batch read-only iterator for the input graph (graph_size rows fit in one batch). + bli::batch_load_iterator< + raft::mdspan, raft::row_major, AccessorKnnGraph>> + d_input_graph(res, knn_graph, graph_size, copy_stream, workspace_mr); + d_input_graph.prefetch_next_batch(); + auto input_view = (*d_input_graph).view(); + + bli::batch_load_iterator< + raft::mdspan, raft::row_major, AccessorOutputGraph>> + d_output_graph(res, + output_graph, + batch_size, + copy_stream, + workspace_mr, + enable_prefetch, + /*initialize=*/false, + /*host_writeback=*/true); + d_output_graph.prefetch_next_batch(); auto d_invalid_neighbor_list = raft::make_device_scalar(res, 0u); @@ -1597,7 +1614,7 @@ void prune_graph_gpu( const size_t prune_smem_size = num_warps * knn_graph_degree * (sizeof(IdxT) + sizeof(uint32_t)); for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { - auto output_view = d_output_graph.next_view(); + auto output_view = (*d_output_graph).view(); kern_fused_prune <<>>( input_view, @@ -1607,7 +1624,8 @@ void prune_graph_gpu( d_invalid_neighbor_list.data_handle(), dev_stats.data_handle()); - d_output_graph.prefetch_next(); + d_output_graph.prefetch_next_batch(); + ++d_output_graph; RAFT_LOG_DEBUG( "# Pruning kNN Graph on GPUs (%.1lf %%)\r", (double)std::min((i_batch + 1) * batch_size, graph_size) / graph_size * 100); diff --git a/cpp/src/neighbors/detail/cagra/utils.hpp b/cpp/src/neighbors/detail/cagra/utils.hpp index b8ea4e3d54..58bf68bb43 100644 --- a/cpp/src/neighbors/detail/cagra/utils.hpp +++ b/cpp/src/neighbors/detail/cagra/utils.hpp @@ -301,362 +301,4 @@ void copy_with_padding( } } -/** - * Iterate a 2D mdspan in row-batches with overlapping copies and kernel work. - * - * The strategy is selected at compile time from - * `AccessorInputView::is_device_accessible`: - * * passthrough: each batch is `cuda::std::submdspan(input_view_, ...)`. - * No buffering, no arithmetic on `input_view_.data_handle()`. - * * copy_device: each batch is staged through an internal device buffer - * via `cudaMemcpyAsync`. When `host_writeback` is set, every returned - * batch is copied back to `input_view_`; flushing happens lazily during - * subsequent `prefetch_next()` calls and the destructor flushes the tail. - * - * Concurrency model (copy_device): - * * Two device buffers, ping-ponged across iterations. - * * One non-res_ stream `copy_stream_` carries both D2H writebacks and - * H2D prefetches in FIFO order. - * * `next_view()` drains `copy_stream_` so the caller sees a fully-staged - * slot; it does NOT touch res_'s stream. - * * `prefetch_next()` queues D2H of the previous-iter batch and H2D of the - * next-iter batch on `copy_stream_` (slots different from the running - * kernel's slot), then ends with `sync_stream(res_)`. The host stall on - * that final sync overlaps with both the kernel on `res_` and the copies - * on `copy_stream_` -- so for pageable host the host-bound - * pageable->pinned phase of `cudaMemcpyAsync` also overlaps with the - * kernel. - * - * Usage: - * ``` - * batched_device_view view( - * res, input_view, batch_size, host_writeback, initialize); - * for (;;) { - * auto device_view = view.next_view(); - * if (device_view.extent(0) == 0) { break; } // sole stop condition - * kernel<<<..., raft::resource::get_cuda_stream(res)>>>(device_view); - * view.prefetch_next(); // pair with next_view(); copies overlap with kernel - * } - * ``` - * - * Differences vs `cuvs::neighbors::detail::utils::batch_load_iterator` - * (cpp/src/neighbors/detail/ann_utils.cuh): - * * Input typing: this class takes a typed mdspan and decides the strategy - * at compile time from the accessor; `batch_load_iterator` takes - * `(T const*, n_rows, row_width)` and decides at runtime via - * `cudaPointerGetAttributes` (also forces a copy for HMM/ATS sources). - * * API shape: both classes split iteration from prefetch -- here it's - * `next_view()` + `prefetch_next()`; in `batch_load_iterator` it's - * `operator*` + `prefetch_next_batch()`. `batch_load_iterator` is also - * an STL iterator with begin/end and random access; this class is - * single-pass and returns mdspans directly. - * * Mutability: returned views here are mutable T* and support host - * writeback on destruction; `batch_load_iterator` is read-only and never - * copies back. - * * Pipelining: copy_device uses 2 device buffers and a single non-res_ - * stream that carries both directions; cross-iter ordering of D2H and - * H2D on the same slot is enforced by FIFO, and overlap with the kernel - * is achieved by issuing copies *before* the trailing `sync_stream(res_)` - * in `prefetch_next()`. `batch_load_iterator` uses 1-2 buffers, never - * writes back, and the caller is responsible for using a separate stream - * for kernels in order to overlap with `prefetch_next_batch()`. - * - * @tparam T element type - * @tparam IdxT index type for the input mdspan extents - * @tparam AccessorInputView accessor of the input mdspan - */ -template -class batched_device_view { - using input_view_t = - raft::mdspan, raft::row_major, AccessorInputView>; - - public: - // Compile-time strategy switch; see class-level documentation for semantics. - static constexpr bool kPassthrough = AccessorInputView::is_device_accessible; - - static_assert(kPassthrough || - cuda::std::is_convertible_v, - "copy_device path issues cudaMemcpyAsync against input_view_.data_handle() " - "in both directions, so AccessorInputView::data_handle_type must be " - "convertible to T*. To lift this, route prefetch/writeback through " - "raft::copy on submdspans of input_view_."); - - // Result of submdspan(layout_right, tuple, full_extent): - // layout_stride with AccessorInputView preserved. - using next_view_passthrough_type = - decltype(cuda::std::submdspan(std::declval(), - std::declval>(), - cuda::std::full_extent)); - - // Internal contiguous row-major device buffer. - using next_view_copy_type = raft::device_matrix_view; - - /** - * @param res raft resources (must outlive this object) - * @param input_view mdspan to iterate over - * @param batch_size rows per batch (must be > 0 if input_view is non-empty) - * @param host_writeback copy each batch back to input_view_ after use - * (no-op for passthrough) - * @param initialize stage each batch's contents into the device buffer - * before returning it (no-op for passthrough) - * - * At least one of host_writeback / initialize must be true for non-empty input. - */ - batched_device_view(raft::resources const& res, - input_view_t input_view, - uint64_t batch_size, - bool host_writeback = false, - bool initialize = true) - : res_(res), - input_view_(input_view), - batch_size_(batch_size), - batch_id_(-1), - next_prefetched_(false), - last_flushed_batch_id_(-1), - host_writeback_(host_writeback), - initialize_(initialize) - { - if (input_view.extent(0) == 0) { return; } - - RAFT_EXPECTS(batch_size_ > 0, "batch_size must be greater than zero for non-empty input"); - RAFT_EXPECTS(host_writeback_ || initialize_, - "At least one of host_writeback or initialize must be true"); - - RAFT_LOG_DEBUG("Memory strategy: %s for matrix of type %s, dimensions %zu x %zu", - kPassthrough ? "passthrough" : "copy_device", - typeid(T).name(), - input_view.extent(0), - input_view.extent(1)); - - // device buffers (copy_device only). Two slots suffice: at any iter K - // the kernel runs on slot K%nb while prefetch_next() queues a D2H of - // slot (K-1)%nb and an H2D into slot (K+1)%nb on the *same* copy_stream_, - // so D2H and H2D ordering on one slot is enforced by FIFO -- no third - // buffer needed. - if constexpr (!kPassthrough) { - try { - device_mem_[0].emplace(raft::make_device_mdarray( - res, - raft::resource::get_workspace_resource_ref(res), - raft::make_extents(batch_size, input_view.extent(1)))); - device_ptr[0] = device_mem_[0]->data_handle(); - if (batch_size < static_cast(input_view.extent(0))) { - device_mem_[1].emplace(raft::make_device_mdarray( - res, - raft::resource::get_workspace_resource_ref(res), - raft::make_extents(batch_size, input_view.extent(1)))); - device_ptr[1] = device_mem_[1]->data_handle(); - } - } catch (std::bad_alloc& e) { - throw std::bad_alloc(); - } catch (raft::logic_error& e) { - throw raft::logic_error("Insufficient memory for device buffers (logic error)"); - } - } - - // One non-res_ stream is enough: D2H and H2D for a given iter are queued - // back-to-back on this stream and run concurrently with the user's kernel - // on res_'s stream. - if (!res.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL) || - raft::resource::get_stream_pool_size(res) < 1) { - raft::resource::set_cuda_stream_pool(res, std::make_shared(1)); - } - copy_stream_ = raft::resource::get_stream_from_stream_pool(res); - - // Prime batch 0 (slot 0 staged on copy_stream_; first next_view() syncs). - issue_prefetch_for_next_batch(); - } - - ~batched_device_view() noexcept - { - raft::resource::sync_stream(res_); - - // Nothing was ever returned (empty input or no next_view() call); streams - // may still be default-constructed so bail out early. - if (batch_id_ < 0) { return; } - - if constexpr (!kPassthrough) { - if (host_writeback_) { - // Each prefetch_next() flushes batch (batch_id_ - 1) lazily; that - // leaves the most recent 1 batch (normal exit on empty next_view) or - // up to 2 batches (early break without final prefetch_next()) still - // pending. Flush whatever's left on copy_stream_ FIFO. - for (int32_t i = last_flushed_batch_id_ + 1; i <= batch_id_; ++i) { - uint32_t pos = i % 2; - uint64_t off = static_cast(i) * batch_size_; - writeback_from_device_to_host(device_ptr[pos], off, actual_batch_size_[pos]); - } - copy_stream_.synchronize(); - } - } - } - - /** - * Return the view of the batch staged by the constructor or the most recent - * `prefetch_next()`. After this call, that batch is the "current" batch; its - * slot is owned by the caller until `prefetch_next()` advances the pipeline. - * - * Return type is `next_view_passthrough_type` or `next_view_copy_type`; both - * are 2D mdspans of T with the same element- and extent-API surface. - * Iteration ends when extent(0) == 0; this is the only legal stop signal. - * - * Pair every non-empty `next_view()` with a `prefetch_next()` call. Skipping - * the pairing simply stops iteration at the current batch. - */ - auto next_view() - { - if constexpr (kPassthrough) { - // Passthrough has no buffer; the slice goes through the accessor. - if (!next_prefetched_) { - auto end = static_cast(input_view_.extent(0)); - return cuda::std::submdspan( - input_view_, cuda::std::tuple{end, end}, cuda::std::full_extent); - } - ++batch_id_; - next_prefetched_ = false; - - uint32_t current_pos = batch_id_ % 2; - auto first = static_cast(batch_id_ * batch_size_); - auto last = static_cast(first + actual_batch_size_[current_pos]); - return cuda::std::submdspan( - input_view_, cuda::std::tuple{first, last}, cuda::std::full_extent); - } else { - auto cols = static_cast(input_view_.extent(1)); - if (!next_prefetched_) { return next_view_copy_type{nullptr, IdxT{0}, cols}; } - - // Drain the copies queued by the previous prefetch_next() / ctor: this - // ensures the slot we're about to hand out is fully staged AND that the - // writeback for the older batch (if any) has finished before we return - // -- which matters for slot recycling on subsequent iterations. - copy_stream_.synchronize(); - - ++batch_id_; - next_prefetched_ = false; - - uint32_t current_pos = batch_id_ % 2; - return next_view_copy_type{ - device_ptr[current_pos], static_cast(actual_batch_size_[current_pos]), cols}; - } - } - - /** - * Advance the prefetch pipeline -- call once after each non-empty - * `next_view()`, AFTER launching the kernel on res_'s stream. - * - * In copy_device mode this: - * 1. Queues D2H of batch (batch_id_ - 1) on copy_stream_ (the slot the - * kernel is NOT on; data is from the previous iter's kernel which has - * already been sync'd on res_). Skipped on the first iteration when no - * previous batch exists. - * 2. Queues H2D of batch (batch_id_ + 1) on copy_stream_ (also a different - * slot from the running kernel). - * 3. Calls sync_stream(res_) at the *end*. - * - * Steps 1-2 run on copy_stream_ concurrently with the just-launched kernel - * on res_'s stream. The host stall during cudaMemcpyAsync's pageable->pinned - * staging (for pageable host sources/destinations) and the host stall on - * step 3 both overlap with the kernel: that is what makes the pipeline - * actually asynchronous, even for plain pageable host memory. - * - * In passthrough mode this is pure bookkeeping (no copies, no syncs). - */ - void prefetch_next() - { - if constexpr (!kPassthrough) { - if (host_writeback_ && batch_id_ - 1 > last_flushed_batch_id_) { - // Writeback batch (batch_id_ - 1) -- the slot the kernel is *not* on. - // The corresponding kernel (kernel-(batch_id_-1)) finished at the end - // of the previous prefetch_next() (sync_stream(res_)), so its writes - // are globally visible. - uint32_t pos = (batch_id_ - 1) % 2; - uint64_t off = static_cast(batch_id_ - 1) * batch_size_; - writeback_from_device_to_host(device_ptr[pos], off, actual_batch_size_[pos]); - last_flushed_batch_id_ = batch_id_ - 1; - } - } - - issue_prefetch_for_next_batch(); - - if constexpr (!kPassthrough) { - // Wait for the kernel we paired with this prefetch_next() before - // returning. - if (input_view_.extent(0) != 0) raft::resource::sync_stream(res_); - } - } - - private: - /** - * Stage batch (batch_id_ + 1) into slot ((batch_id_ + 1) % 2) on - * copy_stream_, after any prior op on the stream (FIFO). Pure bookkeeping - * in passthrough or !initialize_ mode. Sets next_prefetched_ accordingly. - */ - void issue_prefetch_for_next_batch() - { - uint64_t target_offset = static_cast(batch_id_ + 1) * batch_size_; - if (target_offset >= static_cast(input_view_.extent(0))) { - next_prefetched_ = false; - return; - } - int32_t prefetch_pos = (batch_id_ + 1) % 2; - actual_batch_size_[prefetch_pos] = - static_cast(min(batch_size_, input_view_.extent(0) - target_offset)); - - if constexpr (!kPassthrough) { - if (initialize_) { - prefetch_from_host_to_device( - device_ptr[prefetch_pos], target_offset, actual_batch_size_[prefetch_pos]); - } - } - next_prefetched_ = true; - } - - void prefetch_from_host_to_device(T* dev_ptr, size_t src_row_offset, size_t num_rows) - { - const size_t n_elem = num_rows * input_view_.extent(1); - const size_t n_bytes = n_elem * sizeof(T); - // use memcpy instead of raft::copy to avoid strange behavior with HMM/ATS memory - RAFT_CUDA_TRY( - cudaMemcpyAsync(dev_ptr, - input_view_.data_handle() + src_row_offset * input_view_.extent(1), - n_bytes, - cudaMemcpyHostToDevice, - copy_stream_)); - } - - void writeback_from_device_to_host(T* dev_ptr, size_t dst_row_offset, size_t num_rows) - { - const size_t n_elem = num_rows * input_view_.extent(1); - const size_t n_bytes = n_elem * sizeof(T); - // use memcpy instead of raft::copy to avoid strange behavior with HMM/ATS memory - RAFT_CUDA_TRY( - cudaMemcpyAsync(input_view_.data_handle() + dst_row_offset * input_view_.extent(1), - dev_ptr, - n_bytes, - cudaMemcpyDeviceToHost, - copy_stream_)); - } - - // single non-res_ stream that carries both H2D prefetches and D2H writebacks - // in FIFO order (copy_device only; unused in passthrough) - rmm::cuda_stream_view copy_stream_; - - // configuration - const raft::resources& res_; - bool initialize_; - bool host_writeback_; - - // iteration state - uint64_t batch_size_; - int32_t batch_id_; // most-recently-returned batch id; -1 if none returned - bool next_prefetched_; // slot for batch_id_+1 holds staged data - int32_t last_flushed_batch_id_; // highest batch id whose writeback has been issued; -1 if none - - input_view_t input_view_; - - // device buffers (copy_device only) - std::optional> device_mem_[2]; - T* device_ptr[2]; - uint32_t actual_batch_size_[2]; -}; - } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/nn_descent.cuh b/cpp/src/neighbors/detail/nn_descent.cuh index 2e50aefc15..0adc3373fc 100644 --- a/cpp/src/neighbors/detail/nn_descent.cuh +++ b/cpp/src/neighbors/detail/nn_descent.cuh @@ -1507,8 +1507,13 @@ void GNND::build(Data_t* data, RAFT_CUDA_TRY(cudaPointerGetAttributes(&data_ptr_attr, data)); size_t batch_size = (data_ptr_attr.devicePointer == nullptr) ? 100000 : nrow_; - cuvs::spatial::knn::detail::utils::batch_load_iterator vec_batches{ - data, static_cast(nrow_), build_config_.dataset_dim, batch_size, stream}; + auto vec_batches = cuvs::spatial::knn::detail::utils::make_batch_load_iterator( + res, + data, + static_cast(nrow_), + static_cast(build_config_.dataset_dim), + batch_size, + stream); for (auto const& batch : vec_batches) { if (d_data_float_.has_value()) { preprocess_data_kernel<< build( // TODO: with scaling workspace we could choose the batch size dynamically constexpr uint32_t kReasonableMaxBatchSize = 65536; const uint32_t max_batch_size = std::min(n_rows, kReasonableMaxBatchSize); - for (const auto& batch : cuvs::spatial::knn::detail::utils::batch_load_iterator( - dataset.data_handle(), - n_rows, - dim, - max_batch_size, - raft::resource::get_cuda_stream(res), - raft::resource::get_workspace_resource_ref(res))) { + auto _vamana_batches = cuvs::spatial::knn::detail::utils::make_batch_load_iterator( + res, + dataset.data_handle(), + static_cast(n_rows), + static_cast(dim), + static_cast(max_batch_size), + raft::resource::get_cuda_stream(res), + raft::resource::get_workspace_resource_ref(res)); + for (const auto& batch : _vamana_batches) { // perform rotation auto dataset_rotated = raft::make_device_matrix(res, batch.size(), dim); if constexpr (std::is_same_v) { diff --git a/cpp/src/neighbors/detail/vpq_dataset.cuh b/cpp/src/neighbors/detail/vpq_dataset.cuh index 75785da5d2..d4410afa69 100644 --- a/cpp/src/neighbors/detail/vpq_dataset.cuh +++ b/cpp/src/neighbors/detail/vpq_dataset.cuh @@ -506,13 +506,15 @@ void process_and_fill_codes( return; } - for (const auto& batch : cuvs::spatial::knn::detail::utils::batch_load_iterator( - dataset.data_handle(), - n_rows, - dim, - max_batch_size, - stream, - rmm::mr::get_current_device_resource_ref())) { + auto _vpq_batches_codes = cuvs::spatial::knn::detail::utils::make_batch_load_iterator( + res, + dataset.data_handle(), + static_cast(n_rows), + static_cast(dim), + static_cast(max_batch_size), + stream, + rmm::mr::get_current_device_resource_ref()); + for (const auto& batch : _vpq_batches_codes) { auto batch_view = raft::make_device_matrix_view(batch.data(), ix_t(batch.size()), dim); auto batch_labels_view = raft::make_device_vector_view(nullptr, 0); if (inline_vq_labels) { @@ -899,11 +901,12 @@ void process_and_fill_codes_subspaces( enable_prefetch_stream = true; copy_stream = raft::resource::get_stream_from_stream_pool(res); } - auto vec_batches = cuvs::spatial::knn::detail::utils::batch_load_iterator( + auto vec_batches = cuvs::spatial::knn::detail::utils::make_batch_load_iterator( + res, dataset.data_handle(), - n_rows, - dim, - max_batch_size, + static_cast(n_rows), + static_cast(dim), + static_cast(max_batch_size), copy_stream, raft::resource::get_workspace_resource_ref(res), enable_prefetch_stream); diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh index 35005b4279..fffe5134ae 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh @@ -203,13 +203,15 @@ void extend(raft::resources const& handle, } } // Predict the cluster labels for the new data, in batches if necessary - utils::batch_load_iterator vec_batches(new_vectors, - n_rows, - index->dim(), - max_batch_size, - copy_stream, - raft::resource::get_workspace_resource_ref(handle), - enable_prefetch); + auto vec_batches = + utils::make_batch_load_iterator(handle, + new_vectors, + n_rows, + IdxT{index->dim()}, + max_batch_size, + copy_stream, + raft::resource::get_workspace_resource_ref(handle), + enable_prefetch); vec_batches.prefetch_next_batch(); for (const auto& batch : vec_batches) { @@ -296,17 +298,19 @@ void extend(raft::resources const& handle, raft::make_device_vector_view(list_sizes_ptr, n_lists), raft::make_device_vector_view(old_list_sizes_dev.data_handle(), n_lists)); - utils::batch_load_iterator vec_indices(new_indices, - n_rows, - 1, - max_batch_size, - stream, - raft::resource::get_workspace_resource_ref(handle)); + auto vec_indices = + utils::make_batch_load_iterator(handle, + new_indices, + n_rows, + IdxT{1}, + max_batch_size, + stream, + raft::resource::get_workspace_resource_ref(handle)); vec_batches.reset(); vec_batches.prefetch_next_batch(); - utils::batch_load_iterator idx_batch = vec_indices.begin(); - size_t next_report_offset = 0; - size_t d_report_offset = n_rows * 5 / 100; + auto idx_batch = vec_indices.begin(); + size_t next_report_offset = 0; + size_t d_report_offset = n_rows * 5 / 100; for (const auto& batch : vec_batches) { auto batch_data_view = raft::make_device_matrix_view(batch.data(), batch.size(), index->dim()); diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh index 2ff8195b8b..fc99e2d34c 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh @@ -1078,8 +1078,14 @@ void extend(raft::resources const& handle, } } // Predict the cluster labels for the new data, in batches if necessary - utils::batch_load_iterator vec_batches( - new_vectors, n_rows, index->dim(), max_batch_size, copy_stream, device_memory, enable_prefetch); + auto vec_batches = utils::make_batch_load_iterator(handle, + new_vectors, + n_rows, + index->dim(), + max_batch_size, + copy_stream, + device_memory, + enable_prefetch); // Release the placeholder memory, because we don't intend to allocate any more long-living // temporary buffers before we allocate the index data. // This memory could potentially speed up UVM accesses, if any. @@ -1175,8 +1181,8 @@ void extend(raft::resources const& handle, // By this point, the index state is updated and valid except it doesn't contain the new data // Fill the extended index with the new data (possibly, in batches) - utils::batch_load_iterator idx_batches( - new_indices, n_rows, 1, max_batch_size, stream, batches_mr); + auto idx_batches = utils::make_batch_load_iterator( + handle, new_indices, n_rows, IdxT{1}, max_batch_size, stream, batches_mr); vec_batches.reset(); vec_batches.prefetch_next_batch(); for (const auto& vec_batch : vec_batches) { diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_transform.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_transform.cuh index 1d55692488..b5ea8afab4 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_transform.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_transform.cuh @@ -140,13 +140,14 @@ void transform(raft::resources const& res, constexpr size_t max_batch_size = 65536; rmm::device_async_resource_ref device_memory = raft::resource::get_workspace_resource_ref(res); - utils::batch_load_iterator vec_batches(dataset.data_handle(), - n_rows, - index.dim(), - max_batch_size, - copy_stream, - device_memory, - enable_prefetch); + auto vec_batches = utils::make_batch_load_iterator(res, + dataset.data_handle(), + n_rows, + IdxT{index.dim()}, + max_batch_size, + copy_stream, + device_memory, + enable_prefetch); vec_batches.prefetch_next_batch(); for (const auto& batch : vec_batches) { diff --git a/cpp/src/neighbors/scann/detail/scann_build.cuh b/cpp/src/neighbors/scann/detail/scann_build.cuh index 66ebd994b2..c01e50bc83 100644 --- a/cpp/src/neighbors/scann/detail/scann_build.cuh +++ b/cpp/src/neighbors/scann/detail/scann_build.cuh @@ -102,13 +102,15 @@ index build( } } - utils::batch_load_iterator dataset_vec_batches(dataset.data_handle(), - dataset.extent(0), - dataset.extent(1), - max_batch_size, - copy_stream, - device_memory, - enable_prefetch); + auto dataset_vec_batches = + utils::make_batch_load_iterator(res, + dataset.data_handle(), + static_cast(dataset.extent(0)), + static_cast(dataset.extent(1)), + max_batch_size, + copy_stream, + device_memory, + enable_prefetch); dataset_vec_batches.reset(); dataset_vec_batches.prefetch_next_batch(); diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 34c067e9f5..2f46abb812 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -182,7 +182,7 @@ ConfigureTest( ConfigureTest( NAME NEIGHBORS_ANN_CAGRA_HELPERS_TEST - PATH neighbors/ann_cagra/test_optimize_uint32_t.cu neighbors/ann_cagra/test_batched_device_view.cu + PATH neighbors/ann_cagra/test_optimize_uint32_t.cu neighbors/ann_cagra/test_batch_load_iterator.cu GPUS 1 PERCENT 100 ) diff --git a/cpp/tests/neighbors/ann_cagra/test_batched_device_view.cu b/cpp/tests/neighbors/ann_cagra/test_batch_load_iterator.cu similarity index 52% rename from cpp/tests/neighbors/ann_cagra/test_batched_device_view.cu rename to cpp/tests/neighbors/ann_cagra/test_batch_load_iterator.cu index 6733ac784a..c0df2b8826 100644 --- a/cpp/tests/neighbors/ann_cagra/test_batched_device_view.cu +++ b/cpp/tests/neighbors/ann_cagra/test_batch_load_iterator.cu @@ -5,7 +5,6 @@ #include -#include #include #include #include @@ -17,11 +16,12 @@ #include #include +#include #include #include #include -#include "../../../src/neighbors/detail/cagra/utils.hpp" +#include "../../../src/neighbors/detail/ann_utils.cuh" #include #include @@ -30,6 +30,8 @@ namespace cuvs::neighbors::cagra { +namespace bli = cuvs::spatial::knn::detail::utils; + using IdxT = uint32_t; using DeviceAccessor = raft::host_device_accessor, raft::memory_type::device>; @@ -51,13 +53,20 @@ struct DimsConfig { uint64_t batch_size; }; -class BatchedDeviceViewTest : public ::testing::Test { +class BatchLoadIteratorTest : public ::testing::Test { protected: - void SetUp() override { raft::resource::sync_stream(res); } + void SetUp() override + { + // Provide a stream pool so get_prefetch_stream(res) returns a non-main stream and + // the iterator exercises its real pipelined path. + raft::resource::set_cuda_stream_pool(res, std::make_shared(1)); + raft::resource::sync_stream(res); + } /** - * Run batched_device_view over host data, copy device views back, - * and verify against the input. + * Run batch_load_iterator over input_view, copy device views back, and verify against the + * input. Mirrors the old batched_device_view test: writeback fills with IdxT(17), readback + * (when initialize==true) verifies pre-fill of IdxT(13). */ template void run_and_verify_batched( @@ -66,6 +75,8 @@ class BatchedDeviceViewTest : public ::testing::Test { bool host_writeback, bool initialize) { + using mdspan_t = decltype(input_view); + int64_t n_rows = input_view.extent(0); int64_t n_cols = input_view.extent(1); @@ -73,12 +84,24 @@ class BatchedDeviceViewTest : public ::testing::Test { int64_t total_processed = 0; + auto [copy_stream, enable_prefetch] = bli::get_prefetch_stream(res); + auto workspace_mr = raft::resource::get_workspace_resource_ref(res); + { - cagra::detail::batched_device_view batched( - res, input_view, batch_size, host_writeback, initialize); - while (true) { - auto dev_view = batched.next_view(); - if (dev_view.extent(0) == 0) break; + bli::batch_load_iterator iter(res, + input_view, + batch_size, + copy_stream, + workspace_mr, + enable_prefetch, + initialize, + host_writeback); + iter.prefetch_next_batch(); + + for (auto& batch : iter) { + if (batch.size() == 0) break; + + auto dev_view = batch.view(); if (initialize) { raft::copy(readback.data() + total_processed * n_cols, @@ -87,14 +110,12 @@ class BatchedDeviceViewTest : public ::testing::Test { raft::resource::get_cuda_stream(res)); } if (host_writeback) { - // Re-wrap as a plain device_matrix_view to strip the (potentially - // layout_stride / pinned- or managed-accessor) shape that the - // passthrough path would otherwise hand us, so raft::matrix::fill's - // device_matrix_view overload accepts the call. dev_view is always - // exhaustive (contiguous row range of a row-major matrix), so - // (data_handle, extent(0), extent(1)) describes the same memory. - // This should eventually be fixed by adding a more generic - // overload to raft::matrix::fill. + // The passthrough strategy returns `cuda::std::submdspan(input_view, ...)`, which + // may be `layout_stride` and/or carry a non-default accessor (e.g. PinnedAccessor), + // neither of which `raft::matrix::fill`'s `device_matrix_view` overload accepts. + // Re-wrap as a plain device_matrix_view since the slice is always exhaustive + // (a contiguous row range of a row-major input). This re-wrap should eventually be + // removed once raft::matrix::fill grows a more generic mdspan overload. raft::matrix::fill(res, raft::make_device_matrix_view( dev_view.data_handle(), dev_view.extent(0), dev_view.extent(1)), @@ -102,10 +123,7 @@ class BatchedDeviceViewTest : public ::testing::Test { } total_processed += dev_view.extent(0); - // Pair next_view() with prefetch_next(): the next batch's H2D and the - // previous batch's D2H run on copy_stream_ concurrently with the - // raft::copy / raft::matrix::fill kernels we just queued on res_. - batched.prefetch_next(); + iter.prefetch_next_batch(); } } raft::resource::sync_stream(res); @@ -130,38 +148,38 @@ class BatchedDeviceViewTest : public ::testing::Test { raft::resources res; }; -TEST_F(BatchedDeviceViewTest, EmptyViewFromHost) +TEST_F(BatchLoadIteratorTest, EmptyViewFromHost) { auto host_empty = raft::make_host_matrix(0, 8); auto host_view = host_empty.view(); - cagra::detail::batched_device_view batched( - res, host_view, /*batch_size=*/128, /*host_writeback=*/false, /*initialize=*/true); - auto view = batched.next_view(); - EXPECT_EQ(view.extent(0), 0); - EXPECT_EQ(view.extent(1), 8); - EXPECT_EQ(view.data_handle(), nullptr); + auto [copy_stream, enable_prefetch] = bli::get_prefetch_stream(res); + auto workspace_mr = raft::resource::get_workspace_resource_ref(res); + + bli::batch_load_iterator> iter( + res, host_view, /*batch_size=*/128, copy_stream, workspace_mr, enable_prefetch); + EXPECT_TRUE(iter.begin() == iter.end()); } -TEST_F(BatchedDeviceViewTest, EmptyViewFromDevice) +TEST_F(BatchLoadIteratorTest, EmptyViewFromDevice) { auto device_empty = raft::make_device_matrix(res, 0, 8); auto device_view = device_empty.view(); - cagra::detail::batched_device_view batched( - res, device_view, /*batch_size=*/128, /*host_writeback=*/false, /*initialize=*/true); - auto view = batched.next_view(); - EXPECT_EQ(view.extent(0), 0); - EXPECT_EQ(view.extent(1), 8); - EXPECT_EQ(view.data_handle(), nullptr); + auto [copy_stream, enable_prefetch] = bli::get_prefetch_stream(res); + auto workspace_mr = raft::resource::get_workspace_resource_ref(res); + + bli::batch_load_iterator> iter( + res, device_view, /*batch_size=*/128, copy_stream, workspace_mr, enable_prefetch); + EXPECT_TRUE(iter.begin() == iter.end()); } using BatchDimsParam = std::tuple; -class BatchedDeviceViewParameterizedTest : public BatchedDeviceViewTest, +class BatchLoadIteratorParameterizedTest : public BatchLoadIteratorTest, public ::testing::WithParamInterface {}; -TEST_P(BatchedDeviceViewParameterizedTest, VectorHostData) +TEST_P(BatchLoadIteratorParameterizedTest, VectorHostData) { auto [batch_config, dims_config] = GetParam(); auto [initialize, host_writeback] = batch_config; @@ -175,14 +193,13 @@ TEST_P(BatchedDeviceViewParameterizedTest, VectorHostData) run_and_verify_batched(host_view, batch_size, host_writeback, initialize); } -TEST_P(BatchedDeviceViewParameterizedTest, PinnedMemory) +TEST_P(BatchLoadIteratorParameterizedTest, PinnedMemory) { auto [batch_config, dims_config] = GetParam(); auto [initialize, host_writeback] = batch_config; auto [n_rows, n_cols, batch_size] = dims_config; auto pinned_matrix = raft::make_pinned_matrix(res, n_rows, n_cols); - // auto pinned_view = pinned_matrix.view(); auto pinned_view = raft::mdspan, raft::row_major, PinnedAccessor>( pinned_matrix.data_handle(), n_rows, n_cols); @@ -191,7 +208,7 @@ TEST_P(BatchedDeviceViewParameterizedTest, PinnedMemory) run_and_verify_batched(pinned_view, batch_size, host_writeback, initialize); } -TEST_P(BatchedDeviceViewParameterizedTest, PinnedMemoryForcedToHost) +TEST_P(BatchLoadIteratorParameterizedTest, PinnedMemoryForcedToHost) { auto [batch_config, dims_config] = GetParam(); auto [initialize, host_writeback] = batch_config; @@ -207,7 +224,7 @@ TEST_P(BatchedDeviceViewParameterizedTest, PinnedMemoryForcedToHost) run_and_verify_batched(pinned_view, batch_size, host_writeback, initialize); } -TEST_P(BatchedDeviceViewParameterizedTest, ManagedMemory) +TEST_P(BatchLoadIteratorParameterizedTest, ManagedMemory) { auto [batch_config, dims_config] = GetParam(); auto [initialize, host_writeback] = batch_config; @@ -221,7 +238,7 @@ TEST_P(BatchedDeviceViewParameterizedTest, ManagedMemory) run_and_verify_batched(managed_view, batch_size, host_writeback, initialize); } -TEST_P(BatchedDeviceViewParameterizedTest, ManagedMemoryForcedToHost) +TEST_P(BatchLoadIteratorParameterizedTest, ManagedMemoryForcedToHost) { auto [batch_config, dims_config] = GetParam(); auto [initialize, host_writeback] = batch_config; @@ -238,7 +255,7 @@ TEST_P(BatchedDeviceViewParameterizedTest, ManagedMemoryForcedToHost) run_and_verify_batched(managed_view, batch_size, host_writeback, initialize); } -TEST_P(BatchedDeviceViewParameterizedTest, DeviceMemory) +TEST_P(BatchLoadIteratorParameterizedTest, DeviceMemory) { auto [batch_config, dims_config] = GetParam(); auto [initialize, host_writeback] = batch_config; @@ -252,6 +269,89 @@ TEST_P(BatchedDeviceViewParameterizedTest, DeviceMemory) run_and_verify_batched(device_view, batch_size, host_writeback, initialize); } +/** + * Drive the runtime-dispatched wrapper. Verifies that: + * * a host pointer dispatches to the copy_device branch (does_copy() == true), and + * * a device pointer dispatches to passthrough (does_copy() == false), + * and that initialize-only iteration yields the expected pre-fill on each batch. + */ +TEST_F(BatchLoadIteratorTest, MakeBatchLoadIteratorHostPtr) +{ + const int64_t n_rows = 256; + const int64_t n_cols = 32; + const size_t batch_size_rows = 64; + + std::vector host_data(n_rows * n_cols, IdxT(13)); + auto [copy_stream, enable_prefetch] = bli::get_prefetch_stream(res); + + auto iter = bli::make_batch_load_iterator(res, + host_data.data(), + n_rows, + n_cols, + batch_size_rows, + copy_stream, + raft::resource::get_workspace_resource_ref(res), + enable_prefetch); + EXPECT_TRUE(iter.does_copy()); + + std::vector readback(n_rows * n_cols, IdxT(0)); + int64_t total = 0; + iter.prefetch_next_batch(); + for (auto const& batch : iter) { + if (batch.size() == 0) break; + raft::copy(readback.data() + total * n_cols, + batch.data(), + batch.size() * n_cols, + raft::resource::get_cuda_stream(res)); + total += batch.size(); + iter.prefetch_next_batch(); + } + raft::resource::sync_stream(res); + EXPECT_EQ(total, n_rows); + for (int64_t i = 0; i < n_rows * n_cols; ++i) { + EXPECT_EQ(readback[i], IdxT(13)) << "Mismatch at index " << i; + } +} + +TEST_F(BatchLoadIteratorTest, MakeBatchLoadIteratorDevicePtr) +{ + const int64_t n_rows = 256; + const int64_t n_cols = 32; + const size_t batch_size_rows = 64; + + auto device_matrix = raft::make_device_matrix(res, n_rows, n_cols); + raft::matrix::fill(res, device_matrix.view(), IdxT(13)); + + auto [copy_stream, enable_prefetch] = bli::get_prefetch_stream(res); + auto iter = bli::make_batch_load_iterator(res, + device_matrix.data_handle(), + n_rows, + n_cols, + batch_size_rows, + copy_stream, + raft::resource::get_workspace_resource_ref(res), + enable_prefetch); + EXPECT_FALSE(iter.does_copy()); + + std::vector readback(n_rows * n_cols, IdxT(0)); + int64_t total = 0; + iter.prefetch_next_batch(); + for (auto const& batch : iter) { + if (batch.size() == 0) break; + raft::copy(readback.data() + total * n_cols, + batch.data(), + batch.size() * n_cols, + raft::resource::get_cuda_stream(res)); + total += batch.size(); + iter.prefetch_next_batch(); + } + raft::resource::sync_stream(res); + EXPECT_EQ(total, n_rows); + for (int64_t i = 0; i < n_rows * n_cols; ++i) { + EXPECT_EQ(readback[i], IdxT(13)) << "Mismatch at index " << i; + } +} + static const std::array kBatchConfigs = {{ {/*initialize=*/true, /*host_writeback=*/false}, {/*initialize=*/false, /*host_writeback=*/true}, @@ -268,7 +368,7 @@ static const std::array kDimsConfigs = {{ }}; INSTANTIATE_TEST_SUITE_P(BatchConfigs, - BatchedDeviceViewParameterizedTest, + BatchLoadIteratorParameterizedTest, ::testing::Combine(::testing::ValuesIn(kBatchConfigs), ::testing::ValuesIn(kDimsConfigs))); From cc8d89232b61221023a1c02b2ed566ae9c71801e Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Fri, 8 May 2026 00:59:42 +0000 Subject: [PATCH 37/40] more review suggestions --- .../neighbors/detail/cagra/cagra_build.cuh | 5 ++-- cpp/src/neighbors/detail/cagra/graph_core.cuh | 24 +++++++++---------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/cagra_build.cuh b/cpp/src/neighbors/detail/cagra/cagra_build.cuh index d659aa246f..96ff8344d3 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_build.cuh @@ -831,10 +831,11 @@ inline std::pair optimize_workspace_size(size_t n_rows, mst_host += (graph_degree - 1) * (graph_degree - 1) * index_size; // iB_candidates } - // Prune stage memory - // We neglect 8 bytes (both on host and device) for stats + // batchsize for both prune and combine stages size_t batch_size = std::min(static_cast(256 * 1024), n_rows); + // Prune stage memory + // We neglect 8 bytes (both on host and device) for stats size_t prune_dev = batch_size * intermediate_degree * 1; // detour count (uint8_t) prune_dev += batch_size * sizeof(uint32_t); // d_num_detour_edges prune_dev += n_rows * intermediate_degree * index_size; // d_input_graph diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index d0cc88d614..4cfdfc8fce 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -251,7 +251,7 @@ __global__ void kern_fused_prune(KnnGraphView knn_graph, // [graph_chunk_ smem_num_detour[k] = knn_graph_degree; } } - __syncwarp(); + warp.sync(); // count number of detours (A->D->B) for (uint32_t kAD = 0; kAD < knn_graph_degree - 1; kAD++) { @@ -270,7 +270,7 @@ __global__ void kern_fused_prune(KnnGraphView knn_graph, // [graph_chunk_ } } } - __syncwarp(); + warp.sync(); } uint32_t num_edges_no_detour = 0; @@ -280,7 +280,7 @@ __global__ void kern_fused_prune(KnnGraphView knn_graph, // [graph_chunk_ if (smem_indices[k] >= graph_size) { smem_num_detour[k] = maxval16; } } - __syncwarp(); + warp.sync(); num_edges_no_detour = cg::reduce(warp, num_edges_no_detour, cg::plus()); num_edges_no_detour = min(num_edges_no_detour, output_graph_degree); @@ -319,7 +319,7 @@ __global__ void kern_fused_prune(KnnGraphView knn_graph, // [graph_chunk_ for (uint32_t k = lane_id; k < knn_graph_degree; k += raft::WarpSize) { if (smem_indices[k] == selected_node) { smem_num_detour[k] = maxval16; } } - __syncwarp(); + warp.sync(); if (lane_id == 0) { output_graph(nid_batch, i) = selected_node; } } @@ -348,16 +348,16 @@ template __device__ void warp_shift_array_one_right(uint32_t lane_id, T* array, uint64_t num) { if (num == 0) { return; } - for (auto chunk_end = static_cast(num); chunk_end >= 1; chunk_end -= 32) { + for (auto chunk_end = static_cast(num); chunk_end >= 1; chunk_end -= 31) { const int64_t chunk_start_lo = chunk_end - 31; - const int64_t chunk_start = (chunk_start_lo > 1) ? chunk_start_lo : 1; + const int64_t chunk_start = (chunk_start_lo > 0) ? chunk_start_lo : 0; const int64_t k = chunk_start + static_cast(lane_id); T val{}; - const bool active = (k <= chunk_end); - if (active) { val = array[k - 1]; } - __syncwarp(); - if (active) { array[k] = val; } - __syncwarp(); + const bool read_active = (k <= chunk_end); + if (read_active) { val = array[k]; } + const T shifted = raft::shfl_up(val, 1); + const bool write_active = (lane_id > 0) && read_active; + if (write_active) { array[k] = shifted; } } } @@ -423,7 +423,7 @@ __global__ void kern_merge_graph( } } - unsigned int warp_dup = __ballot_sync(0xffffffff, dup); + auto warp_dup = raft::ballot(dup); if (warp_dup == 0) { if (lane_id == 0) smem_sorted_output_graph[output_j] = v; output_j++; From 3b70439f7484830b4858564939872444d92c3743 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Fri, 8 May 2026 11:36:44 +0000 Subject: [PATCH 38/40] more review suggestions --- cpp/src/neighbors/detail/ann_utils.cuh | 15 ++- cpp/src/neighbors/detail/cagra/graph_core.cuh | 123 ++++++++++++------ 2 files changed, 92 insertions(+), 46 deletions(-) diff --git a/cpp/src/neighbors/detail/ann_utils.cuh b/cpp/src/neighbors/detail/ann_utils.cuh index 4089c58a03..d4aa18efce 100644 --- a/cpp/src/neighbors/detail/ann_utils.cuh +++ b/cpp/src/neighbors/detail/ann_utils.cuh @@ -418,7 +418,9 @@ inline auto get_prefetch_stream(raft::resources const& res) * Three orthogonal flags control behavior of the copy_device strategy: * * `prefetch`: if true, allocate two device buffers and pipeline copies using * `prefetch_next_batch()`. If false, copies happen synchronously at `operator*` and one buffer - * is allocated. + * is allocated. As an optimization, when `n_iters_ <= 1` (i.e. `batch_size >= + * input_view.extent(0)`) prefetching is internally downgraded to false and only one buffer is + * allocated, since there is no "next batch" to overlap with. * * `initialize`: if true, stage source rows H2D into the buffer before yielding the batch. * If false, the buffer is handed out uninitialized (kernel produces the data from scratch). * * `host_writeback`: if true, queue D2H of every advanced batch back to `input_view_`. @@ -474,8 +476,7 @@ struct batch_load_iterator { static constexpr bool kPassthrough = accessor_type::is_device_accessible; // Type returned by `view()` for the passthrough strategy: a 2D submdspan of `input_view_` over a - // contiguous row range. Built without ever calling `data_handle()` on the input mdspan, so this - // stays valid even for future device mdspans whose accessor exposes no raw pointer. + // contiguous row range. Built without ever calling `data_handle()` on the input mdspan. // (Per the mdspan spec, slicing a `layout_right` with a `tuple{lo, hi}` over the leading dim // yields a `layout_stride` mdspan with the input's accessor preserved.) using passthrough_view_type = @@ -599,9 +600,15 @@ struct batch_load_iterator { } buf_0_.resize(row_width_ * batch_size_, copy_stream); dev_ptr_ = reinterpret_cast(buf_0_.data()); - if (prefetch_) { + // The second buffer is only useful when there is more than one batch to overlap. With + // n_iters_ <= 1, there is no "next batch" to stage while a kernel runs on the current + // one, so prefetching offers no benefit. Downgrade `prefetch_` to false to skip the + // buf_1_ allocation and have `prefetch()` / `load()` take the single-buffer fast path. + if (prefetch_ && n_iters_ > 1) { buf_1_.resize(row_width_ * batch_size_, copy_stream); prefetch_dev_ptr_ = reinterpret_cast(buf_1_.data()); + } else { + prefetch_ = false; } } } diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 4cfdfc8fce..a7ea1bbc91 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -378,7 +378,7 @@ __global__ void kern_merge_graph( const uint32_t batch_size, const uint32_t batch_id, bool guarantee_connectivity, - bool* check_num_protected_edges) + uint32_t* check_num_protected_edges) { // Check assumption we have at least one warp per row of the batch assert(blockDim.x == raft::WarpSize * num_warps); @@ -442,7 +442,7 @@ __global__ void kern_merge_graph( const auto num_protected_edges = max(current_mst_graph_num_edges, output_graph_degree / 2); if (num_protected_edges > output_graph_degree) { - check_num_protected_edges[0] = false; + check_num_protected_edges[0] = 0u; return; } if (num_protected_edges == output_graph_degree) { return; } @@ -706,13 +706,29 @@ __global__ void kern_mst_opt_postprocessing(IdxT* outgoing_num_edges, // [graph } } -template -void log_incoming_edges_histogram(const IdxT* output_graph_ptr, - uint64_t graph_size, - uint64_t output_graph_degree) +template +void log_incoming_edges_histogram( + raft::resources const& res, + raft::mdspan, raft::row_major, AccessorOutputGraph> + output_graph) { raft::common::nvtx::range block_scope( "cagra::graph::optimize/check_edges"); + const uint64_t graph_size = output_graph.extent(0); + const uint64_t output_graph_degree = output_graph.extent(1); + + // copy to host if required + IdxT* output_graph_ptr = nullptr; + int64_t buffer_size = AccessorOutputGraph::is_host_accessible ? graph_size : 0; + auto host_copy_output_graph = + raft::make_host_matrix(buffer_size, output_graph_degree); + if constexpr (AccessorOutputGraph::is_host_accessible) { + output_graph_ptr = output_graph.data_handle(); + } else { + raft::copy(res, host_copy_output_graph.view(), output_graph); + output_graph_ptr = host_copy_output_graph.data_handle(); + } + auto in_edge_count = raft::make_host_vector(graph_size); auto in_edge_count_ptr = in_edge_count.data_handle(); #pragma omp parallel for @@ -753,13 +769,30 @@ void log_incoming_edges_histogram(const IdxT* output_graph_ptr, } } -template -void check_duplicates_and_out_of_range(const IdxT* output_graph_ptr, - uint64_t graph_size, - uint64_t output_graph_degree) +template +void check_duplicates_and_out_of_range( + raft::resources const& res, + raft::mdspan, raft::row_major, AccessorOutputGraph> + output_graph) { raft::common::nvtx::range block_scope( "cagra::graph::optimize/check_duplicates"); + + const uint64_t graph_size = output_graph.extent(0); + const uint64_t output_graph_degree = output_graph.extent(1); + + // copy to host if required + IdxT* output_graph_ptr = nullptr; + int64_t buffer_size = AccessorOutputGraph::is_host_accessible ? graph_size : 0; + auto host_copy_output_graph = + raft::make_host_matrix(buffer_size, output_graph_degree); + if constexpr (AccessorOutputGraph::is_host_accessible) { + output_graph_ptr = output_graph.data_handle(); + } else { + raft::copy(res, host_copy_output_graph.view(), output_graph); + output_graph_ptr = host_copy_output_graph.data_handle(); + } + uint64_t num_dup = 0; uint64_t num_oor = 0; #pragma omp parallel for reduction(+ : num_dup) reduction(+ : num_oor) @@ -807,8 +840,12 @@ void merge_graph_gpu( const double merge_graph_start = cur_time(); - auto d_check_num_protected_edges = raft::make_device_scalar(res, true); + auto d_check_num_protected_edges = raft::make_device_scalar(res, 1u); + // The batchsize is statically set to 256 * 1024 which corresponds to 256MB for a graph + // degree of 128 and 16byte index type. This is a trade-off between memory usage and performance. + // When choosing dynamically based on available memory, we would also need to modify the static + // size assumption in the cagra_build.cuh::optimize_workspace_size function. uint32_t batch_size = static_cast(std::min(graph_size, 256 * 1024)); const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; @@ -871,11 +908,10 @@ void merge_graph_gpu( ++d_mst_graph_num_edges; } - bool check_num_protected_edges = true; - raft::copy(&check_num_protected_edges, - d_check_num_protected_edges.data_handle(), - 1, - raft::resource::get_cuda_stream(res)); + uint32_t check_num_protected_edges = 1u; + raft::copy(res, + raft::make_host_scalar_view(&check_num_protected_edges), + d_check_num_protected_edges.view()); raft::resource::sync_stream(res); const auto merge_graph_end = cur_time(); RAFT_EXPECTS(check_num_protected_edges, @@ -907,7 +943,15 @@ void make_reverse_graph_gpu( raft::matrix::fill(res, d_rev_graph, IdxT(-1)); raft::matrix::fill(res, d_rev_graph_count, uint32_t(0)); - if constexpr (!AccessorOutputGraph::is_device_accessible) { + if constexpr (AccessorOutputGraph::is_device_accessible) { + // output graph is fully device accessible, so we need no copy to device + dim3 threads(256, 1, 1); + dim3 blocks(1024, 1, 1); + for (uint64_t k = 0; k < output_graph_degree; k++) { + kern_make_rev_graph_k<<>>( + output_graph, d_rev_graph, d_rev_graph_count, k); + } + } else { auto d_dest_nodes = raft::make_device_matrix(res, graph_size, 1); auto dest_nodes = raft::make_host_vector(graph_size); for (uint64_t k = 0; k < output_graph_degree; k++) { @@ -924,14 +968,6 @@ void make_reverse_graph_gpu( raft::resource::sync_stream(res); RAFT_LOG_DEBUG("# Making reverse graph on GPUs: %lu / %u \r", k, output_graph_degree); } - } else { - // output graph is fully device accessible, so we need no copy to device - dim3 threads(256, 1, 1); - dim3 blocks(1024, 1, 1); - for (uint64_t k = 0; k < output_graph_degree; k++) { - kern_make_rev_graph_k<<>>( - output_graph, d_rev_graph, d_rev_graph_count, k); - } } } @@ -1570,6 +1606,10 @@ void prune_graph_gpu( raft::common::nvtx::range block_scope( "cagra::graph::optimize/prune"); + // The batchsize is statically set to 256 * 1024 which corresponds to 256MB for a graph + // degree of 128 and 16byte index type. This is a trade-off between memory usage and performance. + // When choosing dynamically based on available memory, we would also need to modify the static + // size assumption in the cagra_build.cuh::optimize_workspace_size function. uint32_t batch_size = static_cast(std::min(graph_size, 256 * 1024)); const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; @@ -1577,10 +1617,10 @@ void prune_graph_gpu( const double prune_start = cur_time(); - uint64_t num_keep __attribute__((unused)) = 0; - uint64_t num_full __attribute__((unused)) = 0; - auto dev_stats = raft::make_device_vector(res, 2); - auto host_stats = raft::make_host_vector(2); + [[maybe_unused]] uint64_t num_keep = 0; + [[maybe_unused]] uint64_t num_full = 0; + auto dev_stats = raft::make_device_vector(res, 2); + auto host_stats = raft::make_host_vector(2); raft::matrix::fill(res, dev_stats.view(), uint64_t(0)); namespace bli = cuvs::spatial::knn::detail::utils; @@ -1591,7 +1631,6 @@ void prune_graph_gpu( bli::batch_load_iterator< raft::mdspan, raft::row_major, AccessorKnnGraph>> d_input_graph(res, knn_graph, graph_size, copy_stream, workspace_mr); - d_input_graph.prefetch_next_batch(); auto input_view = (*d_input_graph).view(); bli::batch_load_iterator< @@ -1683,6 +1722,12 @@ void optimize( // temporary memory for small arrays, e.g. everything <= O(batchsize * graph_degree) auto default_ws_mr = raft::resource::get_workspace_resource_ref(res); + // create a stream pool if not already present + if (!res.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL) || + raft::resource::get_stream_pool_size(res) < 1) { + raft::resource::set_cuda_stream_pool(res, std::make_shared(1)); + } + RAFT_EXPECTS(knn_graph.extent(0) == new_graph.extent(0), "Each input array is expected to have the same number of rows"); RAFT_EXPECTS(new_graph.extent(1) <= knn_graph.extent(1), @@ -1703,10 +1748,6 @@ void optimize( auto mst_graph_num_edges = raft::make_host_vector(mst_graph_size); if (guarantee_connectivity) { -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - mst_graph_num_edges(i) = 0; - } raft::common::nvtx::range block_scope( "cagra::graph::optimize/check_connectivity"); RAFT_LOG_INFO("MST optimization is used to guarantee graph connectivity."); @@ -1731,7 +1772,7 @@ void optimize( auto d_rev_graph = raft::make_device_mdarray(res, raft::make_extents(0, 0)); try { d_rev_graph = raft::make_device_mdarray( - res, raft::make_extents(graph_size, output_graph_degree)); + res, default_ws_mr, raft::make_extents(graph_size, output_graph_degree)); } catch (const std::exception& e) { RAFT_LOG_DEBUG( "Failed to create device matrix for reverse graph, switching to large workspace resource"); @@ -1765,14 +1806,12 @@ void optimize( raft::resource::sync_stream(res); - if constexpr (AccessorOutputGraph::is_host_accessible) { - // following checks require host access - log_incoming_edges_histogram(new_graph.data_handle(), graph_size, output_graph_degree); + // These host-side checks are expensive (O(N*D^2)) and only used as debug + // diagnostics, so only run them when debug logging is active at runtime. + if (raft::default_logger().should_log(rapids_logger::level_enum::debug)) { + log_incoming_edges_histogram(res, new_graph); - check_duplicates_and_out_of_range( - new_graph.data_handle(), graph_size, output_graph_degree); - } else { - RAFT_LOG_DEBUG("Output graph is on GPU, skipping checks"); + check_duplicates_and_out_of_range(res, new_graph); } } From 533d19bb3881f0b0f801c4d60b13a0d98f9d0b21 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Fri, 8 May 2026 13:32:12 +0000 Subject: [PATCH 39/40] fix async writeback without initialization --- cpp/src/neighbors/detail/ann_utils.cuh | 150 ++++++++++++++++++------- 1 file changed, 110 insertions(+), 40 deletions(-) diff --git a/cpp/src/neighbors/detail/ann_utils.cuh b/cpp/src/neighbors/detail/ann_utils.cuh index d4aa18efce..bdb2c310c6 100644 --- a/cpp/src/neighbors/detail/ann_utils.cuh +++ b/cpp/src/neighbors/detail/ann_utils.cuh @@ -430,14 +430,24 @@ inline auto get_prefetch_stream(raft::resources const& res) * (`initialize=false, host_writeback=true`) when the kernel produces the result from scratch. * At least one of them must be true. * - * Stream model: - * * The user passes `copy_stream`. With `prefetch=true`, this should be a stream distinct from - * `res`'s main stream (use `get_prefetch_stream(res)`); otherwise no real overlap is possible. - * * `prefetch_next_batch()` queues D2H of the just-completed batch (if dirty) followed by - * H2D of the next batch (if `initialize`) on `copy_stream`. With prefetch enabled it then - * calls `sync_stream(res)` so the host stall on the main stream overlaps with the copies on - * `copy_stream`. With prefetch disabled, it synchronizes `copy_stream` directly. - * * `operator*` drains `copy_stream` so the slot is fully staged before the caller dereferences. + * Stream model (prefetch=true): + * * The user passes `copy_stream`, which should be distinct from `res`'s main stream (use + * `get_prefetch_stream(res)`); otherwise no real overlap is possible. + * * Writeback is done with a 2-iteration delay so that the D2H of batch `i` is queued in the + * `prefetch_next_batch()` of iteration `i+2` -- right before the same slot is overwritten by + * the next H2D. This way the D2H reads a slot whose user kernel (`kernel_i`) is two + * iterations old and therefore guaranteed finished, without needing CUDA events for + * cross-stream synchronization. + * * `prefetch_next_batch()` queues, on `copy_stream`, in order: (1) D2H of the prefetch slot's + * stale kernel output (if `host_writeback`), then (2) H2D of `pos` into the prefetch slot + * (if `initialize`); then it `sync_stream(res)`'s so the host stall for the previous kernel + * overlaps with the just-queued copies. + * * `operator*` (next iteration's `load()`) swaps the ring slots and `synchronize()`'s + * `copy_stream` so the swapped-in slot is fully staged before the user's next kernel runs. + * + * Stream model (prefetch=false): + * * Single buffer; copies happen synchronously inside `operator*` with `copy_stream` then + * `synchronize()`'d from the host. No overlap is attempted. * * Iteration ends when `operator++` reaches `n_iters_`. The iterator can be reused via `reset()`. * @@ -493,11 +503,26 @@ struct batch_load_iterator { ~batch() noexcept { if constexpr (!kPassthrough) { - // Flush any pending writeback for the slot still held in dev_ptr_. - // The "other" slot's writeback (if any) was issued at the last load() that swapped to it. - if (host_writeback_ && source_ != nullptr && dirty_cur_ && pos_.has_value()) { - queue_d2h(dev_ptr_, *pos_); - dirty_cur_ = false; + if (host_writeback_ && source_ != nullptr) { + // Two slots may still hold un-flushed kernel output: + // * `prefetch_dev_ptr_` for the batch the loop didn't reach (its prefetch_next_batch() + // that would have queued the D2H never happened). + // * `dev_ptr_` for the most-recently-loaded batch (its writeback is always deferred to + // the destructor since no future iteration recycles its slot). + // The user kernel that wrote either slot may still be in flight on `res`'s main stream + // when this destructor runs, so host-stall on it before issuing the D2Hs to avoid a + // read-while-write race. + const bool has_pending = + (prefetch_ && prefetch_dirty_pos_.has_value()) || (dirty_cur_ && pos_.has_value()); + if (has_pending) { raft::resource::sync_stream(*res_); } + if (prefetch_ && prefetch_dirty_pos_.has_value()) { + queue_d2h(prefetch_dev_ptr_, *prefetch_dirty_pos_); + prefetch_dirty_pos_.reset(); + } + if (dirty_cur_ && pos_.has_value()) { + queue_d2h(dev_ptr_, *pos_); + dirty_cur_ = false; + } } } // Stream is shared with the iterator; it must be sync'd before the underlying buffers (or, @@ -614,10 +639,21 @@ struct batch_load_iterator { } /** - * Make this batch represent position `pos`. In copy_device mode this synchronously stages - * H2D if needed; in passthrough mode this is pure bookkeeping (the per-batch view is - * recomputed on demand by `view()` via `cuda::std::submdspan`, never via pointer arithmetic - * on the input mdspan). + * Make this batch represent position `pos`. + * + * Passthrough: pure bookkeeping; the per-batch view is recomputed on demand by `view()` via + * `cuda::std::submdspan`, never via pointer arithmetic on the input mdspan. + * + * Copy_device, prefetch=true: swap the ring slots and host-sync `copy_stream` so the + * swapped-in slot is fully staged before the user's next kernel runs. The slot we swapped + * out (now `prefetch_dev_ptr_`) carried the just-completed kernel's writes; if `host_writeback` + * is on, its writeback is recorded for the *next* `prefetch_next_batch()` to flush -- when + * that slot is about to be overwritten by a new H2D, by which time the kernel that wrote it + * is two iterations old and guaranteed finished (the legacy 2-iteration-delay model). + * + * Copy_device, prefetch=false: synchronously stage H2D into the single buffer and host-sync + * `copy_stream`. + * * No-op if the buffer already holds `pos`. Iteration end is signaled by `pos >= n_iters_`. */ void load(size_type pos) @@ -643,22 +679,29 @@ struct batch_load_iterator { return; } - // Always issue D2H of the slot we're about to leave (or recycle) BEFORE swapping in - // / overwriting it with new data. With prefetch=true the prior kernel has already been - // sync'd by the previous prefetch_next_batch()'s sync_stream(res); with prefetch=false - // copies serialize on a single stream so D2H precedes H2D into the same buffer. - if (host_writeback_ && dirty_cur_ && pos_.has_value()) { - queue_d2h(dev_ptr_, *pos_); - dirty_cur_ = false; - } if (prefetch_ && prefetch_pos_.has_value() && *prefetch_pos_ == pos) { - // Swap to the prefetched slot. The previously-current slot moves into prefetch_dev_ptr_; - // its writeback (if any) was issued just above. + // Hand-off from the prefetch slot. Before swapping, transfer the soon-to-be-stale + // dev_ptr_'s dirty state into prefetch_dirty_pos_: after the swap, the slot that just + // received `kernel_(prev pos_)`'s writes becomes prefetch_dev_ptr_, and its D2H will be + // queued by the *next* prefetch_next_batch() right before the slot is overwritten by a + // new H2D. By that point, the kernel that wrote it is two iterations old -- finished + // by main-stream FIFO, no events required. + if (host_writeback_ && dirty_cur_ && pos_.has_value()) { + prefetch_dirty_pos_ = pos_; + } else { + prefetch_dirty_pos_.reset(); + } std::swap(dev_ptr_, prefetch_dev_ptr_); prefetch_pos_.reset(); - // Drain copy_stream so the swapped-in slot is fully staged before user reads. + // Ensure prefetch_next_batch()'s queued H2D into this slot (and any prior D2H of the + // slot from the previous overwrite) finished before the user kernel reads it. copy_stream_.synchronize(); } else { + // Non-pipelined fast path (prefetch_=false, or prefetch_pos_ didn't match). + if (host_writeback_ && dirty_cur_ && pos_.has_value()) { + queue_d2h(dev_ptr_, *pos_); + dirty_cur_ = false; + } if (initialize_) { queue_h2d(dev_ptr_, row_offset, len); } copy_stream_.synchronize(); } @@ -666,19 +709,29 @@ struct batch_load_iterator { batch_len_ = len; if (host_writeback_) { // Every advanced batch is implicitly dirty: the user kernel will write to it before - // the next load() / prefetch() recycles the slot. + // the next prefetch_next_batch() (or destructor) recycles the slot. dirty_cur_ = true; } } } /** - * Queue H2D for `pos` into the not-currently-visible slot, plus D2H of the previously - * dirtied (just-completed) slot. No-op if prefetch is disabled, source is null, or - * `pos >= n_iters_`. + * Cross-stream pipelining step (called by the user *after* enqueuing the kernel for the + * current batch). + * + * On `copy_stream`, queues -- in this order, into the *prefetch* slot only: + * 1. D2H of the prefetch slot's stale kernel output (if `host_writeback` and the slot was + * written by a kernel two iterations ago). The kernel that produced that output is + * already finished by main-stream FIFO (the user has since enqueued the kernel for + * the slot currently in `dev_ptr_`), so no event is needed -- the D2H reads a slot the + * main stream is no longer touching. + * 2. H2D of `pos` into the prefetch slot (if `initialize`). + * + * Then host-stalls on `res`'s main stream so the host has waited out the just-enqueued user + * kernel by the time control returns; meanwhile the D2H/H2D queued above runs concurrently + * on `copy_stream`, overlapping the writeback with the kernel. * - * With prefetch enabled this is followed by `sync_stream(res)` so the host-side memcpy - * stall on `copy_stream` overlaps with the user kernel on `res`'s main stream. + * No-op if prefetch is disabled, source is null, passthrough, or `pos >= n_iters_`. */ void prefetch(size_type pos) { @@ -689,9 +742,18 @@ struct batch_load_iterator { return; } - // Issue H2D of `pos` into prefetch_dev_ptr_ (the slot the user kernel is NOT on). - // Writeback of the "other" slot is unnecessary here because it was already issued at the - // last load() that recycled it. + // 2-iteration-delayed writeback: the prefetch slot still holds kernel output from two + // iterations ago (from when this slot was last `dev_ptr_`); D2H it now -- right before we + // overwrite the slot with the next H2D. The kernel that wrote it is past on the main + // stream, so this D2H runs concurrently with the just-enqueued kernel for `dev_ptr_` + // without racing (different slots). + if (host_writeback_ && prefetch_dirty_pos_.has_value()) { + queue_d2h(prefetch_dev_ptr_, *prefetch_dirty_pos_); + prefetch_dirty_pos_.reset(); + } + + // H2D of the next batch into the prefetch slot. Sequenced after the D2H on `copy_stream`, + // so the H2D doesn't clobber the slot before its prior contents have been read out. if (initialize_) { const size_type row_offset = pos * batch_size_; const size_type len = @@ -700,9 +762,10 @@ struct batch_load_iterator { } prefetch_pos_.emplace(pos); - // Wait for the kernel paired with this prefetch_next_batch() before returning, so the next - // operator* can safely swap-and-read the slot. Do this AFTER queueing copies, so the host - // stall overlaps with both the kernel and the copies. + // Host-stall on the user's main stream: the just-enqueued kernel runs concurrently with + // the copies above. By the time control returns, the kernel is done and the next load() + // can safely read its writes (still in `dev_ptr_`'s slot, which becomes prefetch_dev_ptr_ + // after the upcoming swap). raft::resource::sync_stream(*res_); } @@ -753,8 +816,15 @@ struct batch_load_iterator { // Slot bookkeeping (only meaningful for !kPassthrough). element_type* dev_ptr_ = nullptr; element_type* prefetch_dev_ptr_ = nullptr; + // `pos_`: batch index currently held in `dev_ptr_`'s slot (the user-facing slot). + // `prefetch_pos_`: batch index pre-staged in `prefetch_dev_ptr_`'s slot via H2D. + // `prefetch_dirty_pos_`: batch index whose kernel output is sitting in `prefetch_dev_ptr_`'s + // slot and still needs to be D2H'd back to host. Set by `load()`'s + // swap when the slot it just retired held kernel writes; consumed + // (D2H'd) by the *next* `prefetch()` -- the legacy 2-iteration delay. std::optional pos_; std::optional prefetch_pos_; + std::optional prefetch_dirty_pos_; size_type batch_len_ = 0; bool dirty_cur_ = false; }; From 28372a3ec18f45833611849ee1d431fedc97df1a Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Tue, 12 May 2026 21:26:59 +0000 Subject: [PATCH 40/40] more suggestions --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index a7ea1bbc91..2c21018deb 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -1673,10 +1673,8 @@ void prune_graph_gpu( RAFT_LOG_DEBUG("\n"); uint32_t invalid_neighbor_list = 0; - raft::copy(&invalid_neighbor_list, - d_invalid_neighbor_list.data_handle(), - 1, - raft::resource::get_cuda_stream(res)); + raft::copy( + res, raft::make_host_scalar_view(&invalid_neighbor_list), d_invalid_neighbor_list.view()); raft::copy(res, host_stats.view(), raft::make_const_mdspan(dev_stats.view())); raft::resource::sync_stream(res);