Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@ categories = ["algorithms", "mathematics", "science"]
num-traits = "0.2.18"
kdtree = "0.7.0"
rayon = { version = "1.10.0", optional = true }
serde = { version = "1.0", features = ["derive"], optional = true }

[dev-dependencies]
bincode = "1.3"

[features]
default = ["serial"]
serial = []
parallel = ["dep:rayon"]
serde = ["dep:serde"]
26 changes: 21 additions & 5 deletions src/data_wrappers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,25 @@ pub(crate) struct SLTNode<T> {
pub(crate) size: usize,
}

pub(crate) struct CondensedNode<T> {
pub(crate) node_id: usize,
pub(crate) parent_node_id: usize,
pub(crate) lambda_birth: T,
pub(crate) size: usize,
/// A node in the condensed cluster tree produced by HDBSCAN. Exposed
/// publicly to enable external `approximate_predict`-style inference on
/// new points (assigning previously-unseen samples to existing clusters
/// without re-running the full algorithm).
///
/// Marked `#[non_exhaustive]` so additional fields can be added in
/// future revisions without breaking external consumers.
#[non_exhaustive]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
feature = "serde",
serde(bound(
serialize = "T: serde::Serialize",
deserialize = "T: serde::Deserialize<'de>"
))
)]
pub struct CondensedNode<T> {
pub node_id: usize,
pub parent_node_id: usize,
pub lambda_birth: T,
pub size: usize,
}
64 changes: 62 additions & 2 deletions src/hdbscan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ use num_traits::Float;
use std::collections::{HashMap, VecDeque};
use std::ops::Range;

type CondensedTree<T> = Vec<CondensedNode<T>>;
/// The condensed cluster tree produced internally by HDBSCAN, used to
/// derive cluster labels. Exposed publicly to enable external
/// `approximate_predict`-style inference; obtain an instance via
/// [`Hdbscan::cluster_with_tree`].
pub type CondensedTree<T> = Vec<CondensedNode<T>>;

/// The HDBSCAN clustering algorithm in Rust. Generic over floating point numeric types.
#[derive(Debug, Clone, PartialEq)]
Expand Down Expand Up @@ -65,6 +69,29 @@ impl<T: Float> Hdbscan<'_, T> {
///assert_eq!(-1, labels[10]);
/// ```
pub fn cluster(&self) -> Result<Vec<i32>, HdbscanError> {
self.cluster_internal().map(|(labels, _tree)| labels)
}

/// Same as [`cluster`](Hdbscan::cluster), but additionally returns
/// the condensed cluster tree used internally to derive the cluster
/// labels. The condensed tree is required for
/// `approximate_predict`-style inference: assigning previously-unseen
/// points to existing clusters without re-running the full algorithm.
///
/// # Returns
/// * A `Result` whose `Ok` contains `(labels, condensed_tree)`.
/// Labels follow the same semantics as [`cluster`](Hdbscan::cluster);
/// the condensed tree is a `Vec<CondensedNode<T>>` describing the
/// cluster hierarchy.
pub fn cluster_with_tree(
&self,
) -> Result<(Vec<i32>, CondensedTree<T>), HdbscanError> {
self.cluster_internal()
}

fn cluster_internal(
&self,
) -> Result<(Vec<i32>, CondensedTree<T>), HdbscanError> {
DataValidator::new(self.data, &self.hp).validate_input_data()?;

let core_dist_calculator = CoreDistanceCalculator::new(self.data, &self.hp);
Expand All @@ -79,7 +106,7 @@ impl<T: Float> Hdbscan<'_, T> {
let winning_clusters = self.extract_winning_clusters(&condensed_tree);
let labelled_data = self.label_data(&winning_clusters, &condensed_tree);

Ok(labelled_data)
Ok((labelled_data, condensed_tree))
}
}

Expand Down Expand Up @@ -806,3 +833,36 @@ impl<'a, T: Float> Hdbscan<'a, T> {
}
}
}

#[cfg(all(test, feature = "serde"))]
mod serde_tests {
use super::CondensedTree;
use crate::data_wrappers::CondensedNode;

#[test]
fn condensed_tree_bincode_roundtrip() {
let tree: CondensedTree<f64> = vec![
CondensedNode {
node_id: 5,
parent_node_id: 11,
lambda_birth: 0.42,
size: 3,
},
CondensedNode {
node_id: 6,
parent_node_id: 11,
lambda_birth: 0.42,
size: 4,
},
];
let bytes = bincode::serialize(&tree).expect("serialize tree");
let decoded: CondensedTree<f64> =
bincode::deserialize(&bytes).expect("deserialize tree");
assert_eq!(decoded.len(), 2);
assert_eq!(decoded[0].node_id, 5);
assert_eq!(decoded[0].parent_node_id, 11);
assert_eq!(decoded[0].lambda_birth, 0.42);
assert_eq!(decoded[0].size, 3);
assert_eq!(decoded[1].node_id, 6);
}
}
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@

pub use crate::centers::Center;
pub use crate::core_distances::NnAlgorithm;
pub use crate::data_wrappers::CondensedNode;
pub use crate::distance::DistanceMetric;
pub use crate::error::HdbscanError;
pub use crate::hdbscan::Hdbscan;
pub use crate::hdbscan::{CondensedTree, Hdbscan};
pub use crate::hyper_parameters::{HdbscanHyperParams, HyperParamBuilder};

mod centers;
Expand Down
Loading