Skip to content

Commit 0d7feda

Browse files
authored
c++ for normalize_splits (#257)
1 parent 5e1351f commit 0d7feda

File tree

10 files changed

+105
-39
lines changed

10 files changed

+105
-39
lines changed

.positai/settings.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,9 @@
1111
"executeCode": {
1212
"*": "allow"
1313
}
14+
},
15+
"model": {
16+
"id": "claude-opus-4-6",
17+
"provider": "positai"
1418
}
1519
}

AGENTS.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,10 @@ of internal nodes for any topologically identical tree.
2727

2828
Splits objects are defined in `as.Splits()`, and denote split membership as
2929
binary 0/1 in an underlying `raw` object.
30+
31+
## Workflow requirements
32+
33+
- After completing each optimization or user-visible change, update `NEWS.md`
34+
before moving on to the next task.
35+
- Increment the `.900X` dev version suffix in `DESCRIPTION` with each
36+
`NEWS.md` update.

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: TreeTools
22
Title: Create, Modify and Analyse Phylogenetic Trees
3-
Version: 2.1.0.9004
3+
Version: 2.1.0.9005
44
Authors@R: c(
55
person("Martin R.", 'Smith', role = c("aut", "cre", "cph"),
66
email = "martin.smith@durham.ac.uk",

NEWS.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
# TreeTools 2.1.0.9004 (2026-03-12) #
1+
# TreeTools 2.1.0.9005 (2026-03-13) #
22

3-
- Rewrite popcount calculation for more efficient `TipsInSplits()`.
3+
- `SplitFrequency(reference = NULL)` split normalization moved to C++,
4+
eliminating an R-level per-split loop.
45

56
# TreeTools 2.1.0.9003 (2026-03-09) #
67

R/RcppExports.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,10 @@ cpp_count_splits <- function(edge, nTip) {
137137
.Call(`_TreeTools_cpp_count_splits`, edge, nTip)
138138
}
139139

140+
normalize_splits <- function(splits, n_tip) {
141+
.Call(`_TreeTools_normalize_splits`, splits, n_tip)
142+
}
143+
140144
splits_to_edge <- function(splits, nTip) {
141145
.Call(`_TreeTools_splits_to_edge`, splits, nTip)
142146
}

R/Support.R

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -58,30 +58,10 @@ SplitFrequency <- function(reference, forest = NULL) {
5858
return(structure(splits, nTip = nTip, tip.label = tipLabels, # nocov
5959
count = integer(), class = "Splits")) # nocov
6060
}
61-
# The ClusterTable outputs clusters (clades); normalize so bit 0 (tip 1)
62-
# is not in the set (matching as.Splits convention)
63-
nTipMod <- nTip %% 8L
64-
lastByteMask <- if (nTipMod == 0L) as.raw(0xff) else as.raw(bitwShiftL(1L, nTipMod) - 1L)
65-
keep <- logical(nrow(splits))
66-
for (i in seq_len(nrow(splits))) {
67-
val <- splits[i, ]
68-
# Count bits set (to filter trivial splits)
69-
nBits <- sum(vapply(as.integer(val), function(b) sum(as.integer(intToBits(b))), integer(1)))
70-
if (nBits < 2L || nBits > nTip - 2L) next # trivial split
71-
# Normalize: if bit 0 is NOT set, complement to match as.Splits format
72-
if (!as.logical(as.integer(val[1]) %% 2L)) {
73-
for (j in seq_along(val)) {
74-
splits[i, j] <- as.raw(bitwXor(as.integer(val[j]), 0xffL))
75-
}
76-
# Mask last byte
77-
if (nTipMod > 0L) {
78-
splits[i, nbin] <- as.raw(bitwAnd(as.integer(splits[i, nbin]),
79-
as.integer(lastByteMask)))
80-
}
81-
}
82-
keep[i] <- TRUE
83-
}
84-
splits <- splits[keep, , drop = FALSE]
61+
# Normalize splits: ensure bit 0 is set, filter trivial splits
62+
normalized <- normalize_splits(splits, nTip)
63+
keep <- normalized[["keep"]]
64+
splits <- normalized[["splits"]][keep, , drop = FALSE]
8565
counts <- counts[keep]
8666
ret <- structure(splits,
8767
nTip = nTip,

benchmark/_compare_results.R

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,12 @@ for (pr_file in pr_files) {
2424
# Prepare a report
2525
report <- list()
2626

27+
# Use deparse1 for reliable expression-to-string conversion;
28+
# as.character(unlist()) decomposes call objects into their components.
29+
expr_names <- vapply(pr1[["expression"]], deparse1, "")
30+
2731
# Iterate over each function benchmarked
28-
for (fn_name in unique(as.character(unlist(pr1[["expression"]])))) {
32+
for (fn_name in unique(expr_names)) {
2933
pr1_times <- as.numeric(pr1[["time"]][[1]])
3034
pr2_times <- as.numeric(pr2[["time"]][[1]])
3135
pr_times <- if (rep_exists) c(pr1_times, pr2_times) else pr1_times
@@ -79,18 +83,21 @@ for (pr_file in pr_files) {
7983
# Create a markdown-formatted message
8084
has_significant_regression <- FALSE
8185

82-
for (fn_name in names(report)) {
83-
res <- report[[fn_name]]
84-
status <- if (res$matched) {
86+
for (i in seq_along(report)) {
87+
fn_name <- names(report)[[i]]
88+
res <- report[[i]]
89+
if (is.null(res) || is.null(res$matched)) next
90+
91+
status <- if (isTRUE(res$matched)) {
8592
if (res$slower) {
86-
if (abs(percentage_change) > threshold_percent) {
93+
if (abs(res$change) > threshold_percent) {
8794
has_significant_regression <- TRUE
8895
"\U1F7E0 Slower \U1F641"
8996
} else {
9097
"\U1F7E3 ~same"
9198
}
9299
} else if (res$faster) {
93-
if (abs(percentage_change) > threshold_percent) {
100+
if (abs(res$change) > threshold_percent) {
94101
"\U1F7E2 Faster!"
95102
} else {
96103
"\U1F7E3 ~same"
@@ -111,14 +118,14 @@ for (pr_file in pr_files) {
111118
signif(res$median_pr * 1e3, 3), ", ",
112119
signif(res$median_cf * 1e3, 3), " |\n"
113120
)
121+
122+
cat(message)
123+
output <- paste0(output, message)
114124
}
115125

116126
if (has_significant_regression) {
117127
regressions <- TRUE
118128
}
119-
120-
cat(message)
121-
output <- paste0(output, message)
122129
}
123130

124131
cat(paste0(output, "\nEOF"), file = Sys.getenv("GITHUB_OUTPUT"), append = TRUE)

benchmark/_init.R

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
library("TreeTools")
22

33
Benchmark <- function(..., min_iterations = NULL, min_time = NULL) {
4-
args <- list(..., min_iterations = min_iterations %||% 3, time_unit = "us")
5-
if (!is.null(min_time)) args[["min_time"]] <- min_time
6-
result <- do.call(bench::mark, args)
4+
# Pass ... directly to bench::mark to preserve non-standard evaluation;
5+
# do.call() would evaluate expressions first, breaking expression capture.
6+
result <- if (is.null(min_time)) {
7+
bench::mark(..., min_iterations = min_iterations %||% 3, time_unit = "us")
8+
} else {
9+
bench::mark(..., min_iterations = min_iterations %||% 3,
10+
min_time = min_time, time_unit = "us")
11+
}
712
if (interactive()) {
813
print(result)
914
} else {

src/RcppExports.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,18 @@ BEGIN_RCPP
426426
return rcpp_result_gen;
427427
END_RCPP
428428
}
429+
// normalize_splits
430+
Rcpp::List normalize_splits(Rcpp::RawMatrix splits, const int n_tip);
431+
RcppExport SEXP _TreeTools_normalize_splits(SEXP splitsSEXP, SEXP n_tipSEXP) {
432+
BEGIN_RCPP
433+
Rcpp::RObject rcpp_result_gen;
434+
Rcpp::RNGScope rcpp_rngScope_gen;
435+
Rcpp::traits::input_parameter< Rcpp::RawMatrix >::type splits(splitsSEXP);
436+
Rcpp::traits::input_parameter< const int >::type n_tip(n_tipSEXP);
437+
rcpp_result_gen = Rcpp::wrap(normalize_splits(splits, n_tip));
438+
return rcpp_result_gen;
439+
END_RCPP
440+
}
429441
// splits_to_edge
430442
IntegerMatrix splits_to_edge(const RawMatrix splits, const IntegerVector nTip);
431443
RcppExport SEXP _TreeTools_splits_to_edge(SEXP splitsSEXP, SEXP nTipSEXP) {
@@ -517,6 +529,7 @@ static const R_CallMethodDef CallEntries[] = {
517529
{"_TreeTools_pack_splits_logical", (DL_FUNC) &_TreeTools_pack_splits_logical, 1},
518530
{"_TreeTools_pack_splits_logical_vec", (DL_FUNC) &_TreeTools_pack_splits_logical_vec, 1},
519531
{"_TreeTools_cpp_count_splits", (DL_FUNC) &_TreeTools_cpp_count_splits, 2},
532+
{"_TreeTools_normalize_splits", (DL_FUNC) &_TreeTools_normalize_splits, 2},
520533
{"_TreeTools_splits_to_edge", (DL_FUNC) &_TreeTools_splits_to_edge, 2},
521534
{"_TreeTools_tips_in_splits", (DL_FUNC) &_TreeTools_tips_in_splits, 1},
522535
{"_TreeTools_edge_to_rooted_shape", (DL_FUNC) &_TreeTools_edge_to_rooted_shape, 3},

src/splits.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -714,3 +714,48 @@ int cpp_count_splits(const Rcpp::IntegerMatrix& edge, const int nTip) {
714714

715715
return (n_internal - n_singles) - 1 - (is_rooted ? 1 : 0);
716716
}
717+
718+
// Normalize splits from ClusterTable output for as.Splits convention:
719+
// - Filter trivial splits (fewer than 2 or more than nTip-2 bits set)
720+
// - Ensure bit 0 is set (complement if not)
721+
// - Mask unused trailing bits in last byte
722+
// Returns a List with "splits" (RawMatrix) and "keep" (LogicalVector)
723+
// [[Rcpp::export]]
724+
Rcpp::List normalize_splits(Rcpp::RawMatrix splits, const int n_tip) {
725+
const int n_split = splits.nrow();
726+
const int n_bin = splits.ncol();
727+
const int n_spare = n_tip % BIN_SIZE;
728+
const Rbyte last_mask = n_spare == 0
729+
? Rbyte(0xff)
730+
: static_cast<Rbyte>((1 << n_spare) - 1);
731+
732+
Rcpp::LogicalVector keep(n_split, false);
733+
734+
for (int i = 0; i < n_split; ++i) {
735+
// Count bits set
736+
int n_bits = 0;
737+
for (int j = 0; j < n_bin; ++j) {
738+
n_bits += __builtin_popcount(static_cast<unsigned>(splits(i, j)));
739+
}
740+
741+
if (n_bits < 2 || n_bits > n_tip - 2) continue; // trivial
742+
743+
// Normalize: if bit 0 is NOT set, complement
744+
if (!(splits(i, 0) & Rbyte(1))) {
745+
for (int j = 0; j < n_bin; ++j) {
746+
splits(i, j) = static_cast<Rbyte>(~splits(i, j));
747+
}
748+
// Mask trailing bits in last byte
749+
if (n_spare > 0) {
750+
splits(i, n_bin - 1) &= last_mask;
751+
}
752+
}
753+
754+
keep[i] = true;
755+
}
756+
757+
return Rcpp::List::create(
758+
Rcpp::Named("splits") = splits,
759+
Rcpp::Named("keep") = keep
760+
);
761+
}

0 commit comments

Comments
 (0)