22#define _TREETOOLS_SPLITLIST_H
33
44#include < Rcpp/Lightest>
5-
65#include < stdexcept> /* for errors */
6+ #include < vector> /* for heap allocation */
7+ #include < algorithm> /* for std::fill */
78
89#include " assert.h" /* for ASSERT */
9- #include " types.h" /* for int16 */
10+ #include " types.h" /* for int16, int32 */
1011
1112using splitbit = uint_fast64_t ;
1213
1314#define R_BIN_SIZE int16 (8 )
1415#define SL_BIN_SIZE int16 (64 )
1516#define SL_MAX_BINS int16 (32 )
16- /* 64*32 is about the largest size for which two SplitList objects reliably fit
17- * on the stack (as required in TreeDist; supporting more leaves would mean
18- * refactoring to run on the heap (and, trivially, converting int16 to int32
19- * for split*bin implicit calculation in state[split][bin]?) */
20- #define SL_MAX_TIPS (SL_BIN_SIZE * SL_MAX_BINS) // 32 * 64 = 2048
21- #define SL_MAX_SPLITS (SL_MAX_TIPS - 3 ) /* no slower than a power of two */
17+
18+ /* * Stack allocation limits (Legacy support for speed)
19+ * Trees smaller than this will use stack arrays.
20+ * Trees larger will trigger heap allocation.
21+ */
22+ #define SL_MAX_TIPS (SL_BIN_SIZE * SL_MAX_BINS) // 2048
23+ #define SL_MAX_SPLITS (SL_MAX_TIPS - 3 )
2224
2325#define INLASTBIN (n, size ) int16((size) - int16((size) - int16((n) % (size))) % (size))
2426#define INSUBBIN (bin, offset ) \
@@ -38,31 +40,29 @@ namespace TreeTools {
3840
3941#if __cplusplus >= 202002L
4042#include < bit> // C++20 header for std::popcount
41-
42- inline int16 count_bits (splitbit x) {
43- return static_cast <int16>(std::popcount (x));
43+ inline int32 count_bits (splitbit x) {
44+ return static_cast <int32>(std::popcount (x));
4445 }
45-
4646 // Option 2: Fallback for C++17 and older
4747#else
4848#if defined(__GNUC__) || defined(__clang__)
4949 // GCC and Clang support __builtin_popcountll for long long
50- inline int16 count_bits (splitbit x) {
51- return static_cast <int16 >(__builtin_popcountll (x));
50+ inline int32 count_bits (splitbit x) {
51+ return static_cast <int32 >(__builtin_popcountll (x));
5252 }
5353#elif defined(_MSC_VER)
5454#include < intrin.h>
55- inline int16 count_bits (splitbit x) {
56- return static_cast <int16 >(__popcnt64 (x));
55+ inline int32 count_bits (splitbit x) {
56+ return static_cast <int32 >(__popcnt64 (x));
5757 }
5858#else
5959 // A slower, but safe and highly portable fallback for all other compilers
6060 // This is a last resort if no built-in is available.
61- inline int16_t count_bits (splitbit x) {
62- int16_t count = 0 ;
61+ inline int32_t count_bits (splitbit x) {
62+ int32_t count = 0 ;
6363 while (x != 0 ) {
6464 x &= (x - 1 );
65- count++ ;
65+ ++count ;
6666 }
6767 return count;
6868 }
@@ -72,45 +72,73 @@ namespace TreeTools {
7272
7373 class SplitList {
7474 public:
75- int16 n_splits, n_bins;
76- int16 in_split[SL_MAX_SPLITS];
77- splitbit state[SL_MAX_SPLITS][SL_MAX_BINS];
75+ int32 n_splits;
76+ int32 n_bins;
77+ int32* in_split;
78+ splitbit** state;
79+
80+ private:
81+ /* STACK STORAGE (Fast path for small trees) */
82+ int32 stack_in_split[SL_MAX_SPLITS];
83+ splitbit stack_state[SL_MAX_SPLITS][SL_MAX_BINS];
84+ splitbit* stack_rows[SL_MAX_SPLITS];
85+
86+ /* HEAP STORAGE (Large trees) */
87+ std::vector<int32> heap_in_split;
88+ std::vector<splitbit> heap_data;
89+ std::vector<splitbit*> heap_rows;
90+
91+ public:
7892 SplitList (const Rcpp::RawMatrix &x) {
79- if (double (x.rows ()) > double (std::numeric_limits<int16>::max ())) {
80- Rcpp::stop (" This many splits cannot be supported. "
81- " Please contact the TreeTools maintainer if "
82- " you need to use more!" );
83- }
84- if (double (x.cols ()) > double (std::numeric_limits<int16>::max ())) {
85- Rcpp::stop (" This many leaves cannot be supported. "
86- " Please contact the TreeTools maintainer if "
87- " you need to use more!" );
93+
94+ const double n_rows = static_cast <double >(x.rows ());
95+
96+ /* Check limits */
97+ if (n_rows > static_cast <double >(std::numeric_limits<int32>::max ())) {
98+ Rcpp::stop (" Too many splits (exceeds int32 limit)." ); // #nocov
8899 }
89100
90- n_splits = int16 (x.rows ());
101+ n_splits = int32 (x.rows ());
91102 ASSERT (n_splits >= 0 );
92103
93- const int16 n_input_bins = int16 (x.cols ());
94-
95- n_bins = int16 (n_input_bins + R_BIN_SIZE - 1 ) / input_bins_per_bin;
96-
97- if (n_bins > SL_MAX_BINS) {
98- Rcpp::stop (" This many leaves cannot be supported. "
99- " Please contact the TreeTools maintainer if "
100- " you need to use more!" );
101- }
104+ const int32 n_input_bins = int32 (x.cols ());
105+ ASSERT (n_input_bins > 0 );
106+ n_bins = int32 (n_input_bins + R_BIN_SIZE - 1 ) / input_bins_per_bin;
107+
108+ bool use_heap = (n_splits > SL_MAX_SPLITS) || (n_bins > SL_MAX_BINS);
102109
103- for (int16 split = 0 ; split < n_splits; ++split) {
104- in_split[split] = 0 ;
110+ if (use_heap) {
111+ heap_in_split.resize (n_splits, 0 );
112+ in_split = heap_in_split.data ();
113+
114+ size_t total_elements = static_cast <size_t >(n_splits) *
115+ static_cast <size_t >(n_bins);
116+ heap_data.resize (total_elements);
117+
118+ heap_rows.resize (n_splits);
119+ for (int32 i = 0 ; i < n_splits; ++i) {
120+ heap_rows[i] = &heap_data[i * n_bins];
121+ }
122+ state = heap_rows.data ();
123+
124+ } else {
125+ in_split = stack_in_split;
126+
127+ for (int32 i = 0 ; i < n_splits; ++i) {
128+ stack_rows[i] = stack_state[i];
129+ in_split[i] = 0 ;
130+ }
131+ state = stack_rows;
105132 }
106133
107- for (int16 bin = 0 ; bin < n_bins - 1 ; ++bin) {
108- const int16 bin_offset = bin * input_bins_per_bin;
134+
135+ for (int32 bin = 0 ; bin < n_bins - 1 ; ++bin) {
136+ const int32 bin_offset = bin * input_bins_per_bin;
109137
110- for (int16 split = 0 ; split < n_splits; ++split) {
138+ for (int32 split = 0 ; split < n_splits; ++split) {
111139 splitbit combined = splitbit (x (split, bin_offset));
112140
113- for (int16 input_bin = 1 ; input_bin < input_bins_per_bin; ++input_bin) {
141+ for (int32 input_bin = 1 ; input_bin < input_bins_per_bin; ++input_bin) {
114142 combined |= splitbit (x (split, bin_offset + input_bin)) <<
115143 (R_BIN_SIZE * input_bin);
116144 }
@@ -120,19 +148,26 @@ namespace TreeTools {
120148 }
121149 }
122150
123- const int16 last_bin = n_bins - 1 ;
124- const int16 raggedy_bins = INLASTBIN (n_input_bins, R_BIN_SIZE);
151+ const int32 last_bin = n_bins - 1 ;
152+ const int32 raggedy_bins = INLASTBIN (n_input_bins, R_BIN_SIZE);
125153
126- for (int16 split = 0 ; split < n_splits; ++split) {
154+ for (int32 split = 0 ; split < n_splits; ++split) {
127155 state[split][last_bin] = INSUBBIN (last_bin, 0 );
128156
129- for (int16 input_bin = 1 ; input_bin < raggedy_bins; ++input_bin) {
157+ for (int32 input_bin = 1 ; input_bin < raggedy_bins; ++input_bin) {
130158 state[split][last_bin] += INBIN (input_bin, last_bin);
131159 }
132160
133161 in_split[split] += count_bits (state[split][last_bin]);
134162 }
135163 }
164+
165+ // Default destructor handles vector cleanup automatically
166+ ~SplitList () = default ;
167+
168+ // Disable copy/move to prevent pointer invalidation issues
169+ SplitList (const SplitList&) = delete ;
170+ SplitList& operator =(const SplitList&) = delete ;
136171 };
137172}
138173
0 commit comments