Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
609b0f3
prune kernel smem
mfoerste4 Feb 16, 2026
a320e0e
reduce copies within reverse graph compute
mfoerste4 Feb 18, 2026
6d1a618
optimize() draft move more compute to GPU
mfoerste4 Feb 19, 2026
77ab079
Merge branch 'rapidsai:main' into cagra_optimize
mfoerste4 Feb 19, 2026
008e0fb
Merge branch 'rapidsai:main' into cagra_optimize
mfoerste4 Feb 20, 2026
822faea
some fixes, cleanup
mfoerste4 Feb 20, 2026
8ed1497
Merge branch 'main' into cagra_optimize
mfoerste4 Feb 24, 2026
9b1f741
some fixes
mfoerste4 Feb 25, 2026
ecf3b1d
extract prune into separate function
mfoerste4 Feb 27, 2026
972d278
extract optimize components
mfoerste4 Mar 2, 2026
5e9ebc5
enable both host/device inout graphs for optimize
mfoerste4 Mar 2, 2026
8f24d9d
resolve conflicts
mfoerste4 Mar 2, 2026
40977e2
smaller fixes
mfoerste4 Mar 2, 2026
14e9f3e
bugfix
mfoerste4 Mar 3, 2026
416558d
fuse and simplify pruning, remove CPU path
mfoerste4 Mar 5, 2026
d8d8bd8
cleanup merge, remove CPU path
mfoerste4 Mar 5, 2026
00c4204
batch reverse creation
mfoerste4 Mar 6, 2026
9e63a7c
add prefetch view to handle managed & host
mfoerste4 Mar 6, 2026
a38ad52
fix batched iterator
mfoerste4 Mar 9, 2026
89b0d1c
implement fallback / simplify strategy
mfoerste4 Mar 9, 2026
d0e3dae
add logging / remove stats compute
mfoerste4 Mar 10, 2026
ec45fd2
add test, persist stream pool, cleanup
mfoerste4 Mar 10, 2026
e43b51b
Merge branch 'main' into cagra_optimize
mfoerste4 Mar 10, 2026
c412138
switch to cooperative groups as __reduce_min_sync causes issues
mfoerste4 Mar 11, 2026
b035ea0
Merge branch 'cagra_optimize' of github.com:mfoerste4/cuvs into cagra…
mfoerste4 Mar 11, 2026
ab01bab
back to column wise reverse graph creation to boost closer connections
mfoerste4 Mar 13, 2026
139774f
Merge branch 'main' into cagra_optimize
mfoerste4 Mar 13, 2026
68f7883
fix signness
mfoerste4 Mar 13, 2026
add206a
stupid me trusting cursor to fix this
mfoerste4 Mar 13, 2026
ef1ec18
remove pointer arithmetic part1
mfoerste4 Mar 16, 2026
01e1336
remove pointer arithmetic part2
mfoerste4 Mar 16, 2026
5a24fb0
fix mst graph usage
mfoerste4 Mar 17, 2026
c033436
remove memcopy2D
mfoerste4 Mar 17, 2026
cc7efa7
Merge branch 'main' into cagra_optimize
mfoerste4 Mar 17, 2026
0b2a2dd
Merge branch 'main' into cagra_optimize
mfoerste4 Mar 23, 2026
333fe46
Merge branch 'main' into cagra_optimize
mfoerste4 Mar 25, 2026
7d31bb3
Merge branch 'main' into cagra_optimize
mfoerste4 Mar 30, 2026
ed4fb8b
Merge branch 'main' into cagra_optimize
mfoerste4 Apr 17, 2026
a585dcb
review suggestions
mfoerste4 Apr 21, 2026
b381d32
Merge branch 'main' into cagra_optimize
mfoerste4 Apr 21, 2026
1f0ce37
fix merge conflict
mfoerste4 Apr 21, 2026
24c606e
more review suggestions
mfoerste4 Apr 22, 2026
a61ff0f
Merge branch 'main' into cagra_optimize
mfoerste4 Apr 22, 2026
f695763
coderabbit suggestions
mfoerste4 Apr 23, 2026
023fbc6
trying to fix IllegalAccess
mfoerste4 Apr 24, 2026
3437700
try more fixes for V100/cuda12.2
mfoerste4 Apr 27, 2026
b72ad17
Merge branch 'main' into cagra_optimize
mfoerste4 Apr 28, 2026
f04022c
refactor remove all device pointer arithmetic from batch_device_view,…
mfoerste4 May 1, 2026
cf86064
simplify batched_view to 2 buffers 1 copy stream
mfoerste4 May 1, 2026
aa250a7
Merge branch 'main' into cagra_optimize
mfoerste4 May 1, 2026
821eae6
stream-sync fix typo
mfoerste4 May 4, 2026
cd7be32
Merge branch 'main' into cagra_optimize
mfoerste4 May 4, 2026
22b32cb
Merge branch 'main' into cagra_optimize
mfoerste4 May 7, 2026
6211ff3
merge into batch_load_iterator
mfoerste4 May 7, 2026
cc8d892
more review suggestions
mfoerste4 May 8, 2026
fad99af
Merge branch 'cagra_optimize' of github.com:mfoerste4/cuvs into cagra…
mfoerste4 May 8, 2026
0b6ea72
fix merge conflict within kmeans
mfoerste4 May 8, 2026
3b70439
more review suggestions
mfoerste4 May 8, 2026
533d19b
fix async writeback without initialization
mfoerste4 May 8, 2026
d0a4cfd
Merge branch 'main' into cagra_optimize
mfoerste4 May 11, 2026
28372a3
more suggestions
mfoerste4 May 12, 2026
7322903
Merge branch 'cagra_optimize' of github.com:mfoerste4/cuvs into cagra…
mfoerste4 May 12, 2026
bfc0520
Merge branch 'main' into cagra_optimize
mfoerste4 May 12, 2026
5369eb6
Merge branch 'main' into cagra_optimize
mfoerste4 May 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions cpp/src/cluster/detail/kmeans.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -694,14 +694,15 @@ void kmeans_fit(

rmm::device_uvector<char> batch_workspace(streaming_batch_size, stream);

cuvs::spatial::knn::detail::utils::batch_load_iterator<DataT> data_batches(
X.data_handle(), n_samples, n_features, streaming_batch_size, stream);
auto data_batches = cuvs::spatial::knn::detail::utils::make_batch_load_iterator<DataT>(
handle, X.data_handle(), n_samples, n_features, streaming_batch_size, stream);
// Host-path weight batches: only materialized when weights are provided and
// the data resides on host
std::optional<cuvs::spatial::knn::detail::utils::batch_load_iterator<DataT>> weight_batches;
std::optional<cuvs::spatial::knn::detail::utils::batch_load_iterator_dyn<DataT>> weight_batches;
if constexpr (!data_on_device) {
if (weight_ptr != nullptr) {
weight_batches.emplace(weight_ptr, n_samples, 1, streaming_batch_size, stream);
weight_batches = cuvs::spatial::knn::detail::utils::make_batch_load_iterator<DataT>(
handle, weight_ptr, n_samples, IndexT{1}, streaming_batch_size, stream);
} else {
raft::matrix::fill(handle, batch_weights_buf.view(), DataT{1});
}
Expand Down Expand Up @@ -833,7 +834,7 @@ void kmeans_fit(
raft::make_device_matrix_view<DataT, IndexT>(new_centroids_ptr, n_clusters, n_features);

data_batches.reset();
using wt_iter_t = cuvs::spatial::knn::detail::utils::batch_load_iterator<DataT>;
using wt_iter_t = cuvs::spatial::knn::detail::utils::batch_load_iterator_dyn<DataT>;
std::optional<wt_iter_t> wt_it;
if (weight_batches.has_value()) {
weight_batches->reset();
Expand Down Expand Up @@ -932,7 +933,7 @@ void kmeans_fit(

iter_inertia = DataT{0};
data_batches.reset();
using wt_iter_t = cuvs::spatial::knn::detail::utils::batch_load_iterator<DataT>;
using wt_iter_t = cuvs::spatial::knn::detail::utils::batch_load_iterator_dyn<DataT>;
std::optional<wt_iter_t> wt_it;
if (weight_batches.has_value()) {
weight_batches->reset();
Expand Down
947 changes: 768 additions & 179 deletions cpp/src/neighbors/detail/ann_utils.cuh

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions cpp/src/neighbors/detail/cagra/add_nodes.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,11 @@ void add_node_core(
auto host_neighbor_indices =
raft::make_host_matrix<IdxT, std::int64_t>(max_search_batch_size, base_degree);

cuvs::spatial::knn::detail::utils::batch_load_iterator<T> additional_dataset_batch(
auto additional_dataset_batch = cuvs::spatial::knn::detail::utils::make_batch_load_iterator<T>(
handle,
additional_dataset_view.data_handle(),
num_add,
additional_dataset_view.stride(0),
static_cast<std::int64_t>(num_add),
static_cast<std::int64_t>(additional_dataset_view.stride(0)),
max_search_batch_size,
raft::resource::get_cuda_stream(handle),
mr);
Expand Down
47 changes: 27 additions & 20 deletions cpp/src/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -831,29 +831,34 @@ inline std::pair<size_t, size_t> optimize_workspace_size(size_t n_rows,
mst_host += (graph_degree - 1) * (graph_degree - 1) * index_size; // iB_candidates
}

// batchsize for both prune and combine stages
size_t batch_size = std::min(static_cast<size_t>(256 * 1024), n_rows);
Comment thread
mfoerste4 marked this conversation as resolved.

// Prune stage memory
// We neglect 8 bytes (both on host and device) for stats
size_t prune_host = n_rows * intermediate_degree * sizeof(uint8_t); // detour count

size_t prune_dev = n_rows * intermediate_degree * 1; // detour count (uint8_t)
prune_dev += n_rows * sizeof(uint32_t); // d_num_detour_edges
prune_dev += n_rows * intermediate_degree * index_size; // d_input_graph
size_t prune_dev = batch_size * intermediate_degree * 1; // detour count (uint8_t)
prune_dev += batch_size * sizeof(uint32_t); // d_num_detour_edges
prune_dev += n_rows * intermediate_degree * index_size; // d_input_graph
prune_dev += 2 * batch_size * graph_degree * index_size; // d_output_graph(2*batch)

// Reverse graph stage memory
size_t rev_host = n_rows * graph_degree * index_size; // rev_graph
rev_host += n_rows * sizeof(uint32_t); // rev_graph_count
rev_host += n_rows * index_size; // dest_nodes

size_t rev_dev = n_rows * graph_degree * index_size; // d_rev_graph
rev_dev += n_rows * sizeof(uint32_t); // d_rev_graph_count
rev_dev += n_rows * sizeof(uint32_t); // d_dest_nodes
rev_dev += n_rows * index_size; // d_dest_nodes

Comment thread
mfoerste4 marked this conversation as resolved.
// Memory for merging graphs (host only)
// Memory for merging graphs (host only optional)
size_t combine_host =
n_rows * sizeof(uint32_t) + graph_degree * sizeof(uint32_t); // in_edge_count + hist

size_t total_host = mst_host + std::max({prune_host, rev_host, combine_host});
size_t total_dev = std::max(prune_dev, rev_dev);
// additional memory for combine stage on device (3 batches)
size_t combine_dev = 2 * batch_size * graph_degree * index_size; // d_output_graph(2*batch)
if (mst_optimize) {
combine_dev += 2 * batch_size * graph_degree * index_size; // d_mst_graph(2*batch)
combine_dev += 2 * batch_size * sizeof(uint32_t); // d_mst_graph_num_edges(2*batch)
}

size_t total_host = mst_host + combine_host;
size_t total_dev = std::max(prune_dev, rev_dev + combine_dev);
Comment thread
mfoerste4 marked this conversation as resolved.

return std::make_pair(total_host, total_dev);
}
Expand Down Expand Up @@ -1725,11 +1730,12 @@ void build_knn_graph(
bool first = true;
const auto start_clock = std::chrono::system_clock::now();

cuvs::spatial::knn::detail::utils::batch_load_iterator<DataT> vec_batches(
auto vec_batches = cuvs::spatial::knn::detail::utils::make_batch_load_iterator<DataT>(
res,
dataset.data_handle(),
dataset.extent(0),
dataset.extent(1),
static_cast<int64_t>(max_queries),
static_cast<int64_t>(dataset.extent(0)),
static_cast<int64_t>(dataset.extent(1)),
static_cast<size_t>(max_queries),
raft::resource::get_cuda_stream(res),
workspace_mr);

Expand Down Expand Up @@ -2110,10 +2116,11 @@ auto iterative_build_graph(

// Search.
// Since there are many queries, divide them into batches and search them.
cuvs::spatial::knn::detail::utils::batch_load_iterator<T> query_batch(
auto query_batch = cuvs::spatial::knn::detail::utils::make_batch_load_iterator<T>(
res,
dev_query_view.data_handle(),
curr_query_size,
dev_query_view.extent(1),
static_cast<int64_t>(curr_query_size),
static_cast<int64_t>(dev_query_view.extent(1)),
max_chunk_size,
raft::resource::get_cuda_stream(res),
raft::resource::get_workspace_resource_ref(res));
Expand Down
Loading
Loading