diff --git a/Cargo.toml b/Cargo.toml index b38c520..9578ed3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,8 +14,13 @@ categories = ["algorithms", "mathematics", "science"] num-traits = "0.2.18" kdtree = "0.7.0" rayon = { version = "1.10.0", optional = true } +serde = { version = "1.0", features = ["derive"], optional = true } + +[dev-dependencies] +bincode = "1.3" [features] default = ["serial"] serial = [] parallel = ["dep:rayon"] +serde = ["dep:serde"] diff --git a/src/data_wrappers.rs b/src/data_wrappers.rs index a7f529d..321f49e 100644 --- a/src/data_wrappers.rs +++ b/src/data_wrappers.rs @@ -12,9 +12,25 @@ pub(crate) struct SLTNode { pub(crate) size: usize, } -pub(crate) struct CondensedNode { - pub(crate) node_id: usize, - pub(crate) parent_node_id: usize, - pub(crate) lambda_birth: T, - pub(crate) size: usize, +/// A node in the condensed cluster tree produced by HDBSCAN. Exposed +/// publicly to enable external `approximate_predict`-style inference on +/// new points (assigning previously-unseen samples to existing clusters +/// without re-running the full algorithm). +/// +/// Marked `#[non_exhaustive]` so additional fields can be added in +/// future revisions without breaking external consumers. +#[non_exhaustive] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr( + feature = "serde", + serde(bound( + serialize = "T: serde::Serialize", + deserialize = "T: serde::Deserialize<'de>" + )) +)] +pub struct CondensedNode { + pub node_id: usize, + pub parent_node_id: usize, + pub lambda_birth: T, + pub size: usize, } diff --git a/src/hdbscan.rs b/src/hdbscan.rs index 0155c0f..695fc60 100644 --- a/src/hdbscan.rs +++ b/src/hdbscan.rs @@ -15,7 +15,11 @@ use num_traits::Float; use std::collections::{HashMap, VecDeque}; use std::ops::Range; -type CondensedTree = Vec>; +/// The condensed cluster tree produced internally by HDBSCAN, used to +/// derive cluster labels. Exposed publicly to enable external +/// `approximate_predict`-style inference; obtain an instance via +/// [`Hdbscan::cluster_with_tree`]. +pub type CondensedTree = Vec>; /// The HDBSCAN clustering algorithm in Rust. Generic over floating point numeric types. #[derive(Debug, Clone, PartialEq)] @@ -65,6 +69,29 @@ impl Hdbscan<'_, T> { ///assert_eq!(-1, labels[10]); /// ``` pub fn cluster(&self) -> Result, HdbscanError> { + self.cluster_internal().map(|(labels, _tree)| labels) + } + + /// Same as [`cluster`](Hdbscan::cluster), but additionally returns + /// the condensed cluster tree used internally to derive the cluster + /// labels. The condensed tree is required for + /// `approximate_predict`-style inference: assigning previously-unseen + /// points to existing clusters without re-running the full algorithm. + /// + /// # Returns + /// * A `Result` whose `Ok` contains `(labels, condensed_tree)`. + /// Labels follow the same semantics as [`cluster`](Hdbscan::cluster); + /// the condensed tree is a `Vec>` describing the + /// cluster hierarchy. + pub fn cluster_with_tree( + &self, + ) -> Result<(Vec, CondensedTree), HdbscanError> { + self.cluster_internal() + } + + fn cluster_internal( + &self, + ) -> Result<(Vec, CondensedTree), HdbscanError> { DataValidator::new(self.data, &self.hp).validate_input_data()?; let core_dist_calculator = CoreDistanceCalculator::new(self.data, &self.hp); @@ -79,7 +106,7 @@ impl Hdbscan<'_, T> { let winning_clusters = self.extract_winning_clusters(&condensed_tree); let labelled_data = self.label_data(&winning_clusters, &condensed_tree); - Ok(labelled_data) + Ok((labelled_data, condensed_tree)) } } @@ -806,3 +833,36 @@ impl<'a, T: Float> Hdbscan<'a, T> { } } } + +#[cfg(all(test, feature = "serde"))] +mod serde_tests { + use super::CondensedTree; + use crate::data_wrappers::CondensedNode; + + #[test] + fn condensed_tree_bincode_roundtrip() { + let tree: CondensedTree = vec![ + CondensedNode { + node_id: 5, + parent_node_id: 11, + lambda_birth: 0.42, + size: 3, + }, + CondensedNode { + node_id: 6, + parent_node_id: 11, + lambda_birth: 0.42, + size: 4, + }, + ]; + let bytes = bincode::serialize(&tree).expect("serialize tree"); + let decoded: CondensedTree = + bincode::deserialize(&bytes).expect("deserialize tree"); + assert_eq!(decoded.len(), 2); + assert_eq!(decoded[0].node_id, 5); + assert_eq!(decoded[0].parent_node_id, 11); + assert_eq!(decoded[0].lambda_birth, 0.42); + assert_eq!(decoded[0].size, 3); + assert_eq!(decoded[1].node_id, 6); + } +} diff --git a/src/lib.rs b/src/lib.rs index 151f4f1..6391985 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -51,9 +51,10 @@ pub use crate::centers::Center; pub use crate::core_distances::NnAlgorithm; +pub use crate::data_wrappers::CondensedNode; pub use crate::distance::DistanceMetric; pub use crate::error::HdbscanError; -pub use crate::hdbscan::Hdbscan; +pub use crate::hdbscan::{CondensedTree, Hdbscan}; pub use crate::hyper_parameters::{HdbscanHyperParams, HyperParamBuilder}; mod centers;