From 6f2fdbb12b209b82d65ac06edb48ca48a6a1d20a Mon Sep 17 00:00:00 2001 From: tom-whitehead Date: Sun, 23 Nov 2025 21:30:34 +0000 Subject: [PATCH 1/4] feat: par minimum spanning tree --- src/hdbscan.rs | 94 ++++++-------------- src/lib.rs | 1 + src/min_spanning_tree.rs | 185 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 213 insertions(+), 67 deletions(-) create mode 100644 src/min_spanning_tree.rs diff --git a/src/hdbscan.rs b/src/hdbscan.rs index 822591b..0155c0f 100644 --- a/src/hdbscan.rs +++ b/src/hdbscan.rs @@ -3,6 +3,11 @@ use crate::core_distances::parallel::CoreDistanceCalculatorPar; #[cfg(feature = "serial")] use crate::core_distances::serial::CoreDistanceCalculator; use crate::data_wrappers::{CondensedNode, MSTEdge, SLTNode}; +#[cfg(feature = "parallel")] +use crate::min_spanning_tree::parallel::PrimsMinSpanningTreePar; +#[cfg(feature = "serial")] +use crate::min_spanning_tree::serial::PrimsMinSpanningTree; +use crate::min_spanning_tree::MinSpanningTree; use crate::union_find::UnionFind; use crate::validation::DataValidator; use crate::{distance, Center, DistanceMetric, HdbscanError, HdbscanHyperParams}; @@ -26,7 +31,7 @@ impl Hdbscan<'_, T> { /// /// # Returns /// * A result that, if successful, contains a list of cluster labels, with a length equal to - /// the numbe of samples passed to the constructor. Positive integers mean a data point + /// the number of samples passed to the constructor. Positive integers mean a data point /// belongs to a cluster of that label. -1 labels mean that a data point is noise and does /// not belong to any cluster. An Error will be returned if the dimensionality of the input /// vectors are mismatched, if any vector contains non-finite coordinates, or if the passed @@ -60,15 +65,20 @@ impl Hdbscan<'_, T> { ///assert_eq!(-1, labels[10]); /// ``` pub fn cluster(&self) -> Result, HdbscanError> { - let validator = DataValidator::new(self.data, &self.hp); - validator.validate_input_data()?; - let calculator = CoreDistanceCalculator::new(self.data, &self.hp); - let core_distances = calculator.calc_core_distances(); - let min_spanning_tree = self.prims_min_spanning_tree(&core_distances); + DataValidator::new(self.data, &self.hp).validate_input_data()?; + + let core_dist_calculator = CoreDistanceCalculator::new(self.data, &self.hp); + let core_distances = core_dist_calculator.calc_core_distances(); + + let mst_calculator = + PrimsMinSpanningTree::new(self.data, self.hp.dist_metric, &core_distances); + let min_spanning_tree = mst_calculator.compute(); + let single_linkage_tree = self.make_single_linkage_tree(&min_spanning_tree); let condensed_tree = self.condense_tree(&single_linkage_tree); let winning_clusters = self.extract_winning_clusters(&condensed_tree); let labelled_data = self.label_data(&winning_clusters, &condensed_tree); + Ok(labelled_data) } } @@ -80,7 +90,7 @@ impl Hdbscan<'_, T> { /// /// # Returns /// * A result that, if successful, contains a list of cluster labels, with a length equal to - /// the numbe of samples passed to the constructor. Positive integers mean a data point + /// the number of samples passed to the constructor. Positive integers mean a data point /// belongs to a cluster of that label. -1 labels mean that a data point is noise and does /// not belong to any cluster. An Error will be returned if the dimensionality of the input /// vectors are mismatched, if any vector contains non-finite coordinates, or if the passed @@ -114,15 +124,20 @@ impl Hdbscan<'_, T> { ///assert_eq!(-1, labels[10]); /// ``` pub fn cluster_par(&self) -> Result, HdbscanError> { - let validator = DataValidator::new(self.data, &self.hp); - validator.validate_input_data()?; - let calculator = CoreDistanceCalculatorPar::new(self.data, &self.hp); - let core_distances = calculator.calc_core_distances(); - let min_spanning_tree = self.prims_min_spanning_tree(&core_distances); + DataValidator::new(self.data, &self.hp).validate_input_data()?; + + let core_dist_calculator = CoreDistanceCalculatorPar::new(self.data, &self.hp); + let core_distances = core_dist_calculator.calc_core_distances(); + + let mst_calculator = + PrimsMinSpanningTreePar::new(self.data, self.hp.dist_metric, &core_distances); + let min_spanning_tree = mst_calculator.compute(); + let single_linkage_tree = self.make_single_linkage_tree(&min_spanning_tree); let condensed_tree = self.condense_tree(&single_linkage_tree); let winning_clusters = self.extract_winning_clusters(&condensed_tree); let labelled_data = self.label_data(&winning_clusters, &condensed_tree); + Ok(labelled_data) } } @@ -275,61 +290,6 @@ impl<'a, T: Float> Hdbscan<'a, T> { )) } - fn prims_min_spanning_tree(&self, core_distances: &[T]) -> Vec> { - let mut in_tree = vec![false; self.n_samples]; - let mut distances = vec![T::infinity(); self.n_samples]; - distances[0] = T::zero(); - - let mut mst = Vec::with_capacity(self.n_samples); - - let mut left_node_id = 0; - let mut right_node_id = 0; - - for _ in 1..self.n_samples { - in_tree[left_node_id] = true; - let mut current_min_dist = T::infinity(); - - for i in 0..self.n_samples { - if in_tree[i] { - continue; - } - let mrd = self.calc_mutual_reachability_dist(left_node_id, i, core_distances); - if mrd < distances[i] { - distances[i] = mrd; - } - if distances[i] < current_min_dist { - right_node_id = i; - current_min_dist = distances[i]; - } - } - mst.push(MSTEdge { - left_node_id, - right_node_id, - distance: current_min_dist, - }); - left_node_id = right_node_id; - } - self.sort_mst_by_dist(&mut mst); - mst - } - - fn calc_mutual_reachability_dist(&self, a: usize, b: usize, core_distances: &[T]) -> T { - let core_dist_a = core_distances[a]; - let core_dist_b = core_distances[b]; - let dist_a_b = if self.hp.dist_metric == DistanceMetric::Precalculated { - self.data[a][b] - } else { - self.hp.dist_metric.calc_dist(&self.data[a], &self.data[b]) - }; - - core_dist_a.max(core_dist_b).max(dist_a_b) - } - - fn sort_mst_by_dist(&self, min_spanning_tree: &mut [MSTEdge]) { - min_spanning_tree - .sort_by(|a, b| a.distance.partial_cmp(&b.distance).expect("Invalid floats")); - } - fn make_single_linkage_tree(&self, min_spanning_tree: &[MSTEdge]) -> Vec> { let mut single_linkage_tree: Vec> = Vec::with_capacity(self.n_samples - 1); diff --git a/src/lib.rs b/src/lib.rs index 7dbc5a7..151f4f1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -63,5 +63,6 @@ mod distance; mod error; mod hdbscan; mod hyper_parameters; +mod min_spanning_tree; mod union_find; mod validation; diff --git a/src/min_spanning_tree.rs b/src/min_spanning_tree.rs new file mode 100644 index 0000000..82d027c --- /dev/null +++ b/src/min_spanning_tree.rs @@ -0,0 +1,185 @@ +use crate::data_wrappers::MSTEdge; +use crate::DistanceMetric; +use num_traits::Float; + +pub(crate) trait MinSpanningTree<'a, T> { + fn compute(&self) -> Vec>; +} + +#[derive(Clone, Debug)] +struct MinSpanningTreeCommon<'a, T> { + data: &'a [Vec], + dist_metric: DistanceMetric, + core_distances: &'a [T], + n_samples: usize, +} + +impl<'a, T: Float> MinSpanningTreeCommon<'a, T> { + fn new(data: &'a [Vec], dist_metric: DistanceMetric, core_distances: &'a [T]) -> Self { + MinSpanningTreeCommon { + data, + dist_metric, + core_distances, + n_samples: data.len(), + } + } + + fn calc_mutual_reachability_dist(&self, a: usize, b: usize) -> T { + let core_dist_a = self.core_distances[a]; + let core_dist_b = self.core_distances[b]; + let dist_a_b = if self.dist_metric == DistanceMetric::Precalculated { + self.data[a][b] + } else { + self.dist_metric.calc_dist(&self.data[a], &self.data[b]) + }; + core_dist_a.max(core_dist_b).max(dist_a_b) + } + + fn sort_mst_by_dist(&self, min_spanning_tree: &mut [MSTEdge]) { + min_spanning_tree + .sort_by(|a, b| a.distance.partial_cmp(&b.distance).expect("Invalid floats")); + } +} + +#[cfg(feature = "serial")] +pub(crate) mod serial { + use super::*; + use crate::data_wrappers::MSTEdge; + use num_traits::Float; + + #[derive(Clone, Debug)] + pub(crate) struct PrimsMinSpanningTree<'a, T> { + common: MinSpanningTreeCommon<'a, T>, + } + + impl<'a, T: Float> PrimsMinSpanningTree<'a, T> { + pub(crate) fn new( + data: &'a [Vec], + dist_metric: DistanceMetric, + core_distances: &'a [T], + ) -> Self { + let common = MinSpanningTreeCommon::new(data, dist_metric, core_distances); + PrimsMinSpanningTree { common } + } + } + + impl<'a, T: Float> MinSpanningTree<'a, T> for PrimsMinSpanningTree<'a, T> { + fn compute(&self) -> Vec> { + let n_samples = self.common.n_samples; + + let mut in_tree = vec![false; n_samples]; + let mut distances = vec![T::infinity(); n_samples]; + distances[0] = T::zero(); + + let mut mst = Vec::with_capacity(n_samples); + + let mut left_node_id = 0; + let mut right_node_id = 0; + + for _ in 1..n_samples { + in_tree[left_node_id] = true; + let mut current_min_dist = T::infinity(); + + for i in 0..n_samples { + if in_tree[i] { + continue; + } + let mrd = self.common.calc_mutual_reachability_dist(left_node_id, i); + if mrd < distances[i] { + distances[i] = mrd; + } + if distances[i] < current_min_dist { + right_node_id = i; + current_min_dist = distances[i]; + } + } + mst.push(MSTEdge { + left_node_id, + right_node_id, + distance: current_min_dist, + }); + left_node_id = right_node_id; + } + self.common.sort_mst_by_dist(&mut mst); + mst + } + } +} + +#[cfg(feature = "parallel")] +pub(crate) mod parallel { + use super::*; + use crate::data_wrappers::MSTEdge; + use num_traits::Float; + use rayon::prelude::*; + + #[derive(Clone, Debug)] + pub(crate) struct PrimsMinSpanningTreePar<'a, T> { + common: MinSpanningTreeCommon<'a, T>, + } + + impl<'a, T: Float + Send + Sync> PrimsMinSpanningTreePar<'a, T> { + pub(crate) fn new( + data: &'a [Vec], + dist_metric: DistanceMetric, + core_distances: &'a [T], + ) -> Self { + let common = MinSpanningTreeCommon::new(data, dist_metric, core_distances); + PrimsMinSpanningTreePar { common } + } + } + + impl<'a, T: Float + Send + Sync> MinSpanningTree<'a, T> for PrimsMinSpanningTreePar<'a, T> { + fn compute(&self) -> Vec> { + let n_samples = self.common.n_samples; + + let mut in_tree = vec![false; n_samples]; + let mut distances = vec![T::infinity(); n_samples]; + distances[0] = T::zero(); + + let mut mst = Vec::with_capacity(n_samples); + + let mut left_node_id = 0; + let mut right_node_id = 0; + + for _ in 1..n_samples { + in_tree[left_node_id] = true; + + let updates: Vec<(usize, T)> = (0..n_samples) + .into_par_iter() + .filter(|&i| !in_tree[i]) + .map(|i| { + let mrd = self.common.calc_mutual_reachability_dist(left_node_id, i); + (i, mrd) + }) + .collect(); + + let (min_idx, min_dist) = updates + .iter() + .map(|&(i, mrd)| { + let dist = distances[i].min(mrd); + (i, dist) + }) + .min_by(|(_, dist_a), (_, dist_b)| dist_a.partial_cmp(dist_b).unwrap()) + .expect("No minimum candidate"); + + updates.into_par_iter().for_each(|(i, mrd)| { + if mrd < distances[i] { + distances[i] = mrd; + } + }); + + right_node_id = min_idx; + mst.push(MSTEdge { + left_node_id, + right_node_id, + distance: min_dist, + }); + left_node_id = right_node_id; + } + + self.common.sort_mst_by_dist(&mut mst); + mst + } + } +} From d601841e44e3f7d26eb699d5a6a7a7483d63cc2f Mon Sep 17 00:00:00 2001 From: tom-whitehead Date: Mon, 24 Nov 2025 20:15:34 +0000 Subject: [PATCH 2/4] parallel iterator --- src/min_spanning_tree.rs | 43 +++++++++++++++++----------------------- 1 file changed, 18 insertions(+), 25 deletions(-) diff --git a/src/min_spanning_tree.rs b/src/min_spanning_tree.rs index 82d027c..2bee5a1 100644 --- a/src/min_spanning_tree.rs +++ b/src/min_spanning_tree.rs @@ -140,42 +140,35 @@ pub(crate) mod parallel { let mut mst = Vec::with_capacity(n_samples); let mut left_node_id = 0; - let mut right_node_id = 0; for _ in 1..n_samples { in_tree[left_node_id] = true; - let updates: Vec<(usize, T)> = (0..n_samples) - .into_par_iter() - .filter(|&i| !in_tree[i]) - .map(|i| { - let mrd = self.common.calc_mutual_reachability_dist(left_node_id, i); - (i, mrd) + let (min_idx, min_dist) = distances + .par_iter_mut() + .enumerate() + .filter_map(|(i, dist)| { + if in_tree[i] { + None + } else { + let mrd = self.common.calc_mutual_reachability_dist(left_node_id, i); + if mrd < *dist { + *dist = mrd; + } + Some((i, *dist)) + } }) - .collect(); - - let (min_idx, min_dist) = updates - .iter() - .map(|&(i, mrd)| { - let dist = distances[i].min(mrd); - (i, dist) + .min_by(|(_, dist_a), (_, dist_b)| { + dist_a.partial_cmp(dist_b).expect("Invalid floats") }) - .min_by(|(_, dist_a), (_, dist_b)| dist_a.partial_cmp(dist_b).unwrap()) - .expect("No minimum candidate"); + .expect("Malformed distance array"); - updates.into_par_iter().for_each(|(i, mrd)| { - if mrd < distances[i] { - distances[i] = mrd; - } - }); - - right_node_id = min_idx; mst.push(MSTEdge { left_node_id, - right_node_id, + right_node_id: min_idx, distance: min_dist, }); - left_node_id = right_node_id; + left_node_id = min_idx; } self.common.sort_mst_by_dist(&mut mst); From e9086d31564ca6a2ba810e5aafe7f18f61ea3c2b Mon Sep 17 00:00:00 2001 From: tom-whitehead Date: Mon, 24 Nov 2025 20:43:50 +0000 Subject: [PATCH 3/4] bump version --- CHANGELOG.md | 7 ++++++- Cargo.toml | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bdfcd3c..e642b8b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +# Version 0.12.0 2025-11-24 +## Changes +- Improvement to parallel clustering via the `cluster_par` method with parallelisation of the calculation of + Prim's minimum spanning tree. + # Version 0.11.0 2025-08-03 ## Changes - Addition of optional `parallel` feature that adds a method `cluster_par` to the `Hdbscan` struct. This method @@ -78,4 +83,4 @@ # Version 0.3.0 2024-02-20 ## Changes - Added `max_cluster_size` hyper parameter, with support in the hyper parameter builder - - Improved read me documentation on current state of the algorithm \ No newline at end of file + - Improved read me documentation on current state of the algorithm diff --git a/Cargo.toml b/Cargo.toml index d5e8426..b38c520 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hdbscan" -version = "0.11.0" +version = "0.12.0" edition = "2021" authors = [ "Tom Whitehead ", ] description = "HDBSCAN clustering in pure Rust. A huge improvement on DBSCAN, capable of identifying clusters of varying densities." From 25bb2570b9258bf26956e379ca32fb1217cbf27d Mon Sep 17 00:00:00 2001 From: tom-whitehead Date: Mon, 24 Nov 2025 20:46:58 +0000 Subject: [PATCH 4/4] bump rust version in CI --- .github/workflows/rust.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 2672d50..350aef7 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -21,7 +21,7 @@ jobs: run: sudo apt-get update - name: Setup | Rust toolchain - uses: dtolnay/rust-toolchain@1.79.0 + uses: dtolnay/rust-toolchain@1.80.0 with: components: clippy, rustfmt