diff --git a/cpp/src/sampling/detail/deduplicate_edges_by_minor_impl.cuh b/cpp/src/sampling/detail/deduplicate_edges_by_minor_impl.cuh index 67343a38e3..c41a9c4b50 100644 --- a/cpp/src/sampling/detail/deduplicate_edges_by_minor_impl.cuh +++ b/cpp/src/sampling/detail/deduplicate_edges_by_minor_impl.cuh @@ -18,14 +18,11 @@ #include #include -#include -#include -#include -#include #include #include #include +#include namespace cugraph { namespace detail { @@ -35,18 +32,30 @@ std::tuple, rmm::device_uvector, arithmetic_device_uvector_t, std::optional>, - std::optional>, + rmm::device_uvector, + rmm::device_uvector, + arithmetic_device_uvector_t, + arithmetic_device_uvector_t, std::optional>> deduplicate_edges_by_minor(raft::handle_t const& handle, graph_view_t const& graph_view, rmm::device_uvector&& result_majors, rmm::device_uvector&& result_minors, - arithmetic_device_uvector_t&& tmp_edge_indices, - std::optional>&& result_labels, - bool call_from_sampling) + arithmetic_device_uvector_t&& result_edge_property, + arithmetic_device_uvector_t&& result_types, + std::optional>&& result_labels) { - std::optional> resample_majors{std::nullopt}; - std::optional> resample_major_labels{std::nullopt}; + bool const has_edge_property = !std::holds_alternative(result_edge_property); + bool const has_types = !std::holds_alternative(result_types); + bool const has_labels = result_labels.has_value(); + if (has_types) { + CUGRAPH_EXPECTS(std::holds_alternative>(result_types), + "result_types must be rmm::device_uvector when present."); + } + if (has_edge_property) { + CUGRAPH_EXPECTS(std::holds_alternative>(result_edge_property), + "result_edge_property must be rmm::device_uvector when present."); + } size_t total_edges = result_majors.size(); @@ -58,353 +67,254 @@ deduplicate_edges_by_minor(raft::handle_t const& handle, if (total_edges == 0) { return std::make_tuple(std::move(result_majors), std::move(result_minors), - std::move(tmp_edge_indices), + std::move(result_edge_property), std::move(result_labels), - std::move(resample_majors), - std::move(resample_major_labels)); + rmm::device_uvector(0, handle.get_stream()), + rmm::device_uvector(0, handle.get_stream()), + arithmetic_device_uvector_t{std::monostate{}}, + arithmetic_device_uvector_t{std::monostate{}}, + std::optional>{std::nullopt}); } // 1. Shuffle the edges to GPUs by minor vertex id if multi-gpu - std::optional> keep_ranks{std::nullopt}; - rmm::device_uvector keep_majors(result_minors.size(), handle.get_stream()); - std::optional> keep_labels{std::nullopt}; - rmm::device_uvector keep_minors(result_minors.size(), handle.get_stream()); - rmm::device_uvector keep_positions(result_minors.size(), handle.get_stream()); - - raft::copy(keep_majors.begin(), result_majors.begin(), result_majors.size(), handle.get_stream()); - raft::copy(keep_minors.begin(), result_minors.begin(), result_minors.size(), handle.get_stream()); - - if (result_labels) { - keep_labels = - std::make_optional>(result_labels->size(), handle.get_stream()); - raft::copy( - keep_labels->begin(), result_labels->begin(), result_labels->size(), handle.get_stream()); - } - - thrust::sequence( - handle.get_thrust_policy(), keep_positions.begin(), keep_positions.end(), size_t{0}); - if constexpr (multi_gpu) { - keep_ranks = - std::make_optional>(keep_positions.size(), handle.get_stream()); - thrust::fill(handle.get_thrust_policy(), - keep_ranks->begin(), - keep_ranks->end(), - handle.get_comms().get_rank()); - std::vector shuffle_properties{}; - shuffle_properties.push_back(std::move(keep_majors)); - shuffle_properties.push_back(std::move(*keep_ranks)); - shuffle_properties.push_back(std::move(keep_positions)); - if (result_labels) { shuffle_properties.push_back(std::move(*keep_labels)); } + shuffle_properties.push_back(std::move(result_majors)); + if (has_edge_property) { shuffle_properties.push_back(std::move(result_edge_property)); } + if (has_types) { shuffle_properties.push_back(std::move(result_types)); } + if (has_labels) { shuffle_properties.push_back(std::move(*result_labels)); } - std::tie(keep_minors, shuffle_properties) = + std::tie(result_minors, shuffle_properties) = shuffle_int_vertices(handle, - std::move(keep_minors), + std::move(result_minors), std::move(shuffle_properties), graph_view.vertex_partition_range_lasts(), std::nullopt); - keep_majors = std::move(std::get>(shuffle_properties[0])); - keep_ranks = std::move(std::get>(shuffle_properties[1])); - keep_positions = std::move(std::get>(shuffle_properties[2])); - if (result_labels) { - keep_labels = std::move(std::get>(shuffle_properties[3])); + result_majors = std::move(std::get>(shuffle_properties[0])); + size_t shuffle_prop_idx{1}; + if (has_edge_property) { + result_edge_property = std::move(shuffle_properties[shuffle_prop_idx++]); + } + if (has_types) { result_types = std::move(shuffle_properties[shuffle_prop_idx++]); } + if (has_labels) { + result_labels = + std::move(std::get>(shuffle_properties[shuffle_prop_idx++])); } } - // 2. Now all edges that lead to a visited minor are on this GPU, we can - // Sort by minor vertex id and identify duplicates. - rmm::device_uvector local_positions(keep_minors.size(), handle.get_stream()); - thrust::sequence( - handle.get_thrust_policy(), local_positions.begin(), local_positions.end(), size_t{0}); - - // FIXME: After we refactor sample_edges_to_unvisited_neighbors to use the partial results we - // can remove the majors from this sort. Until then we need to sort by (label, minor, - // major, position) to guarantee that each iteration of the outer loop makes progress. - if (keep_labels) { - thrust::sort( - handle.get_thrust_policy(), - thrust::make_zip_iterator( - keep_labels->begin(), keep_minors.begin(), keep_majors.begin(), local_positions.begin()), - thrust::make_zip_iterator( - keep_labels->end(), keep_minors.end(), keep_majors.end(), local_positions.end())); + // 2. Sort the edges by minor vertex id and identify duplicates + if (has_edge_property) { + auto& property = std::get>(result_edge_property); + if (has_types) { + auto& types = std::get>(result_types); + if (has_labels) { + thrust::sort(handle.get_thrust_policy(), + thrust::make_zip_iterator(result_labels->begin(), + result_minors.begin(), + result_majors.begin(), + property.begin(), + types.begin()), + thrust::make_zip_iterator(result_labels->end(), + result_minors.end(), + result_majors.end(), + property.end(), + types.end())); + } else { + thrust::sort( + handle.get_thrust_policy(), + thrust::make_zip_iterator( + result_minors.begin(), result_majors.begin(), property.begin(), types.begin()), + thrust::make_zip_iterator( + result_minors.end(), result_majors.end(), property.end(), types.end())); + } + } else if (has_labels) { + thrust::sort( + handle.get_thrust_policy(), + thrust::make_zip_iterator( + result_labels->begin(), result_minors.begin(), result_majors.begin(), property.begin()), + thrust::make_zip_iterator( + result_labels->end(), result_minors.end(), result_majors.end(), property.end())); + } else { + thrust::sort( + handle.get_thrust_policy(), + thrust::make_zip_iterator(result_minors.begin(), result_majors.begin(), property.begin()), + thrust::make_zip_iterator(result_minors.end(), result_majors.end(), property.end())); + } } else { - thrust::sort( - handle.get_thrust_policy(), - thrust::make_zip_iterator(keep_minors.begin(), keep_majors.begin(), local_positions.begin()), - thrust::make_zip_iterator(keep_minors.end(), keep_majors.end(), local_positions.end())); + if (has_types) { + auto& types = std::get>(result_types); + if (has_labels) { + thrust::sort( + handle.get_thrust_policy(), + thrust::make_zip_iterator( + result_labels->begin(), result_minors.begin(), result_majors.begin(), types.begin()), + thrust::make_zip_iterator( + result_labels->end(), result_minors.end(), result_majors.end(), types.end())); + } else { + thrust::sort( + handle.get_thrust_policy(), + thrust::make_zip_iterator(result_minors.begin(), result_majors.begin(), types.begin()), + thrust::make_zip_iterator(result_minors.end(), result_majors.end(), types.end())); + } + } else if (has_labels) { + thrust::sort( + handle.get_thrust_policy(), + thrust::make_zip_iterator( + result_labels->begin(), result_minors.begin(), result_majors.begin()), + thrust::make_zip_iterator(result_labels->end(), result_minors.end(), result_majors.end())); + } else { + thrust::sort(handle.get_thrust_policy(), + thrust::make_zip_iterator(result_minors.begin(), result_majors.begin()), + thrust::make_zip_iterator(result_minors.end(), result_majors.end())); + } } // 3. Mark the edges to keep locally size_t keep_count{0}; rmm::device_uvector keep_flags(0, handle.get_stream()); - - // We'll keep the first edge for each (label, minor) pair. - if (keep_labels) { - std::tie(keep_count, keep_flags) = - detail::mark_entries(handle, - keep_minors.size(), - detail::is_first_in_run_tbegin(), keep_minors.begin()))>{ - thrust::make_zip_iterator(keep_labels->begin(), keep_minors.begin())}); + if (has_labels) { + std::tie(keep_count, keep_flags) = detail::mark_entries( + handle, + result_minors.size(), + detail::is_first_in_run_tbegin(), + result_minors.begin()))>{ + thrust::make_zip_iterator(result_labels->begin(), result_minors.begin())}); } else { std::tie(keep_count, keep_flags) = detail::mark_entries( handle, - keep_minors.size(), - detail::is_first_in_run_t{keep_minors.begin()}); + result_minors.size(), + detail::is_first_in_run_t{result_minors.begin()}); } - size_t global_remove_count{keep_minors.size() - keep_count}; - if constexpr (multi_gpu) { - global_remove_count = host_scalar_allreduce(handle.get_comms(), - (keep_minors.size() - keep_count), - raft::comms::op_t::SUM, - handle.get_stream()); + // split to new result_majors and discarded_majors, then minors, edge_property, types and labels + rmm::device_uvector discarded_majors(0, handle.get_stream()); + rmm::device_uvector discarded_minors(0, handle.get_stream()); + arithmetic_device_uvector_t discarded_edge_property{std::monostate{}}; + arithmetic_device_uvector_t discarded_types{std::monostate{}}; + if (has_edge_property) { + discarded_edge_property = rmm::device_uvector(0, handle.get_stream()); } + if (has_types) { discarded_types = rmm::device_uvector(0, handle.get_stream()); } + std::optional> discarded_labels{std::nullopt}; + + if (keep_count < result_minors.size()) { + size_t const discard_count = result_minors.size() - keep_count; + raft::device_span const keep_flags_span{keep_flags.data(), keep_flags.size()}; + + discarded_majors.resize(discard_count, handle.get_stream()); + detail::copy_if_mask_unset(handle, + result_majors.begin(), + result_majors.end(), + keep_flags.begin(), + discarded_majors.begin()); + result_majors = + detail::keep_marked_entries(handle, std::move(result_majors), keep_flags_span, keep_count); + + discarded_minors.resize(discard_count, handle.get_stream()); + detail::copy_if_mask_unset(handle, + result_minors.begin(), + result_minors.end(), + keep_flags.begin(), + discarded_minors.begin()); + result_minors = + detail::keep_marked_entries(handle, std::move(result_minors), keep_flags_span, keep_count); + + if (has_edge_property) { + auto& property = std::get>(result_edge_property); + rmm::device_uvector discarded(discard_count, handle.get_stream()); + detail::copy_if_mask_unset( + handle, property.begin(), property.end(), keep_flags.begin(), discarded.begin()); + property = + detail::keep_marked_entries(handle, std::move(property), keep_flags_span, keep_count); + discarded_edge_property = std::move(discarded); + } - // 4. If we have any duplicates on any GPU we need to remove them - if (global_remove_count > 0) { - bool skip_shuffle_back_to_ranks{false}; - - if (call_from_sampling) { - // When called from sampling, we need to skip all edges that come from any major vertex that - // we are going to skip. - resample_majors = - std::make_optional>(keep_majors.size(), handle.get_stream()); - raft::copy( - resample_majors->begin(), keep_majors.begin(), keep_majors.size(), handle.get_stream()); - - if (keep_labels) { - resample_major_labels = std::make_optional>( - keep_labels->size(), handle.get_stream()); - raft::copy(resample_major_labels->begin(), - keep_labels->begin(), - keep_labels->size(), - handle.get_stream()); - } - - detail::invert_flags(handle, - raft::device_span{keep_flags.data(), keep_flags.size()}, - keep_minors.size()); - - resample_majors = detail::keep_marked_entries( - handle, - std::move(*resample_majors), - raft::device_span{keep_flags.data(), keep_flags.size()}, - (keep_minors.size() - keep_count)); - if (resample_major_labels) { - resample_major_labels = detail::keep_marked_entries( - handle, - std::move(*resample_major_labels), - raft::device_span{keep_flags.data(), keep_flags.size()}, - (keep_minors.size() - keep_count)); - } - - if constexpr (multi_gpu) { - std::vector shuffle_properties{}; - if (resample_major_labels) { - shuffle_properties.push_back(std::move(*resample_major_labels)); - } - std::tie(resample_majors, shuffle_properties) = - cugraph::shuffle_int_vertices(handle, - std::move(*resample_majors), - std::move(shuffle_properties), - graph_view.vertex_partition_range_lasts()); - if (resample_major_labels) { - resample_major_labels = - std::move(std::get>(shuffle_properties[0])); - } - } + if (has_types) { + auto& types = std::get>(result_types); + auto& discarded = std::get>(discarded_types); + discarded.resize(discard_count, handle.get_stream()); + detail::copy_if_mask_unset( + handle, types.begin(), types.end(), keep_flags.begin(), discarded.begin()); + } - if (resample_major_labels) { - auto new_begin = - thrust::make_zip_iterator(resample_major_labels->begin(), resample_majors->begin()); - thrust::sort(handle.get_thrust_policy(), new_begin, new_begin + resample_majors->size()); - auto new_end = thrust::unique( - handle.get_thrust_policy(), new_begin, new_begin + resample_majors->size()); - resample_majors->resize(cuda::std::distance(new_begin, new_end), handle.get_stream()); - resample_major_labels->resize(cuda::std::distance(new_begin, new_end), handle.get_stream()); - } else { - thrust::sort(handle.get_thrust_policy(), resample_majors->begin(), resample_majors->end()); - auto new_end = thrust::unique( - handle.get_thrust_policy(), resample_majors->begin(), resample_majors->end()); - resample_majors->resize(cuda::std::distance(resample_majors->begin(), new_end), - handle.get_stream()); - } + if (has_labels) { + discarded_labels = + std::make_optional(rmm::device_uvector(discard_count, handle.get_stream())); + detail::copy_if_mask_unset(handle, + result_labels->begin(), + result_labels->end(), + keep_flags.begin(), + discarded_labels->begin()); + *result_labels = + detail::keep_marked_entries(handle, std::move(*result_labels), keep_flags_span, keep_count); + } + } - rmm::device_uvector resample_majors_gathered(0, handle.get_stream()); - rmm::device_uvector resample_major_labels_gathered(0, handle.get_stream()); - raft::device_span resample_majors_span{resample_majors->data(), - resample_majors->size()}; - raft::device_span resample_major_labels_span{ - resample_major_labels ? resample_major_labels->data() : nullptr, - resample_major_labels ? resample_major_labels->size() : size_t{0}}; - - if constexpr (multi_gpu) { - auto& major_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); - - resample_majors_gathered = device_allgatherv(handle, major_comm, resample_majors_span); - resample_majors_span = raft::device_span{resample_majors_gathered.data(), - resample_majors_gathered.size()}; - if (resample_major_labels) { - resample_major_labels_gathered = - device_allgatherv(handle, - major_comm, - raft::device_span{resample_major_labels->data(), - resample_major_labels->size()}); - resample_major_labels_span = raft::device_span{ - resample_major_labels_gathered.data(), resample_major_labels_gathered.size()}; - } - } + if (has_types) { + auto& types = std::get>(result_types); + types.resize(0, handle.get_stream()); + types.shrink_to_fit(handle.get_stream()); + } - // Now we'll regenerate keep_flags to remove all entries that come from an entry in - // resample_majors This will include all of the current marked entries and potentially some - // new entries - - if (resample_major_labels) { - std::tie(keep_count, keep_flags) = - detail::mark_entries(handle, - result_majors.size(), - [majors = result_majors.data(), - labels = result_labels->data(), - resample_majors_span, - resample_major_labels_span] __device__(auto index) { - return !thrust::binary_search( - thrust::seq, - thrust::make_zip_iterator(resample_major_labels_span.begin(), - resample_majors_span.begin()), - thrust::make_zip_iterator(resample_major_labels_span.end(), - resample_majors_span.end()), - cuda::std::make_tuple(labels[index], majors[index])); - }); - } else { - std::tie(keep_count, keep_flags) = detail::mark_entries( - handle, - result_majors.size(), - [majors = result_majors.data(), resample_majors_span] __device__(auto index) { - return !thrust::binary_search( - thrust::seq, resample_majors_span.begin(), resample_majors_span.end(), majors[index]); - }); - } + // 4. Shuffle edges back to the source-owner GPU (inverse of step 1's shuffle by minor). + if constexpr (multi_gpu) { + std::vector shuffle_properties{}; + shuffle_properties.push_back(std::move(result_minors)); + if (has_edge_property) { shuffle_properties.push_back(std::move(result_edge_property)); } + if (has_labels) { shuffle_properties.push_back(std::move(*result_labels)); } - if constexpr (multi_gpu) { - resample_majors_gathered.resize(0, handle.get_stream()); - resample_majors_gathered.shrink_to_fit(handle.get_stream()); - if (resample_major_labels) { - resample_major_labels_gathered.resize(0, handle.get_stream()); - resample_major_labels_gathered.shrink_to_fit(handle.get_stream()); - } - } + std::tie(result_majors, shuffle_properties) = + shuffle_int_vertices(handle, + std::move(result_majors), + std::move(shuffle_properties), + graph_view.vertex_partition_range_lasts(), + std::nullopt); - raft::device_span const result_keep_flags{keep_flags.data(), - keep_flags.size()}; - result_majors = detail::keep_marked_entries( - handle, std::move(result_majors), result_keep_flags, keep_count); - result_minors = detail::keep_marked_entries( - handle, std::move(result_minors), result_keep_flags, keep_count); - if (!std::holds_alternative(tmp_edge_indices)) { - cugraph::variant_type_dispatch( - tmp_edge_indices, [&handle, result_keep_flags, keep_count](auto& property) { - property = detail::keep_marked_entries( - handle, std::move(property), result_keep_flags, keep_count); - }); - } - if (result_labels) { - *result_labels = detail::keep_marked_entries( - handle, std::move(*result_labels), result_keep_flags, keep_count); - } - skip_shuffle_back_to_ranks = true; - } else { - // Gather to reflect the original positions of the edges - rmm::device_uvector tmp(local_positions.size(), handle.get_stream()); - thrust::gather(handle.get_thrust_policy(), - local_positions.begin(), - local_positions.end(), - keep_positions.data(), - tmp.begin()); - keep_positions = std::move(tmp); + result_minors = std::move(std::get>(shuffle_properties[0])); + size_t shuffle_prop_idx{1}; + if (has_edge_property) { + result_edge_property = std::move(shuffle_properties[shuffle_prop_idx++]); + } + if (has_labels) { + result_labels = + std::move(std::get>(shuffle_properties[shuffle_prop_idx++])); } - if (!skip_shuffle_back_to_ranks) { - // 5. Now keep_ranks and keep_positions need to be updated to reflect the new keep_flags and - // shuffled back to the original ranks (for multi-gpu) - if (keep_count < keep_minors.size()) { - if constexpr (multi_gpu) { - keep_ranks = detail::keep_marked_entries( - handle, - std::move(*keep_ranks), - raft::device_span{keep_flags.data(), keep_flags.size()}, - keep_count); - } - keep_positions = detail::keep_marked_entries( - handle, - std::move(keep_positions), - raft::device_span{keep_flags.data(), keep_flags.size()}, - keep_count); - } - - if constexpr (multi_gpu) { - std::tie(std::ignore, keep_positions, std::ignore) = - groupby_gpu_id_and_shuffle_kv_pairs(handle.get_comms(), - keep_ranks->begin(), - keep_ranks->end(), - keep_positions.begin(), - cuda::std::identity{}, - handle.get_stream()); - } + shuffle_properties.clear(); + shuffle_properties.push_back(std::move(discarded_minors)); + if (has_edge_property) { shuffle_properties.push_back(std::move(discarded_edge_property)); } + if (has_types) { shuffle_properties.push_back(std::move(discarded_types)); } + if (has_labels) { shuffle_properties.push_back(std::move(*discarded_labels)); } - // 6. Now we can remove values from results arrays - { - rmm::device_uvector tmp(keep_positions.size(), handle.get_stream()); - thrust::gather(handle.get_thrust_policy(), - keep_positions.begin(), - keep_positions.end(), - result_majors.data(), - tmp.begin()); - result_majors = std::move(tmp); - } - { - rmm::device_uvector tmp(keep_positions.size(), handle.get_stream()); - thrust::gather(handle.get_thrust_policy(), - keep_positions.begin(), - keep_positions.end(), - result_minors.data(), - tmp.begin()); - result_minors = std::move(tmp); - } + std::tie(discarded_majors, shuffle_properties) = + shuffle_int_vertices(handle, + std::move(discarded_majors), + std::move(shuffle_properties), + graph_view.vertex_partition_range_lasts(), + std::nullopt); - if (!std::holds_alternative(tmp_edge_indices)) { - cugraph::variant_type_dispatch( - tmp_edge_indices, [&handle, &keep_positions](auto& property) { - using T = typename std::remove_reference::type::value_type; - rmm::device_uvector tmp(keep_positions.size(), handle.get_stream()); - thrust::gather(handle.get_thrust_policy(), - keep_positions.begin(), - keep_positions.end(), - property.data(), - tmp.begin()); - property = std::move(tmp); - }); - } - if (result_labels) { - rmm::device_uvector tmp(keep_positions.size(), handle.get_stream()); - thrust::gather(handle.get_thrust_policy(), - keep_positions.begin(), - keep_positions.end(), - result_labels->begin(), - tmp.begin()); - *result_labels = std::move(tmp); - } + discarded_minors = std::move(std::get>(shuffle_properties[0])); + shuffle_prop_idx = 1; + if (has_edge_property) { + discarded_edge_property = std::move(shuffle_properties[shuffle_prop_idx++]); + } + if (has_types) { discarded_types = std::move(shuffle_properties[shuffle_prop_idx++]); } + if (has_labels) { + discarded_labels = + std::move(std::get>(shuffle_properties[shuffle_prop_idx++])); } } return std::make_tuple(std::move(result_majors), std::move(result_minors), - std::move(tmp_edge_indices), + std::move(result_edge_property), std::move(result_labels), - std::move(resample_majors), - std::move(resample_major_labels)); + std::move(discarded_majors), + std::move(discarded_minors), + std::move(discarded_edge_property), + std::move(discarded_types), + std::move(discarded_labels)); } } // namespace detail diff --git a/cpp/src/sampling/detail/deduplicate_edges_by_minor_mg_v32_e32.cu b/cpp/src/sampling/detail/deduplicate_edges_by_minor_mg_v32_e32.cu index ac4fdaec4e..c84f3e9f52 100644 --- a/cpp/src/sampling/detail/deduplicate_edges_by_minor_mg_v32_e32.cu +++ b/cpp/src/sampling/detail/deduplicate_edges_by_minor_mg_v32_e32.cu @@ -14,7 +14,10 @@ template CUGRAPH_EXPORT std::tuple, rmm::device_uvector, arithmetic_device_uvector_t, std::optional>, - std::optional>, + rmm::device_uvector, + rmm::device_uvector, + arithmetic_device_uvector_t, + arithmetic_device_uvector_t, std::optional>> deduplicate_edges_by_minor( raft::handle_t const&, @@ -22,8 +25,8 @@ deduplicate_edges_by_minor( rmm::device_uvector&&, rmm::device_uvector&&, arithmetic_device_uvector_t&&, - std::optional>&&, - bool); + arithmetic_device_uvector_t&&, + std::optional>&&); } // namespace detail } // namespace cugraph diff --git a/cpp/src/sampling/detail/deduplicate_edges_by_minor_mg_v64_e64.cu b/cpp/src/sampling/detail/deduplicate_edges_by_minor_mg_v64_e64.cu index 9c5bb569ab..2c3bb4e4b2 100644 --- a/cpp/src/sampling/detail/deduplicate_edges_by_minor_mg_v64_e64.cu +++ b/cpp/src/sampling/detail/deduplicate_edges_by_minor_mg_v64_e64.cu @@ -14,7 +14,10 @@ template CUGRAPH_EXPORT std::tuple, rmm::device_uvector, arithmetic_device_uvector_t, std::optional>, - std::optional>, + rmm::device_uvector, + rmm::device_uvector, + arithmetic_device_uvector_t, + arithmetic_device_uvector_t, std::optional>> deduplicate_edges_by_minor( raft::handle_t const&, @@ -22,8 +25,8 @@ deduplicate_edges_by_minor( rmm::device_uvector&&, rmm::device_uvector&&, arithmetic_device_uvector_t&&, - std::optional>&&, - bool); + arithmetic_device_uvector_t&&, + std::optional>&&); } // namespace detail } // namespace cugraph diff --git a/cpp/src/sampling/detail/deduplicate_edges_by_minor_sg_v32_e32.cu b/cpp/src/sampling/detail/deduplicate_edges_by_minor_sg_v32_e32.cu index c75e6b255a..9c391f8ca6 100644 --- a/cpp/src/sampling/detail/deduplicate_edges_by_minor_sg_v32_e32.cu +++ b/cpp/src/sampling/detail/deduplicate_edges_by_minor_sg_v32_e32.cu @@ -14,7 +14,10 @@ template CUGRAPH_EXPORT std::tuple, rmm::device_uvector, arithmetic_device_uvector_t, std::optional>, - std::optional>, + rmm::device_uvector, + rmm::device_uvector, + arithmetic_device_uvector_t, + arithmetic_device_uvector_t, std::optional>> deduplicate_edges_by_minor( raft::handle_t const&, @@ -22,8 +25,8 @@ deduplicate_edges_by_minor( rmm::device_uvector&&, rmm::device_uvector&&, arithmetic_device_uvector_t&&, - std::optional>&&, - bool); + arithmetic_device_uvector_t&&, + std::optional>&&); } // namespace detail } // namespace cugraph diff --git a/cpp/src/sampling/detail/deduplicate_edges_by_minor_sg_v64_e64.cu b/cpp/src/sampling/detail/deduplicate_edges_by_minor_sg_v64_e64.cu index 59293b0280..7411792e09 100644 --- a/cpp/src/sampling/detail/deduplicate_edges_by_minor_sg_v64_e64.cu +++ b/cpp/src/sampling/detail/deduplicate_edges_by_minor_sg_v64_e64.cu @@ -14,7 +14,10 @@ template CUGRAPH_EXPORT std::tuple, rmm::device_uvector, arithmetic_device_uvector_t, std::optional>, - std::optional>, + rmm::device_uvector, + rmm::device_uvector, + arithmetic_device_uvector_t, + arithmetic_device_uvector_t, std::optional>> deduplicate_edges_by_minor( raft::handle_t const&, @@ -22,8 +25,8 @@ deduplicate_edges_by_minor( rmm::device_uvector&&, rmm::device_uvector&&, arithmetic_device_uvector_t&&, - std::optional>&&, - bool); + arithmetic_device_uvector_t&&, + std::optional>&&); } // namespace detail } // namespace cugraph diff --git a/cpp/src/sampling/detail/gather_one_hop_impl.cuh b/cpp/src/sampling/detail/gather_one_hop_impl.cuh index c1e9209b6d..a4a1281fac 100644 --- a/cpp/src/sampling/detail/gather_one_hop_impl.cuh +++ b/cpp/src/sampling/detail/gather_one_hop_impl.cuh @@ -401,15 +401,21 @@ gather_one_hop_edgelist_to_unvisited_neighbors( } } - std::tie( - result_majors, result_minors, tmp_edge_indices, result_labels, std::ignore, std::ignore) = - deduplicate_edges_by_minor(handle, - graph_view, - std::move(result_majors), - std::move(result_minors), - std::move(tmp_edge_indices), - std::move(result_labels), - false); + std::tie(result_majors, + result_minors, + tmp_edge_indices, + result_labels, + std::ignore, + std::ignore, + std::ignore, + std::ignore, + std::ignore) = deduplicate_edges_by_minor(handle, + graph_view, + std::move(result_majors), + std::move(result_minors), + std::move(tmp_edge_indices), + arithmetic_device_uvector_t{std::monostate{}}, + std::move(result_labels)); std::tie(visited_minors, visited_minor_labels) = detail::update_dst_visited_vertices_and_labels( diff --git a/cpp/src/sampling/detail/sample_edges.cuh b/cpp/src/sampling/detail/sample_edges.cuh index bbcaec8ed6..43af3fb349 100644 --- a/cpp/src/sampling/detail/sample_edges.cuh +++ b/cpp/src/sampling/detail/sample_edges.cuh @@ -9,6 +9,7 @@ #include "sampling_utils.hpp" #include +#include #include #include #include @@ -19,21 +20,33 @@ #include #include #include -#include +#include +#include #include #include #include +#include #include #include #include +#include +#include +#include +#include +#include +#include +#include +#include #include +#include #include #include #include +#include namespace cugraph { namespace detail { @@ -624,6 +637,43 @@ sample_with_one_property( std::move(majors), std::move(minors), std::move(sampled_properties), std::move(labels)); } +template +rmm::device_uvector gather_edge_types_for_sampled_edgelist( + raft::handle_t const& handle, + graph_view_t const& graph_view, + edge_property_view_t edge_type_view, + raft::device_span majors, + raft::device_span minors, + arithmetic_device_uvector_t& multi_edge_index) +{ + CUGRAPH_EXPECTS(std::holds_alternative>(multi_edge_index), + "Multi-edge indices must be of type edge_t"); + + using edge_type_t = int32_t; + + constexpr bool store_transposed = false; + + rmm::device_uvector edge_types(majors.size(), handle.get_stream()); + + cugraph::edge_bucket_t edge_list( + handle, graph_view.is_multigraph()); + + auto& indices = std::get>(multi_edge_index); + edge_list.insert( + majors.begin(), majors.end(), minors.begin(), std::make_optional(indices.begin())); + + cugraph::transform_gather_e(handle, + graph_view, + edge_list, + edge_src_dummy_property_t{}.view(), + edge_dst_dummy_property_t{}.view(), + edge_type_view, + return_edge_property_t{}, + edge_types.begin()); + + return edge_types; +} + template = 1, "Ks must be non-empty."); + + if (Ks.size() > 1) { + CUGRAPH_EXPECTS(edge_type_view.has_value(), "heterogeneous sampling requires edge_type_view."); + } + rmm::device_uvector result_majors(0, handle.get_stream()); rmm::device_uvector result_minors(0, handle.get_stream()); arithmetic_device_uvector_t result_properties{std::monostate{}}; std::optional> result_labels{std::nullopt}; bool sample_and_append{true}; - std::optional> resample_active_majors{std::nullopt}; - std::optional> resample_active_major_labels{std::nullopt}; + rmm::device_uvector carryover_frontier_majors(0, handle.get_stream()); + std::optional> carryover_frontier_labels{std::nullopt}; + std::optional> carryover_frontier_types{std::nullopt}; + rmm::device_uvector carryover_frontier_capacity(0, handle.get_stream()); - cugraph::key_bucket_view_t active_bucket_view = - key_bucket_view; + auto active_bucket_view = key_bucket_view; // FIXME: We could explore increasing the rate of convergency by oversampling to allow // for some duplicates to be discarded. This would allow some vertices to still have the // first sampling result be sufficient. For now we'll leave this as a future optimization. + size_t disjoint_resample_iteration = 0; while (sample_and_append) { + ++disjoint_resample_iteration; + std::optional> sampled_labels{std::nullopt}; rmm::device_uvector sampled_majors(0, handle.get_stream()); rmm::device_uvector sampled_minors(0, handle.get_stream()); @@ -687,8 +747,9 @@ sample_unvisited_with_one_property( edge_property_view, sample_unvisited_edge_biases_op_t{ raft::device_span{visited_minors.data(), visited_minors.size()}, - raft::device_span{visited_minor_labels->data(), - visited_minor_labels->size()}}, + visited_minor_labels ? cuda::std::make_optional(raft::device_span{ + visited_minor_labels->data(), visited_minor_labels->size()}) + : cuda::std::nullopt}, edge_type_view ? std::make_optional( std::get>( *edge_type_view)) @@ -710,8 +771,9 @@ sample_unvisited_with_one_property( edge_property_view, sample_unvisited_edge_biases_op_t{ raft::device_span{visited_minors.data(), visited_minors.size()}, - raft::device_span{visited_minor_labels->data(), - visited_minor_labels->size()}}, + visited_minor_labels ? cuda::std::make_optional(raft::device_span{ + visited_minor_labels->data(), visited_minor_labels->size()}) + : cuda::std::nullopt}, edge_type_view ? std::make_optional( std::get>( *edge_type_view)) @@ -733,8 +795,9 @@ sample_unvisited_with_one_property( edge_property_view, sample_unvisited_edge_biases_op_t{ raft::device_span{visited_minors.data(), visited_minors.size()}, - raft::device_span{visited_minor_labels->data(), - visited_minor_labels->size()}}, + visited_minor_labels ? cuda::std::make_optional(raft::device_span{ + visited_minor_labels->data(), visited_minor_labels->size()}) + : cuda::std::nullopt}, edge_type_view ? std::make_optional( std::get>( *edge_type_view)) @@ -745,19 +808,364 @@ sample_unvisited_with_one_property( with_replacement); } + arithmetic_device_uvector_t sampled_types{std::monostate{}}; + bool const gather_sampled_edge_types = + edge_type_view.has_value() && (Ks.size() > 1) && + !std::holds_alternative(sampled_property); + if (gather_sampled_edge_types) { + auto const edge_type_prop = + std::get>(*edge_type_view); + sampled_types = gather_edge_types_for_sampled_edgelist( + handle, + graph_view, + edge_type_prop, + raft::device_span{sampled_majors.data(), sampled_majors.size()}, + raft::device_span{sampled_minors.data(), sampled_minors.size()}, + sampled_property); + } + + if (carryover_frontier_capacity.size() > 0) { + rmm::device_uvector majors = std::move(sampled_majors); + rmm::device_uvector minors = std::move(sampled_minors); + arithmetic_device_uvector_t prop = std::move(sampled_property); + arithmetic_device_uvector_t types = std::move(sampled_types); + std::optional> labels = std::move(sampled_labels); + + rmm::device_uvector random_numbers = + rmm::device_uvector(majors.size(), handle.get_stream()); + uniform_random_fill( + handle.get_stream(), random_numbers.data(), random_numbers.size(), 0.0f, 1.0f, rng_state); + + size_t keep_count{0}; + rmm::device_uvector keep_flags(0, handle.get_stream()); + + if (carryover_frontier_types) { + auto& type_vec = std::get>(types); + + if (labels) { + thrust::sort_by_key( + handle.get_thrust_policy(), + thrust::make_zip_iterator( + labels->begin(), majors.begin(), type_vec.begin(), random_numbers.begin()), + thrust::make_zip_iterator( + labels->end(), majors.end(), type_vec.end(), random_numbers.end()), + minors.begin()); + + random_numbers.resize(0, handle.get_stream()); + random_numbers.shrink_to_fit(handle.get_stream()); + + std::tie(keep_count, keep_flags) = detail::mark_entries( + handle, + majors.size(), + cuda::proclaim_return_type( + [majors_size = majors.size(), + d_labels = labels->data(), + d_majors = majors.data(), + d_types = type_vec.data(), + carry_frontier_size = carryover_frontier_majors.size(), + carry_frontier_labels = carryover_frontier_labels->data(), + carry_frontier_majors = carryover_frontier_majors.data(), + carry_frontier_types = carryover_frontier_types->data(), + carry_frontier_capacity = carryover_frontier_capacity.data()] __device__(size_t i) { + auto key = cuda::std::make_tuple(d_labels[i], d_majors[i], d_types[i]); + auto carry_frontier_begin = thrust::make_zip_iterator( + carry_frontier_labels, carry_frontier_majors, carry_frontier_types); + auto lb = thrust::lower_bound(thrust::seq, + carry_frontier_begin, + carry_frontier_begin + carry_frontier_size, + key); + + auto pos = cuda::std::distance(carry_frontier_begin, lb); + + if ((pos == carry_frontier_size) || (*lb != key)) { return false; } + + auto needed_count = carry_frontier_capacity[pos]; + + auto d_begin = thrust::make_zip_iterator(d_labels, d_majors, d_types); + auto lb2 = thrust::lower_bound(thrust::seq, d_begin, d_begin + majors_size, key); + + auto position_count = (i - cuda::std::distance(d_begin, lb2)); + return position_count < needed_count; + }), + std::nullopt); + + } else { + thrust::sort_by_key( + handle.get_thrust_policy(), + thrust::make_zip_iterator(majors.begin(), type_vec.begin(), random_numbers.begin()), + thrust::make_zip_iterator(majors.end(), type_vec.end(), random_numbers.end()), + minors.begin()); + + random_numbers.resize(0, handle.get_stream()); + random_numbers.shrink_to_fit(handle.get_stream()); + + std::tie(keep_count, keep_flags) = detail::mark_entries( + handle, + majors.size(), + cuda::proclaim_return_type( + [majors_size = majors.size(), + d_majors = majors.data(), + d_types = type_vec.data(), + carry_frontier_size = carryover_frontier_majors.size(), + carry_frontier_majors = carryover_frontier_majors.data(), + carry_frontier_types = carryover_frontier_types->data(), + carry_frontier_capacity = carryover_frontier_capacity.data()] __device__(size_t i) { + auto key = cuda::std::make_tuple(d_majors[i], d_types[i]); + auto carry_frontier_begin = + thrust::make_zip_iterator(carry_frontier_majors, carry_frontier_types); + auto lb = thrust::lower_bound(thrust::seq, + carry_frontier_begin, + carry_frontier_begin + carry_frontier_size, + key); + + auto pos = cuda::std::distance(carry_frontier_begin, lb); + + if ((pos == carry_frontier_size) || (*lb != key)) { return false; } + + auto needed_count = carry_frontier_capacity[pos]; + + auto d_begin = thrust::make_zip_iterator(d_majors, d_types); + auto lb2 = thrust::lower_bound(thrust::seq, d_begin, d_begin + majors_size, key); + + auto position_count = (i - cuda::std::distance(d_begin, lb2)); + return position_count < needed_count; + }), + std::nullopt); + } + } else { + if (labels) { + thrust::sort_by_key( + handle.get_thrust_policy(), + thrust::make_zip_iterator(labels->begin(), majors.begin(), random_numbers.begin()), + thrust::make_zip_iterator(labels->end(), majors.end(), random_numbers.end()), + minors.begin()); + + random_numbers.resize(0, handle.get_stream()); + random_numbers.shrink_to_fit(handle.get_stream()); + + std::tie(keep_count, keep_flags) = detail::mark_entries( + handle, + majors.size(), + cuda::proclaim_return_type( + [majors_size = majors.size(), + d_labels = labels->data(), + d_majors = majors.data(), + carry_frontier_size = carryover_frontier_majors.size(), + carry_frontier_labels = carryover_frontier_labels->data(), + carry_frontier_majors = carryover_frontier_majors.data(), + carry_frontier_capacity = carryover_frontier_capacity.data()] __device__(size_t i) { + auto key = cuda::std::make_tuple(d_labels[i], d_majors[i]); + auto carry_frontier_begin = + thrust::make_zip_iterator(carry_frontier_labels, carry_frontier_majors); + auto lb = thrust::lower_bound(thrust::seq, + carry_frontier_begin, + carry_frontier_begin + carry_frontier_size, + key); + + auto pos = cuda::std::distance(carry_frontier_begin, lb); + + if ((pos == carry_frontier_size) || (*lb != key)) { return false; } + + auto needed_count = carry_frontier_capacity[pos]; + + auto d_begin = thrust::make_zip_iterator(d_labels, d_majors); + auto lb2 = thrust::lower_bound(thrust::seq, d_begin, d_begin + majors_size, key); + + auto position_count = (i - cuda::std::distance(d_begin, lb2)); + return position_count < needed_count; + }), + std::nullopt); + + } else { + thrust::sort_by_key(handle.get_thrust_policy(), + thrust::make_zip_iterator(majors.begin(), random_numbers.begin()), + thrust::make_zip_iterator(majors.end(), random_numbers.end()), + minors.begin()); + + random_numbers.resize(0, handle.get_stream()); + random_numbers.shrink_to_fit(handle.get_stream()); + + std::tie(keep_count, keep_flags) = detail::mark_entries( + handle, + majors.size(), + cuda::proclaim_return_type( + [majors_size = majors.size(), + d_majors = majors.data(), + carry_frontier_size = carryover_frontier_majors.size(), + carry_frontier_majors = carryover_frontier_majors.data(), + carry_frontier_capacity = carryover_frontier_capacity.data()] __device__(size_t i) { + auto key = d_majors[i]; + auto lb = thrust::lower_bound(thrust::seq, + carry_frontier_majors, + carry_frontier_majors + carry_frontier_size, + key); + + auto pos = cuda::std::distance(carry_frontier_majors, lb); + + if ((pos == carry_frontier_size) || (*lb != key)) { return false; } + + auto needed_count = carry_frontier_capacity[pos]; + + auto lb2 = thrust::lower_bound(thrust::seq, d_majors, d_majors + majors_size, key); + + auto position_count = (i - cuda::std::distance(d_majors, lb2)); + return position_count < needed_count; + }), + std::nullopt); + } + } + + raft::device_span const keep_mask{keep_flags.data(), keep_flags.size()}; + majors = detail::keep_marked_entries(handle, std::move(majors), keep_mask, keep_count); + minors = detail::keep_marked_entries(handle, std::move(minors), keep_mask, keep_count); + if (carryover_frontier_types) { + types = arithmetic_device_uvector_t{detail::keep_marked_entries( + handle, std::move(std::get>(types)), keep_mask, keep_count)}; + } + if (!std::holds_alternative(prop)) { + prop = cugraph::variant_type_dispatch(prop, [&](auto& index_vec) { + return arithmetic_device_uvector_t{ + detail::keep_marked_entries(handle, std::move(index_vec), keep_mask, keep_count)}; + }); + } + if (labels) { + *labels = detail::keep_marked_entries(handle, std::move(*labels), keep_mask, keep_count); + } + + sampled_majors = std::move(majors); + sampled_minors = std::move(minors); + sampled_property = std::move(prop); + sampled_types = std::move(types); + sampled_labels = std::move(labels); + } + // Check for duplicates in the sampled minor vertices + [[maybe_unused]] rmm::device_uvector discarded_minors(0, handle.get_stream()); + [[maybe_unused]] arithmetic_device_uvector_t discarded_edge_property{std::monostate{}}; + [[maybe_unused]] arithmetic_device_uvector_t discarded_types{std::monostate{}}; + rmm::device_uvector discarded_majors(0, handle.get_stream()); + std::optional> discarded_major_labels{std::nullopt}; + std::tie(sampled_majors, sampled_minors, sampled_property, sampled_labels, - resample_active_majors, - resample_active_major_labels) = deduplicate_edges_by_minor(handle, - graph_view, - std::move(sampled_majors), - std::move(sampled_minors), - std::move(sampled_property), - std::move(sampled_labels), - true); + discarded_majors, + discarded_minors, + discarded_edge_property, + discarded_types, + discarded_major_labels) = deduplicate_edges_by_minor(handle, + graph_view, + std::move(sampled_majors), + std::move(sampled_minors), + std::move(sampled_property), + std::move(sampled_types), + std::move(sampled_labels)); + + carryover_frontier_labels = std::nullopt; + carryover_frontier_types = std::nullopt; + carryover_frontier_majors = rmm::device_uvector(0, handle.get_stream()); + carryover_frontier_capacity = rmm::device_uvector(0, handle.get_stream()); + + if (discarded_majors.size() != 0) { + size_t const num_types = Ks.size(); + + rmm::device_uvector agg_majors = std::move(discarded_majors); + std::optional> agg_labels = std::move(discarded_major_labels); + + if (num_types == size_t{1}) { + if (agg_labels) { + thrust::sort(handle.get_thrust_policy(), + thrust::make_zip_iterator(agg_labels->begin(), agg_majors.begin()), + thrust::make_zip_iterator(agg_labels->end(), agg_majors.end())); + } else { + cugraph::detail::sort_ints( + handle, raft::device_span{agg_majors.data(), agg_majors.size()}); + } + + rmm::device_uvector agg_counts(agg_majors.size(), handle.get_stream()); + + if (agg_labels) { + auto ends = thrust::reduce_by_key( + handle.get_thrust_policy(), + thrust::make_zip_iterator(agg_labels->begin(), agg_majors.begin()), + thrust::make_zip_iterator(agg_labels->end(), agg_majors.end()), + thrust::make_constant_iterator(size_t{1}), + thrust::make_zip_iterator(agg_labels->begin(), agg_majors.begin()), + agg_counts.begin()); + agg_labels->resize( + static_cast(cuda::std::distance( + thrust::make_zip_iterator(agg_labels->begin(), agg_majors.begin()), ends.first)), + handle.get_stream()); + agg_majors.resize(agg_labels->size(), handle.get_stream()); + agg_counts.resize(agg_labels->size(), handle.get_stream()); + carryover_frontier_labels = std::move(agg_labels); + carryover_frontier_majors = std::move(agg_majors); + } else { + auto ends = thrust::reduce_by_key(handle.get_thrust_policy(), + agg_majors.begin(), + agg_majors.end(), + thrust::make_constant_iterator(size_t{1}), + agg_majors.begin(), + agg_counts.begin()); + agg_majors.resize( + static_cast(cuda::std::distance(agg_majors.begin(), ends.first)), + handle.get_stream()); + agg_counts.resize(agg_majors.size(), handle.get_stream()); + carryover_frontier_majors = std::move(agg_majors); + } + carryover_frontier_capacity = std::move(agg_counts); + } else { + rmm::device_uvector types = + std::get>(std::move(discarded_types)); + + if (agg_labels) { + thrust::sort( + handle.get_thrust_policy(), + thrust::make_zip_iterator(agg_labels->begin(), agg_majors.begin(), types.begin()), + thrust::make_zip_iterator(agg_labels->end(), agg_majors.end(), types.end())); + } else { + thrust::sort(handle.get_thrust_policy(), + thrust::make_zip_iterator(agg_majors.begin(), types.begin()), + thrust::make_zip_iterator(agg_majors.end(), types.end())); + } + + rmm::device_uvector kr_cnt(agg_majors.size(), handle.get_stream()); + + size_t nt = 0; + if (agg_labels) { + auto zip_keys_begin = + thrust::make_zip_iterator(agg_labels->begin(), agg_majors.begin(), types.begin()); + auto ends = thrust::reduce_by_key( + handle.get_thrust_policy(), + zip_keys_begin, + thrust::make_zip_iterator(agg_labels->end(), agg_majors.end(), types.end()), + thrust::make_constant_iterator(size_t{1}), + zip_keys_begin, + kr_cnt.begin()); + nt = static_cast(cuda::std::distance(kr_cnt.begin(), ends.second)); + } else { + auto zip_keys_begin = thrust::make_zip_iterator(agg_majors.begin(), types.begin()); + auto ends = + thrust::reduce_by_key(handle.get_thrust_policy(), + zip_keys_begin, + thrust::make_zip_iterator(agg_majors.end(), types.end()), + thrust::make_constant_iterator(size_t{1}), + zip_keys_begin, + kr_cnt.begin()); + nt = static_cast(cuda::std::distance(kr_cnt.begin(), ends.second)); + } + if (agg_labels) { agg_labels->resize(nt, handle.get_stream()); } + agg_majors.resize(nt, handle.get_stream()); + types.resize(nt, handle.get_stream()); + kr_cnt.resize(nt, handle.get_stream()); + + carryover_frontier_majors = std::move(agg_majors); + carryover_frontier_labels = std::move(agg_labels); + carryover_frontier_types = std::make_optional(std::move(types)); + carryover_frontier_capacity = std::move(kr_cnt); + } + } std::tie(visited_minors, visited_minor_labels) = detail::update_dst_visited_vertices_and_labels( @@ -771,31 +1179,42 @@ sample_unvisited_with_one_property( : std::nullopt); if constexpr (multi_gpu) { - sample_and_append = - (host_scalar_allreduce(handle.get_comms(), - (resample_active_majors ? resample_active_majors->size() : 0), - raft::comms::op_t::SUM, - handle.get_stream()) > 0); + sample_and_append = (host_scalar_allreduce(handle.get_comms(), + carryover_frontier_majors.size(), + raft::comms::op_t::SUM, + handle.get_stream()) > 0); } else { - sample_and_append = (resample_active_majors ? resample_active_majors->size() : 0) > 0; + sample_and_append = carryover_frontier_majors.size() > 0; } if (sample_and_append) { - if constexpr (std::is_same_v) { - active_bucket_view = cugraph::key_bucket_view_t( - handle, - raft::device_span{resample_active_majors->data(), - resample_active_majors->size()}); + if (carryover_frontier_majors.size() > 0) { + if constexpr (std::is_same_v) { + active_bucket_view = cugraph::key_bucket_view_t( + handle, + raft::device_span{carryover_frontier_majors.data(), + carryover_frontier_majors.size()}); + } else { + active_bucket_view = cugraph::key_bucket_view_t( + handle, + raft::device_span{carryover_frontier_majors.data(), + carryover_frontier_majors.size()}, + raft::device_span{carryover_frontier_labels->data(), + carryover_frontier_labels->size()}); + handle.sync_stream(); + active_major_labels = raft::device_span{carryover_frontier_labels->data(), + carryover_frontier_labels->size()}; + } } else { - active_bucket_view = cugraph::key_bucket_view_t( - handle, - raft::device_span{resample_active_majors->data(), - resample_active_majors->size()}, - raft::device_span{resample_active_major_labels->data(), - resample_active_major_labels->size()}); - handle.sync_stream(); - active_major_labels = raft::device_span{ - resample_active_major_labels->data(), resample_active_major_labels->size()}; + if constexpr (std::is_same_v) { + active_bucket_view = cugraph::key_bucket_view_t( + handle, raft::device_span(nullptr, size_t(0))); + } else { + active_bucket_view = cugraph::key_bucket_view_t( + handle, + raft::device_span(nullptr, size_t(0)), + raft::device_span(nullptr, size_t(0))); + } } } diff --git a/cpp/src/sampling/detail/sampling_utils.hpp b/cpp/src/sampling/detail/sampling_utils.hpp index 3ccdfcedec..c784d45b71 100644 --- a/cpp/src/sampling/detail/sampling_utils.hpp +++ b/cpp/src/sampling/detail/sampling_utils.hpp @@ -494,11 +494,10 @@ update_dst_visited_vertices_and_labels( * Keeps one edge per (label, minor) or per minor. Callers should merge the returned minors into * visited sets via update_dst_visited_vertices_and_labels. * - * When call_from_sampling is true, we need to skip all edges that come from any major vertex that - * we are going to skip. So this function will return the majors that need to be resampled. - * - * FIXME: We should eliminate the call_from_sampling flag by refactoring - * sample_edges_to_unvisited_neighbors to use the partial results we return here. + * When edges are removed, the last four return values contain the discarded edges (majors, minors, + * edge_property, labels in lockstep). Callers such as sample_unvisited_with_one_property use + * those to build the resample frontier. When there are no duplicates, discarded bundles are empty + * (length zero / monostate / nullopt). * * @tparam vertex_t Type of vertex identifiers. Needs to be an integral type. * @tparam edge_t Type of edge identifiers. Needs to be an integral type. @@ -507,28 +506,33 @@ update_dst_visited_vertices_and_labels( * @param graph_view Graph View object for context (partitioning, MG routing, etc.). * @param result_majors Device vector of edge major (source) vertices. * @param result_minors Device vector of edge minor (destination) vertices. - * @param tmp_edge_indices Multi-edge indices (or single property column) to filter in lockstep; - * arithmetic_device_uvector_t (monostate or device vector). Returned as filtered. + * @param result_edge_property Per-edge property column in lockstep with majors/minors + * (monostate if none, or rmm::device_uvector multi-edge index when present). + * @param result_types Optional per-edge edge-type column (monostate or int32_t), in lockstep when + * present for sort/split; only discarded_types is returned, not kept types. * @param result_labels Optional device vector of labels per edge. - * @param call_from_sampling If true, the last two return values are populated with majors (and - * optional major labels) that need to be resampled (edges to those were removed). - * @return Tuple of filtered (result_majors, result_minors, tmp_edge_indices, result_labels), and - * when call_from_sampling is true, optional (resample_majors, resample_major_labels). + * @return Tuple of kept (majors, minors, result_edge_property, labels), then discarded (majors, + * minors, discarded_edge_property, discarded_types, labels). Kept rows are the first edge + * per (label, minor) or per minor after sorting by that key. Callers that do not need + * discards may bind the last five tuple elements to `std::ignore`. */ template std::tuple, rmm::device_uvector, arithmetic_device_uvector_t, std::optional>, - std::optional>, + rmm::device_uvector, + rmm::device_uvector, + arithmetic_device_uvector_t, + arithmetic_device_uvector_t, std::optional>> deduplicate_edges_by_minor(raft::handle_t const& handle, graph_view_t const& graph_view, rmm::device_uvector&& result_majors, rmm::device_uvector&& result_minors, - arithmetic_device_uvector_t&& tmp_edge_indices, - std::optional>&& result_labels, - bool call_from_sampling); + arithmetic_device_uvector_t&& result_edge_property, + arithmetic_device_uvector_t&& result_types, + std::optional>&& result_labels); } // namespace detail } // namespace cugraph diff --git a/cpp/src/sampling/neighbor_sampling_impl.hpp b/cpp/src/sampling/neighbor_sampling_impl.hpp index 30c0efd269..367ef672c4 100644 --- a/cpp/src/sampling/neighbor_sampling_impl.hpp +++ b/cpp/src/sampling/neighbor_sampling_impl.hpp @@ -661,9 +661,6 @@ homogeneous_uniform_neighbor_sample( CUGRAPH_EXPECTS(!(sampling_flags.with_replacement && sampling_flags.disjoint_sampling), "Invalid input argument: disjoint sampling and sampling with replacement are " "mutually exclusive."); - CUGRAPH_EXPECTS(!(sampling_flags.disjoint_sampling && graph_view.is_multigraph()), - "Invalid input argument: disjoint sampling is not supported for multi-graphs."); - auto [majors, minors, weights, edge_ids, edge_types, hops, labels, offsets] = detail::neighbor_sample_impl( handle, @@ -724,9 +721,6 @@ heterogeneous_uniform_neighbor_sample( CUGRAPH_EXPECTS(!(sampling_flags.with_replacement && sampling_flags.disjoint_sampling), "Invalid input argument: disjoint sampling and sampling with replacement are " "mutually exclusive."); - CUGRAPH_EXPECTS(!(sampling_flags.disjoint_sampling && graph_view.is_multigraph()), - "Invalid input argument: disjoint sampling is not supported for multi-graphs."); - auto [majors, minors, weights, edge_ids, edge_types, hops, labels, offsets] = detail::neighbor_sample_impl( handle, @@ -786,9 +780,6 @@ homogeneous_biased_neighbor_sample( CUGRAPH_EXPECTS(!(sampling_flags.with_replacement && sampling_flags.disjoint_sampling), "Invalid input argument: disjoint sampling and sampling with replacement are " "mutually exclusive."); - CUGRAPH_EXPECTS(!(sampling_flags.disjoint_sampling && graph_view.is_multigraph()), - "Invalid input argument: disjoint sampling is not supported for multi-graphs."); - auto [majors, minors, weights, edge_ids, edge_types, hops, labels, offsets] = detail::neighbor_sample_impl( handle, @@ -848,9 +839,6 @@ heterogeneous_biased_neighbor_sample( CUGRAPH_EXPECTS(!(sampling_flags.with_replacement && sampling_flags.disjoint_sampling), "Invalid input argument: disjoint sampling and sampling with replacement are " "mutually exclusive."); - CUGRAPH_EXPECTS(!(sampling_flags.disjoint_sampling && graph_view.is_multigraph()), - "Invalid input argument: disjoint sampling is not supported for multi-graphs."); - auto [majors, minors, weights, edge_ids, edge_types, hops, labels, offsets] = detail::neighbor_sample_impl( handle, diff --git a/cpp/src/sampling/temporal_sampling_impl.hpp b/cpp/src/sampling/temporal_sampling_impl.hpp index 93f02f8129..cbcda34d9f 100644 --- a/cpp/src/sampling/temporal_sampling_impl.hpp +++ b/cpp/src/sampling/temporal_sampling_impl.hpp @@ -112,6 +112,11 @@ temporal_neighbor_sample_impl( "Invalid input argument: number of levels should not overflow int32_t"); // as we use int32_t // to store hops + CUGRAPH_EXPECTS( + !sampling_flags.disjoint_sampling, + "Invalid input argument: disjoint sampling is not supported for temporal neighbor " + "sampling."); + // Get the number of hop. auto num_hops = raft::div_rounding_up_safe( fan_out.size(), static_cast(num_edge_types ? *num_edge_types : edge_type_t{1})); @@ -215,27 +220,6 @@ temporal_neighbor_sample_impl( rmm::device_uvector frontier_vertex_times_has_duplicates(0, handle.get_stream()); std::optional> frontier_vertex_labels_has_duplicates{std::nullopt}; - auto visited_minors = - (sampling_flags.disjoint_sampling) - ? std::make_optional(rmm::device_uvector(0, handle.get_stream())) - : std::nullopt; - - if (visited_minors) { - thrust::fill(handle.get_thrust_policy(), - visited_minors->begin(), - visited_minors->end(), - cugraph::invalid_vertex_id()); - - thrust::scatter( - handle.get_thrust_policy(), - starting_vertices.begin(), - starting_vertices.end(), - thrust::make_transform_iterator( - starting_vertices.begin(), - detail::shift_left_t{graph_view.local_vertex_partition_range_first()}), - visited_minors->begin()); - } - for (size_t hop = 0; hop < num_hops; ++hop) { std::optional> level_Ks{std::nullopt}; std::unique_ptr gather_flags{}; @@ -902,9 +886,6 @@ homogeneous_uniform_temporal_neighbor_sample( CUGRAPH_EXPECTS(!(sampling_flags.with_replacement && sampling_flags.disjoint_sampling), "Invalid input argument: disjoint sampling and sampling with replacement are " "mutually exclusive."); - CUGRAPH_EXPECTS(!(sampling_flags.disjoint_sampling && graph_view.is_multigraph()), - "Invalid input argument: disjoint sampling is not supported for multi-graphs."); - return detail:: temporal_neighbor_sample_impl( handle, @@ -966,9 +947,6 @@ heterogeneous_uniform_temporal_neighbor_sample( CUGRAPH_EXPECTS(!(sampling_flags.with_replacement && sampling_flags.disjoint_sampling), "Invalid input argument: disjoint sampling and sampling with replacement are " "mutually exclusive."); - CUGRAPH_EXPECTS(!(sampling_flags.disjoint_sampling && graph_view.is_multigraph()), - "Invalid input argument: disjoint sampling is not supported for multi-graphs."); - return detail:: temporal_neighbor_sample_impl( handle, @@ -1029,9 +1007,6 @@ homogeneous_biased_temporal_neighbor_sample( CUGRAPH_EXPECTS(!(sampling_flags.with_replacement && sampling_flags.disjoint_sampling), "Invalid input argument: disjoint sampling and sampling with replacement are " "mutually exclusive."); - CUGRAPH_EXPECTS(!(sampling_flags.disjoint_sampling && graph_view.is_multigraph()), - "Invalid input argument: disjoint sampling is not supported for multi-graphs."); - return detail:: temporal_neighbor_sample_impl( handle, @@ -1092,9 +1067,6 @@ heterogeneous_biased_temporal_neighbor_sample( CUGRAPH_EXPECTS(!(sampling_flags.with_replacement && sampling_flags.disjoint_sampling), "Invalid input argument: disjoint sampling and sampling with replacement are " "mutually exclusive."); - CUGRAPH_EXPECTS(!(sampling_flags.disjoint_sampling && graph_view.is_multigraph()), - "Invalid input argument: disjoint sampling is not supported for multi-graphs."); - return detail:: temporal_neighbor_sample_impl( handle, diff --git a/cpp/tests/c_api/mg_test_utils.cpp b/cpp/tests/c_api/mg_test_utils.cpp index a9c9223a92..b822f4374c 100644 --- a/cpp/tests/c_api/mg_test_utils.cpp +++ b/cpp/tests/c_api/mg_test_utils.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -21,6 +21,10 @@ #include +#include +#include +#include + namespace { template raft::device_span make_span(cugraph_type_erased_device_array_view_t const* view) @@ -1134,6 +1138,69 @@ int mg_validate_sample_result(const cugraph_resource_handle_t* handle, raft::update_host( h_result_dsts, renumbered_dsts.data(), renumbered_dsts.size(), raft_handle.get_stream()); + // Global disjoint validation uses a concatenation of each rank's local start list (same layout as + // device_gatherv elsewhere in this file). This gatherv is collective — all ranks must call it + // whenever disjoint sampling is enabled, not only rank 0 (see disjoint validation block below). + rmm::device_uvector gathered_starts(0, raft_handle.get_stream()); + if (internal_sampling_options->disjoint_sampling_ == TRUE) { + rmm::device_uvector d_starting_vertices(num_start_vertices, raft_handle.get_stream()); + if (num_start_vertices > 0) { + raft::update_device( + d_starting_vertices.data(), h_start_vertices, num_start_vertices, raft_handle.get_stream()); + } + gathered_starts = cugraph::test::device_gatherv( + raft_handle, + raft::device_span{d_starting_vertices.data(), d_starting_vertices.size()}); + } + + if ((test_ret_value == 0) && (internal_sampling_options->disjoint_sampling_ == TRUE) && + (cugraph_resource_handle_get_rank(handle) == 0)) { + std::optional> label_offsets_dev{std::nullopt}; + std::optional> batch_nums_dev{std::nullopt}; + rmm::device_uvector d_label_offsets(0, raft_handle.get_stream()); + rmm::device_uvector d_batch_nums(0, raft_handle.get_stream()); + std::vector h_batch_nums_host; + + if (h_start_label_offsets != NULL && num_start_label_offsets > 0) { + d_label_offsets.resize(num_start_label_offsets, raft_handle.get_stream()); + raft::update_device(d_label_offsets.data(), + h_start_label_offsets, + num_start_label_offsets, + raft_handle.get_stream()); + label_offsets_dev = + raft::device_span{d_label_offsets.data(), d_label_offsets.size()}; + + auto const global_num_starts = gathered_starts.size(); + d_batch_nums.resize(global_num_starts, raft_handle.get_stream()); + h_batch_nums_host.resize(global_num_starts); + for (size_t i = 0; i < global_num_starts; ++i) { + int32_t lab = -1; + for (size_t j = 0; j + 1 < num_start_label_offsets; ++j) { + if (i >= h_start_label_offsets[j] && i < h_start_label_offsets[j + 1]) { + lab = static_cast(j); + break; + } + } + h_batch_nums_host[i] = lab; + } + raft::update_device( + d_batch_nums.data(), h_batch_nums_host.data(), global_num_starts, raft_handle.get_stream()); + batch_nums_dev = raft::device_span{d_batch_nums.data(), d_batch_nums.size()}; + } + + raft_handle.sync_stream(); + + bool const disjoint_ok = cugraph::test::validate_disjoint_sampling( + raft_handle, + raft::device_span{renumbered_srcs.data(), renumbered_srcs.size()}, + raft::device_span{renumbered_dsts.data(), renumbered_dsts.size()}, + raft::device_span{gathered_starts.data(), gathered_starts.size()}, + label_offsets_dev, + batch_nums_dev); + + TEST_ASSERT(test_ret_value, disjoint_ok, "validate_disjoint_sampling failed"); + } + original_srcs = cugraph::test::device_gatherv( raft_handle, raft::device_span{original_srcs.data(), original_srcs.size()}); raft::update_host( diff --git a/cpp/tests/c_api/mg_uniform_neighbor_sample_test.c b/cpp/tests/c_api/mg_uniform_neighbor_sample_test.c index 3ae605f3d5..43953ac7db 100644 --- a/cpp/tests/c_api/mg_uniform_neighbor_sample_test.c +++ b/cpp/tests/c_api/mg_uniform_neighbor_sample_test.c @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -10,6 +10,7 @@ #include #include +#include #include typedef int32_t vertex_t; @@ -141,8 +142,6 @@ int generic_uniform_neighbor_sample_test(const cugraph_resource_handle_t* handle TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, cugraph_error_message(ret_error)); TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "uniform_neighbor_sample failed."); - cugraph_sampling_options_free(sampling_options); - test_ret_value = mg_validate_sample_result(handle, result, h_src, @@ -163,6 +162,156 @@ int generic_uniform_neighbor_sample_test(const cugraph_resource_handle_t* handle sampling_options, FALSE); TEST_ASSERT(test_ret_value, test_ret_value == 0, "validate_sample_result failed."); + + cugraph_sampling_options_free(sampling_options); + cugraph_rng_state_free(rng_state); + cugraph_type_erased_device_array_view_free(d_start_view); + if (d_start_label_offsets != NULL) { + cugraph_type_erased_device_array_view_free(d_start_label_offsets_view); + cugraph_type_erased_device_array_free(d_start_label_offsets); + } + cugraph_type_erased_host_array_view_free(h_fan_out_view); + cugraph_type_erased_device_array_free(d_start); + } + + cugraph_sample_result_free(result); + cugraph_graph_free(graph); + + cugraph_error_free(ret_error); + return test_ret_value; +} + +/* + * Disjoint sampling on a multigraph (MG CAPI): parallel edges with is_multigraph=true. + * Single seed vertex on rank 0 only (other ranks get an empty start list), which exercises + * empty local partitions while the owning rank still drives resample / collectives. + */ +int generic_uniform_neighbor_sample_disjoint_multigraph_mg_test( + const cugraph_resource_handle_t* handle, + vertex_t* h_src, + vertex_t* h_dst, + weight_t* h_wgt, + edge_t* h_edge_ids, + int32_t* h_edge_types, + size_t num_vertices, + size_t num_edges, + vertex_t* h_start, + size_t num_start_vertices, + int* fan_out, + size_t fan_out_size, + bool_t with_replacement, + bool_t return_hops, + cugraph_prior_sources_behavior_t prior_sources_behavior, + bool_t dedupe_sources, + bool_t renumber_results) +{ + int test_ret_value = 0; + cugraph_error_code_t ret_code = CUGRAPH_SUCCESS; + cugraph_error_t* ret_error = NULL; + cugraph_graph_t* graph = NULL; + cugraph_sample_result_t* result = NULL; + + int rank = cugraph_resource_handle_get_rank(handle); + + ret_code = create_mg_test_graph_new(handle, + vertex_tid, + edge_tid, + h_src, + h_dst, + weight_tid, + h_wgt, + edge_type_tid, + h_edge_types, + edge_id_tid, + h_edge_ids, + edge_time_tid, + NULL, + NULL, + num_edges, + FALSE, + TRUE, + FALSE, + TRUE, + &graph, + &ret_error); + + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "graph creation failed."); + + if (test_ret_value == 0) { + cugraph_type_erased_device_array_t* d_start = NULL; + cugraph_type_erased_device_array_view_t* d_start_view = NULL; + cugraph_type_erased_host_array_view_t* h_fan_out_view = NULL; + + if (rank > 0) num_start_vertices = 0; + + h_fan_out_view = cugraph_type_erased_host_array_view_create(fan_out, fan_out_size, INT32); + + ret_code = cugraph_type_erased_device_array_create( + handle, num_start_vertices, INT32, &d_start, &ret_error); + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "d_start create failed."); + + d_start_view = cugraph_type_erased_device_array_view(d_start); + + ret_code = cugraph_type_erased_device_array_view_copy_from_host( + handle, d_start_view, (byte_t*)h_start, &ret_error); + TEST_ASSERT( + test_ret_value, ret_code == CUGRAPH_SUCCESS, "start vertices copy_from_host failed."); + + cugraph_rng_state_t* rng_state; + ret_code = cugraph_rng_state_create(handle, rank, &rng_state, &ret_error); + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "rng_state create failed."); + + cugraph_sampling_options_t* sampling_options; + + ret_code = cugraph_sampling_options_create(&sampling_options, &ret_error); + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "sampling_options create failed."); + + cugraph_sampling_set_with_replacement(sampling_options, with_replacement); + cugraph_sampling_set_return_hops(sampling_options, return_hops); + cugraph_sampling_set_prior_sources_behavior(sampling_options, prior_sources_behavior); + cugraph_sampling_set_dedupe_sources(sampling_options, dedupe_sources); + cugraph_sampling_set_renumber_results(sampling_options, renumber_results); + cugraph_sampling_set_disjoint_sampling(sampling_options, TRUE); + + ret_code = cugraph_homogeneous_uniform_neighbor_sample(handle, + rng_state, + graph, + d_start_view, + NULL, + h_fan_out_view, + sampling_options, + FALSE, + &result, + &ret_error); + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, cugraph_error_message(ret_error)); + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "uniform_neighbor_sample failed."); + + test_ret_value = mg_validate_sample_result(handle, + result, + h_src, + h_dst, + h_wgt, + h_edge_ids, + h_edge_types, + NULL, + NULL, + num_vertices, + num_edges, + h_start, + num_start_vertices, + NULL, + 0, + fan_out, + fan_out_size, + sampling_options, + FALSE); + TEST_ASSERT(test_ret_value, test_ret_value == 0, "validate_sample_result failed."); + + cugraph_sampling_options_free(sampling_options); + cugraph_rng_state_free(rng_state); + cugraph_type_erased_device_array_view_free(d_start_view); + cugraph_type_erased_host_array_view_free(h_fan_out_view); + cugraph_type_erased_device_array_free(d_start); } cugraph_sample_result_free(result); @@ -616,6 +765,228 @@ int test_uniform_neighbor_sample_carry_over_sources(const cugraph_resource_handl FALSE); } +int test_uniform_neighbor_sample_disjoint_multigraph(const cugraph_resource_handle_t* handle) +{ + size_t num_edges = 10; + size_t num_vertices = 6; + size_t fan_out_size = 3; + size_t num_starts = 1; + + vertex_t src[] = {0, 0, 0, 1, 1, 2, 2, 2, 3, 4}; + vertex_t dst[] = {1, 1, 3, 3, 4, 0, 1, 3, 5, 5}; + edge_t edge_ids[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + weight_t weight[] = {0.05f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f}; + int32_t edge_types[] = {9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; + vertex_t start[] = {0}; + int fan_out[] = {3, 3, 3}; + + bool_t with_replacement = FALSE; + bool_t return_hops = TRUE; + cugraph_prior_sources_behavior_t prior_sources_behavior = DEFAULT; + bool_t dedupe_sources = FALSE; + bool_t renumber_results = FALSE; + + return generic_uniform_neighbor_sample_disjoint_multigraph_mg_test(handle, + src, + dst, + weight, + edge_ids, + edge_types, + num_vertices, + num_edges, + start, + num_starts, + fan_out, + fan_out_size, + with_replacement, + return_hops, + prior_sources_behavior, + dedupe_sources, + renumber_results); +} + +int generic_uniform_neighbor_sample_disjoint_heterogeneous_multigraph_mg_test( + const cugraph_resource_handle_t* handle, + vertex_t* h_src, + vertex_t* h_dst, + weight_t* h_wgt, + edge_t* h_edge_ids, + int32_t* h_edge_types, + size_t num_vertices, + size_t num_edges, + vertex_t* h_start, + size_t num_start_vertices, + int* fan_out, + size_t fan_out_size, + int num_edge_types, + bool_t with_replacement, + bool_t return_hops, + cugraph_prior_sources_behavior_t prior_sources_behavior, + bool_t dedupe_sources, + bool_t renumber_results) +{ + int test_ret_value = 0; + cugraph_error_code_t ret_code = CUGRAPH_SUCCESS; + cugraph_error_t* ret_error = NULL; + cugraph_graph_t* graph = NULL; + cugraph_sample_result_t* result = NULL; + + int rank = cugraph_resource_handle_get_rank(handle); + + ret_code = create_mg_test_graph_new(handle, + vertex_tid, + edge_tid, + h_src, + h_dst, + weight_tid, + h_wgt, + edge_type_tid, + h_edge_types, + edge_id_tid, + h_edge_ids, + edge_time_tid, + NULL, + NULL, + num_edges, + FALSE, + TRUE, + FALSE, + TRUE, + &graph, + &ret_error); + + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "graph creation failed."); + + if (test_ret_value == 0) { + cugraph_type_erased_device_array_t* d_start = NULL; + cugraph_type_erased_device_array_view_t* d_start_view = NULL; + cugraph_type_erased_host_array_view_t* h_fan_out_view = NULL; + + if (rank > 0) num_start_vertices = 0; + + h_fan_out_view = cugraph_type_erased_host_array_view_create(fan_out, fan_out_size, INT32); + + ret_code = cugraph_type_erased_device_array_create( + handle, num_start_vertices, INT32, &d_start, &ret_error); + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "d_start create failed."); + + d_start_view = cugraph_type_erased_device_array_view(d_start); + + ret_code = cugraph_type_erased_device_array_view_copy_from_host( + handle, d_start_view, (byte_t*)h_start, &ret_error); + TEST_ASSERT( + test_ret_value, ret_code == CUGRAPH_SUCCESS, "start vertices copy_from_host failed."); + + cugraph_rng_state_t* rng_state; + ret_code = cugraph_rng_state_create(handle, rank, &rng_state, &ret_error); + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "rng_state create failed."); + + cugraph_sampling_options_t* sampling_options; + + ret_code = cugraph_sampling_options_create(&sampling_options, &ret_error); + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "sampling_options create failed."); + + cugraph_sampling_set_with_replacement(sampling_options, with_replacement); + cugraph_sampling_set_return_hops(sampling_options, return_hops); + cugraph_sampling_set_prior_sources_behavior(sampling_options, prior_sources_behavior); + cugraph_sampling_set_dedupe_sources(sampling_options, dedupe_sources); + cugraph_sampling_set_renumber_results(sampling_options, renumber_results); + cugraph_sampling_set_disjoint_sampling(sampling_options, TRUE); + + ret_code = cugraph_heterogeneous_uniform_neighbor_sample(handle, + rng_state, + graph, + d_start_view, + NULL, + NULL, + h_fan_out_view, + num_edge_types, + sampling_options, + FALSE, + &result, + &ret_error); + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, cugraph_error_message(ret_error)); + TEST_ASSERT( + test_ret_value, ret_code == CUGRAPH_SUCCESS, "heterogeneous uniform_neighbor_sample failed."); + + test_ret_value = mg_validate_sample_result(handle, + result, + h_src, + h_dst, + h_wgt, + h_edge_ids, + h_edge_types, + NULL, + NULL, + num_vertices, + num_edges, + h_start, + num_start_vertices, + NULL, + 0, + fan_out, + fan_out_size, + sampling_options, + FALSE); + TEST_ASSERT(test_ret_value, test_ret_value == 0, "validate_sample_result failed."); + + cugraph_sampling_options_free(sampling_options); + cugraph_rng_state_free(rng_state); + cugraph_sample_result_free(result); + cugraph_type_erased_device_array_view_free(d_start_view); + cugraph_type_erased_host_array_view_free(h_fan_out_view); + cugraph_type_erased_device_array_free(d_start); + } + + cugraph_graph_free(graph); + + return test_ret_value; +} + +int test_uniform_neighbor_sample_disjoint_heterogeneous_multigraph( + const cugraph_resource_handle_t* handle) +{ + size_t num_edges = 10; + size_t num_vertices = 6; + size_t fan_out_size = 2; + size_t num_starts = 1; + int const num_edge_types = 2; + + vertex_t src[] = {0, 0, 0, 1, 1, 2, 2, 2, 3, 4}; + vertex_t dst[] = {1, 1, 3, 3, 4, 0, 1, 3, 5, 5}; + edge_t edge_ids[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + weight_t weight[] = {0.05f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f}; + int32_t edge_types[] = {0, 1, 0, 0, 1, 1, 0, 0, 1, 0}; + vertex_t start[] = {0}; + int fan_out[] = {2, 2}; + + bool_t with_replacement = FALSE; + bool_t return_hops = TRUE; + cugraph_prior_sources_behavior_t prior_sources_behavior = DEFAULT; + bool_t dedupe_sources = FALSE; + bool_t renumber_results = FALSE; + + return generic_uniform_neighbor_sample_disjoint_heterogeneous_multigraph_mg_test( + handle, + src, + dst, + weight, + edge_ids, + edge_types, + num_vertices, + num_edges, + start, + num_starts, + fan_out, + fan_out_size, + num_edge_types, + with_replacement, + return_hops, + prior_sources_behavior, + dedupe_sources, + renumber_results); +} + /******************************************************************************/ int main(int argc, char** argv) @@ -628,6 +999,8 @@ int main(int argc, char** argv) result |= RUN_MG_TEST(test_uniform_neighbor_from_alex, handle); // result |= RUN_MG_TEST(test_uniform_neighbor_sample_alex_bug, handle); result |= RUN_MG_TEST(test_uniform_neighbor_sample_sort_by_hop, handle); + result |= RUN_MG_TEST(test_uniform_neighbor_sample_disjoint_multigraph, handle); + result |= RUN_MG_TEST(test_uniform_neighbor_sample_disjoint_heterogeneous_multigraph, handle); // result |= RUN_MG_TEST(test_uniform_neighbor_sample_dedupe_sources, handle); // result |= RUN_MG_TEST(test_uniform_neighbor_sample_unique_sources, handle); // result |= RUN_MG_TEST(test_uniform_neighbor_sample_carry_over_sources, handle); diff --git a/cpp/tests/c_api/test_utils.cpp b/cpp/tests/c_api/test_utils.cpp index fff00faf47..3d3b01ba2b 100644 --- a/cpp/tests/c_api/test_utils.cpp +++ b/cpp/tests/c_api/test_utils.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2021-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -15,8 +15,13 @@ #include #include +#include + #include +#include +#include + namespace { template raft::device_span make_span(cugraph_type_erased_device_array_view_t const* view) @@ -851,6 +856,7 @@ extern "C" int validate_sample_result(const cugraph_resource_handle_t* handle, raft::update_host(h_result_srcs, renumbered_srcs.data(), result_size, raft_handle.get_stream()); raft::update_host(h_result_dsts, renumbered_dsts.data(), result_size, raft_handle.get_stream()); + raft_handle.sync_stream(); for (int label_id = 0; label_id < (h_start_label_offsets != NULL ? (num_start_label_offsets - 1) : 1); @@ -1066,6 +1072,56 @@ extern "C" int validate_sample_result(const cugraph_resource_handle_t* handle, } } + if ((test_ret_value == 0) && (internal_sampling_options->disjoint_sampling_ == TRUE)) { + rmm::device_uvector d_starting_vertices(num_start_vertices, raft_handle.get_stream()); + raft::update_device( + d_starting_vertices.data(), h_start_vertices, num_start_vertices, raft_handle.get_stream()); + + std::optional> label_offsets_dev{std::nullopt}; + std::optional> batch_nums_dev{std::nullopt}; + rmm::device_uvector d_label_offsets(0, raft_handle.get_stream()); + rmm::device_uvector d_batch_nums(0, raft_handle.get_stream()); + + if (h_start_label_offsets != NULL) { + d_label_offsets.resize(num_start_label_offsets, raft_handle.get_stream()); + raft::update_device(d_label_offsets.data(), + h_start_label_offsets, + num_start_label_offsets, + raft_handle.get_stream()); + label_offsets_dev = + raft::device_span{d_label_offsets.data(), d_label_offsets.size()}; + + d_batch_nums.resize(num_start_vertices, raft_handle.get_stream()); + std::vector h_batch_nums(num_start_vertices); + for (size_t i = 0; i < num_start_vertices; ++i) { + int32_t lab = -1; + for (size_t j = 0; j + 1 < num_start_label_offsets; ++j) { + if (i >= h_start_label_offsets[j] && i < h_start_label_offsets[j + 1]) { + lab = static_cast(j); + break; + } + } + h_batch_nums[i] = lab; + } + raft::update_device( + d_batch_nums.data(), h_batch_nums.data(), num_start_vertices, raft_handle.get_stream()); + batch_nums_dev = raft::device_span{d_batch_nums.data(), d_batch_nums.size()}; + } + + raft_handle.sync_stream(); + + TEST_ASSERT( + test_ret_value, + cugraph::test::validate_disjoint_sampling( + raft_handle, + raft::device_span{renumbered_srcs.data(), renumbered_srcs.size()}, + raft::device_span{renumbered_dsts.data(), renumbered_dsts.size()}, + raft::device_span{d_starting_vertices.data(), d_starting_vertices.size()}, + label_offsets_dev, + batch_nums_dev), + "validate_disjoint_sampling failed"); + } + // FIXME: Add other C++ checks here return test_ret_value; diff --git a/cpp/tests/c_api/uniform_neighbor_sample_test.c b/cpp/tests/c_api/uniform_neighbor_sample_test.c index 158e7946e0..a328864320 100644 --- a/cpp/tests/c_api/uniform_neighbor_sample_test.c +++ b/cpp/tests/c_api/uniform_neighbor_sample_test.c @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -164,6 +164,137 @@ int generic_uniform_neighbor_sample_test(const cugraph_resource_handle_t* handle return test_ret_value; } +/* + * Disjoint sampling on a multigraph: parallel edges (duplicate src,dst) with + * is_multigraph=true. Single batch of seeds (NULL starting_vertex_label_offsets). + * Expects implementation to allow disjoint sampling when the graph is a multigraph. + */ +int generic_uniform_neighbor_sample_disjoint_multigraph_test( + const cugraph_resource_handle_t* handle, + vertex_t* h_src, + vertex_t* h_dst, + weight_t* h_wgt, + edge_t* h_edge_ids, + int32_t* h_edge_types, + size_t num_vertices, + size_t num_edges, + vertex_t* h_start, + size_t num_start_vertices, + int* fan_out, + size_t fan_out_size, + bool_t with_replacement, + bool_t return_hops, + cugraph_prior_sources_behavior_t prior_sources_behavior, + bool_t dedupe_sources, + bool_t renumber_results) +{ + int test_ret_value = 0; + cugraph_error_code_t ret_code = CUGRAPH_SUCCESS; + cugraph_error_t* ret_error = NULL; + cugraph_graph_t* graph = NULL; + cugraph_sample_result_t* result = NULL; + + ret_code = create_sg_test_graph(handle, + vertex_tid, + edge_tid, + h_src, + h_dst, + weight_tid, + h_wgt, + edge_type_tid, + h_edge_types, + edge_id_tid, + h_edge_ids, + INT32, + NULL, + NULL, + num_edges, + FALSE, + TRUE, + FALSE, + TRUE, + &graph, + &ret_error); + + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "graph creation failed."); + + cugraph_type_erased_device_array_t* d_start = NULL; + cugraph_type_erased_device_array_view_t* d_start_view = NULL; + cugraph_type_erased_host_array_view_t* h_fan_out_view = NULL; + + ret_code = cugraph_type_erased_device_array_create( + handle, num_start_vertices, INT32, &d_start, &ret_error); + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "d_start create failed."); + + d_start_view = cugraph_type_erased_device_array_view(d_start); + + ret_code = cugraph_type_erased_device_array_view_copy_from_host( + handle, d_start_view, (byte_t*)h_start, &ret_error); + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "start vertices copy_from_host failed."); + + h_fan_out_view = cugraph_type_erased_host_array_view_create(fan_out, fan_out_size, INT32); + + cugraph_rng_state_t* rng_state; + ret_code = cugraph_rng_state_create(handle, 0, &rng_state, &ret_error); + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "rng_state create failed."); + + cugraph_sampling_options_t* sampling_options; + + ret_code = cugraph_sampling_options_create(&sampling_options, &ret_error); + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "sampling_options create failed."); + + cugraph_sampling_set_with_replacement(sampling_options, with_replacement); + cugraph_sampling_set_return_hops(sampling_options, return_hops); + cugraph_sampling_set_prior_sources_behavior(sampling_options, prior_sources_behavior); + cugraph_sampling_set_dedupe_sources(sampling_options, dedupe_sources); + cugraph_sampling_set_renumber_results(sampling_options, renumber_results); + cugraph_sampling_set_disjoint_sampling(sampling_options, TRUE); + + ret_code = cugraph_homogeneous_uniform_neighbor_sample(handle, + rng_state, + graph, + d_start_view, + NULL, + h_fan_out_view, + sampling_options, + FALSE, + &result, + &ret_error); + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, cugraph_error_message(ret_error)); + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "uniform_neighbor_sample failed."); + + test_ret_value = validate_sample_result(handle, + result, + h_src, + h_dst, + h_wgt, + h_edge_ids, + h_edge_types, + NULL, + NULL, + num_vertices, + num_edges, + h_start, + num_start_vertices, + NULL, + 0, + fan_out, + fan_out_size, + sampling_options, + FALSE); + TEST_ASSERT(test_ret_value, test_ret_value == 0, "validate_sample_result failed."); + + cugraph_sampling_options_free(sampling_options); + cugraph_rng_state_free(rng_state); + cugraph_sample_result_free(result); + cugraph_graph_free(graph); + cugraph_type_erased_device_array_view_free(d_start_view); + cugraph_type_erased_host_array_view_free(h_fan_out_view); + cugraph_type_erased_device_array_free(d_start); + + return test_ret_value; +} + int create_test_graph_with_edge_ids(const cugraph_resource_handle_t* p_handle, vertex_t* h_src, vertex_t* h_dst, @@ -258,6 +389,140 @@ int create_test_graph_with_edge_ids(const cugraph_resource_handle_t* p_handle, return test_ret_value; } +/* + * Disjoint heterogeneous sampling on a multigraph: parallel edges with distinct edge types, + * fan_out per type, and multi-edge indices for property gather. + */ +int generic_uniform_neighbor_sample_disjoint_heterogeneous_multigraph_test( + const cugraph_resource_handle_t* handle, + vertex_t* h_src, + vertex_t* h_dst, + weight_t* h_wgt, + edge_t* h_edge_ids, + int32_t* h_edge_types, + size_t num_vertices, + size_t num_edges, + vertex_t* h_start, + size_t num_start_vertices, + int* fan_out, + size_t fan_out_size, + int num_edge_types, + bool_t with_replacement, + bool_t return_hops, + cugraph_prior_sources_behavior_t prior_sources_behavior, + bool_t dedupe_sources, + bool_t renumber_results) +{ + int test_ret_value = 0; + cugraph_error_code_t ret_code = CUGRAPH_SUCCESS; + cugraph_error_t* ret_error = NULL; + cugraph_graph_t* graph = NULL; + cugraph_sample_result_t* result = NULL; + + ret_code = create_sg_test_graph(handle, + vertex_tid, + edge_tid, + h_src, + h_dst, + weight_tid, + h_wgt, + edge_type_tid, + h_edge_types, + edge_id_tid, + h_edge_ids, + INT32, + NULL, + NULL, + num_edges, + FALSE, + TRUE, + FALSE, + TRUE, + &graph, + &ret_error); + + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "graph creation failed."); + + cugraph_type_erased_device_array_t* d_start = NULL; + cugraph_type_erased_device_array_view_t* d_start_view = NULL; + cugraph_type_erased_host_array_view_t* h_fan_out_view = NULL; + + ret_code = cugraph_type_erased_device_array_create( + handle, num_start_vertices, INT32, &d_start, &ret_error); + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "d_start create failed."); + + d_start_view = cugraph_type_erased_device_array_view(d_start); + + ret_code = cugraph_type_erased_device_array_view_copy_from_host( + handle, d_start_view, (byte_t*)h_start, &ret_error); + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "start vertices copy_from_host failed."); + + h_fan_out_view = cugraph_type_erased_host_array_view_create(fan_out, fan_out_size, INT32); + + cugraph_rng_state_t* rng_state; + ret_code = cugraph_rng_state_create(handle, 0, &rng_state, &ret_error); + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "rng_state create failed."); + + cugraph_sampling_options_t* sampling_options; + + ret_code = cugraph_sampling_options_create(&sampling_options, &ret_error); + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "sampling_options create failed."); + + cugraph_sampling_set_with_replacement(sampling_options, with_replacement); + cugraph_sampling_set_return_hops(sampling_options, return_hops); + cugraph_sampling_set_prior_sources_behavior(sampling_options, prior_sources_behavior); + cugraph_sampling_set_dedupe_sources(sampling_options, dedupe_sources); + cugraph_sampling_set_renumber_results(sampling_options, renumber_results); + cugraph_sampling_set_disjoint_sampling(sampling_options, TRUE); + + ret_code = cugraph_heterogeneous_uniform_neighbor_sample(handle, + rng_state, + graph, + d_start_view, + NULL, + NULL, + h_fan_out_view, + num_edge_types, + sampling_options, + FALSE, + &result, + &ret_error); + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, cugraph_error_message(ret_error)); + TEST_ASSERT( + test_ret_value, ret_code == CUGRAPH_SUCCESS, "heterogeneous uniform_neighbor_sample failed."); + + test_ret_value = validate_sample_result(handle, + result, + h_src, + h_dst, + h_wgt, + h_edge_ids, + h_edge_types, + NULL, + NULL, + num_vertices, + num_edges, + h_start, + num_start_vertices, + NULL, + 0, + fan_out, + fan_out_size, + sampling_options, + FALSE); + TEST_ASSERT(test_ret_value, test_ret_value == 0, "validate_sample_result failed."); + + cugraph_sampling_options_free(sampling_options); + cugraph_rng_state_free(rng_state); + cugraph_sample_result_free(result); + cugraph_graph_free(graph); + cugraph_type_erased_device_array_view_free(d_start_view); + cugraph_type_erased_host_array_view_free(h_fan_out_view); + cugraph_type_erased_device_array_free(d_start); + + return test_ret_value; +} + int test_uniform_neighbor_sample_clean(const cugraph_resource_handle_t* handle) { cugraph_data_type_id_t vertex_tid = INT32; @@ -528,6 +793,90 @@ int test_uniform_neighbor_sample_renumber_results(const cugraph_resource_handle_ renumber_results); } +int test_uniform_neighbor_sample_disjoint_multigraph(const cugraph_resource_handle_t* handle) +{ + size_t num_edges = 10; + size_t num_vertices = 6; + size_t fan_out_size = 3; + size_t num_starts = 1; + + vertex_t src[] = {0, 0, 0, 1, 1, 2, 2, 2, 3, 4}; + vertex_t dst[] = {1, 1, 3, 3, 4, 0, 1, 3, 5, 5}; + edge_t edge_ids[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + weight_t weight[] = {0.05f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f}; + int32_t edge_types[] = {9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; + vertex_t start[] = {0}; + int fan_out[] = {3, 3, 3}; + + bool_t with_replacement = FALSE; + bool_t return_hops = TRUE; + cugraph_prior_sources_behavior_t prior_sources_behavior = DEFAULT; + bool_t dedupe_sources = FALSE; + bool_t renumber_results = FALSE; + + return generic_uniform_neighbor_sample_disjoint_multigraph_test(handle, + src, + dst, + weight, + edge_ids, + edge_types, + num_vertices, + num_edges, + start, + num_starts, + fan_out, + fan_out_size, + with_replacement, + return_hops, + prior_sources_behavior, + dedupe_sources, + renumber_results); +} + +int test_uniform_neighbor_sample_disjoint_heterogeneous_multigraph( + const cugraph_resource_handle_t* handle) +{ + size_t num_edges = 10; + size_t num_vertices = 6; + size_t fan_out_size = 2; + size_t num_starts = 1; + int const num_edge_types = 2; + + vertex_t src[] = {0, 0, 0, 1, 1, 2, 2, 2, 3, 4}; + vertex_t dst[] = {1, 1, 3, 3, 4, 0, 1, 3, 5, 5}; + edge_t edge_ids[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + weight_t weight[] = {0.05f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f}; + int32_t edge_types[] = {0, 1, 0, 0, 1, 1, 0, 0, 1, 0}; + vertex_t start[] = {0}; + int fan_out[] = {2, 2}; + + bool_t with_replacement = FALSE; + bool_t return_hops = TRUE; + cugraph_prior_sources_behavior_t prior_sources_behavior = DEFAULT; + bool_t dedupe_sources = FALSE; + bool_t renumber_results = FALSE; + + return generic_uniform_neighbor_sample_disjoint_heterogeneous_multigraph_test( + handle, + src, + dst, + weight, + edge_ids, + edge_types, + num_vertices, + num_edges, + start, + num_starts, + fan_out, + fan_out_size, + num_edge_types, + with_replacement, + return_hops, + prior_sources_behavior, + dedupe_sources, + renumber_results); +} + int main(int argc, char** argv) { cugraph_resource_handle_t* handle = NULL; @@ -540,6 +889,8 @@ int main(int argc, char** argv) result |= RUN_TEST_NEW(test_uniform_neighbor_sample_unique_sources, handle); result |= RUN_TEST_NEW(test_uniform_neighbor_sample_carry_over_sources, handle); result |= RUN_TEST_NEW(test_uniform_neighbor_sample_renumber_results, handle); + result |= RUN_TEST_NEW(test_uniform_neighbor_sample_disjoint_multigraph, handle); + result |= RUN_TEST_NEW(test_uniform_neighbor_sample_disjoint_heterogeneous_multigraph, handle); cugraph_free_resource_handle(handle);