Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions R/tree_distance_utilities.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@
#' @importFrom TreeTools as.Splits TipLabels
#' @importFrom utils combn
#' @export
# Keep in sync with C++ guard: min(SL_MAX_TIPS, int16_t::max()).
.MaxSupportedTips <- 32767L

.AssertNtipSupported <- function(nTip) {
if (!is.na(nTip) && nTip > .MaxSupportedTips) {
stop("This many tips are not (yet) supported.")
}
}

CalculateTreeDistance <- function(Func, tree1, tree2 = NULL,
reportMatching = FALSE, ...) {
supportedClasses <- c("phylo", "Splits")
Expand Down Expand Up @@ -132,9 +141,7 @@ CalculateTreeDistance <- function(Func, tree1, tree2 = NULL,
# Fast paths: use OpenMP batch functions when all trees share the same tip
# set and no R-level cluster has been configured. Each branch mirrors the
# generic path exactly but avoids per-pair R overhead.
if (!is.na(nTip) && nTip > 32767L) {
stop("This many tips are not (yet) supported.")
}
.AssertNtipSupported(nTip)
if (!is.na(nTip) && is.null(cluster)) {
.n_threads <- as.integer(getOption("mc.cores", 1L))
.batch_result <- if (identical(Func, MutualClusteringInfoSplits)) {
Expand Down Expand Up @@ -235,9 +242,7 @@ CalculateTreeDistance <- function(Func, tree1, tree2 = NULL,
#' @importFrom stats setNames
.SplitDistanceManyMany <- function(Func, splits1, splits2,
tipLabels, nTip = length(tipLabels), ...) {
if (!is.na(nTip) && nTip > 32767L) {
stop("This many tips are not (yet) supported.")
}
.AssertNtipSupported(nTip)
if (is.na(nTip)) {
tipLabels <- union(unlist(tipLabels, use.names = FALSE),
unlist(TipLabels(splits2), use.names = FALSE))
Expand Down Expand Up @@ -408,9 +413,7 @@ CalculateTreeDistance <- function(Func, tree1, tree2 = NULL,
if (ncol(x) != ncol(y)) {
stop("Input splits must address same number of tips.")
}
if (nTip > 32767L) {
stop("This many tips are not (yet) supported.")
}
.AssertNtipSupported(nTip)
}

.CheckLabelsSame <- function(labelList) {
Expand Down
37 changes: 24 additions & 13 deletions src/tree_distances.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <TreeTools/SplitList.h>
#include <TreeTools/assert.h>
#include <algorithm>
#include <cmath>
#include <memory> /* for unique_ptr, make_unique */
#include <Rcpp/Lightest>
Expand Down Expand Up @@ -29,9 +30,16 @@ namespace TreeDist {
}

void check_ntip(const double n) {
// Validated by R caller (nTip > 32767 guard in CalculateTreeDistance et al.)
ASSERT(n <= static_cast<double>(std::numeric_limits<int16>::max())
&& "This many tips are not (yet) supported.");
// SplitList dimensions are bounded by SL_MAX_TIPS, and current scoring
// paths use int16-sized counts internally.
static_assert(SL_MAX_TIPS <= std::numeric_limits<int32>::max(),
"SL_MAX_TIPS must fit in int32");
constexpr int32 max_supported_tips = std::min(
int32(SL_MAX_TIPS), int32(std::numeric_limits<int16_t>::max())
);
if (n > max_supported_tips) {
Rcpp::stop("This many tips are not (yet) supported.");
}
}


Expand All @@ -51,12 +59,13 @@ inline List robinson_foulds_distance(const RawMatrix &x, const RawMatrix &y,

grf_match matching(a.n_splits, NA_INTEGER);

splitbit b_complement[SL_MAX_SPLITS][SL_MAX_BINS];
// Heap-backed scratch avoids large fixed-size stack allocation.
std::vector<splitbit> b_complement(size_t(b.n_splits) * size_t(a.n_bins));
for (int32 i = b.n_splits; i--; ) {
for (int32 bin = last_bin; bin--; ) {
b_complement[i][bin] = ~b.state[i][bin];
b_complement[size_t(i) * a.n_bins + bin] = ~b.state[i][bin];
}
b_complement[i][last_bin] = b.state[i][last_bin] ^ unset_mask;
b_complement[size_t(i) * a.n_bins + last_bin] = b.state[i][last_bin] ^ unset_mask;
}

for (int32 ai = a.n_splits; ai--; ) {
Expand All @@ -73,7 +82,7 @@ inline List robinson_foulds_distance(const RawMatrix &x, const RawMatrix &y,
}
if (!all_match) {
for (int32 bin = 0; bin < a.n_bins; ++bin) {
if (a.state[ai][bin] != b_complement[bi][bin]) {
if (a.state[ai][bin] != b_complement[size_t(bi) * a.n_bins + bin]) {
all_complement = false;
break;
}
Expand Down Expand Up @@ -105,13 +114,13 @@ inline List robinson_foulds_info(const RawMatrix &x, const RawMatrix &y,

grf_match matching(a.n_splits, NA_INTEGER);

/* Dynamic allocation 20% faster for 105 tips, but VLA not permitted in C11 */
splitbit b_complement[SL_MAX_SPLITS][SL_MAX_BINS];
// Heap-backed scratch avoids large fixed-size stack allocation.
std::vector<splitbit> b_complement(size_t(b.n_splits) * size_t(a.n_bins));
for (int16 i = 0; i < b.n_splits; i++) {
for (int16 bin = 0; bin < last_bin; ++bin) {
b_complement[i][bin] = ~b.state[i][bin];
b_complement[size_t(i) * a.n_bins + bin] = ~b.state[i][bin];
}
b_complement[i][last_bin] = b.state[i][last_bin] ^ unset_mask;
b_complement[size_t(i) * a.n_bins + last_bin] = b.state[i][last_bin] ^ unset_mask;
}

for (int16 ai = 0; ai < a.n_splits; ++ai) {
Expand All @@ -127,7 +136,7 @@ inline List robinson_foulds_info(const RawMatrix &x, const RawMatrix &y,
}
if (!all_match) {
for (int16 bin = 0; bin < a.n_bins; ++bin) {
if ((a.state[ai][bin] != b_complement[bi][bin])) {
if ((a.state[ai][bin] != b_complement[size_t(bi) * a.n_bins + bin])) {
all_complement = false;
break;
}
Expand Down Expand Up @@ -607,7 +616,9 @@ inline List shared_phylo (const RawMatrix &x, const RawMatrix &y,
List cpp_robinson_foulds_distance(const RawMatrix &x, const RawMatrix &y,
const IntegerVector &nTip) {
ASSERT(x.cols() == y.cols() && "Input splits must address same number of tips.");
return robinson_foulds_distance(x, y, static_cast<int32>(nTip[0]));
const int32 n_tip = static_cast<int32>(nTip[0]);
TreeDist::check_ntip(n_tip);
return robinson_foulds_distance(x, y, n_tip);
}

// [[Rcpp::export]]
Expand Down
21 changes: 21 additions & 0 deletions tests/testthat/test-tree_distance_utilities.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,27 @@ test_that("CalculateTreeDistance() errs appropriately", {
expect_error(CalculateTreeDistance(RobinsonFouldsSplits, BalancedTree(8), "Not a tree"))
})

test_that("Tip-count guard is applied consistently", {
expect_no_error(.AssertNtipSupported(1000L))
expect_no_error(.AssertNtipSupported(32766L))
expect_no_error(.AssertNtipSupported(32767L))
expect_error(.AssertNtipSupported(32768L),
"This many tips are not \\(yet\\) supported\\.")

splits8 <- unclass(as.Splits(BalancedTree(8)))
expect_error(cpp_robinson_foulds_distance(splits8, splits8, 32768L),
"This many tips are not \\(yet\\) supported\\.")
expect_error(cpp_robinson_foulds_info(splits8, splits8, 32768L),
"This many tips are not \\(yet\\) supported\\.")

trees <- list(BalancedTree(8), PectinateTree(8))
class(trees) <- "multiPhylo"
expect_error(
.SplitDistanceAllPairs(RobinsonFouldsSplits, trees, letters[1:8], 32768L),
"This many tips are not \\(yet\\) supported\\."
)
})

test_that("CalculateTreeDistance() handles splits appropriately", {
set.seed(101)
tree10 <- ape::rtree(10)
Expand Down
Loading