Skip to content

Commit 1e751c5

Browse files
committed
Deploying to r-dev from @ 022c022 🚀
1 parent fbb1726 commit 1e751c5

8 files changed

Lines changed: 110 additions & 24 deletions

File tree

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## New Features
44

55
* Replaced C++ standard library distributions (`discrete_distribution`, `uniform_real_distribution`, `normal_distribution`, and `gamma_distribution`) with custom implementations for cross-platform reproducibility.
6+
* Substituted custom implementations for base R `mean()`, `var()`, and `sd()` in the preprocessing logic of the R `bart()` and `bcf()` functions for enhanced numeric stability across platforms.
67

78
## Bug Fixes
89

R/bart.R

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -854,11 +854,11 @@ bart <- function(
854854
# differently for binary and continuous outcomes
855855
if (probit_outcome_model) {
856856
# Compute a probit-scale offset and fix scale to 1
857-
y_bar_train <- qnorm(mean(y_train))
857+
y_bar_train <- qnorm(mean_cpp(as.numeric(y_train)))
858858
y_std_train <- 1
859859

860-
# Set a pseudo outcome by subtracting mean(y_train) from y_train
861-
resid_train <- y_train - mean(y_train)
860+
# Set a pseudo outcome by subtracting mean_cpp(y_train) from y_train
861+
resid_train <- y_train - mean_cpp(as.numeric(y_train))
862862

863863
# Set initial values of root nodes to 0.0 (in probit scale)
864864
init_val_mean <- 0.0
@@ -910,8 +910,8 @@ bart <- function(
910910
} else {
911911
# Only standardize if user requested
912912
if (standardize) {
913-
y_bar_train <- mean(y_train)
914-
y_std_train <- sd(y_train)
913+
y_bar_train <- mean_cpp(as.numeric(y_train))
914+
y_std_train <- sd_cpp(as.numeric(y_train))
915915
} else {
916916
y_bar_train <- 0
917917
y_std_train <- 1
@@ -921,23 +921,23 @@ bart <- function(
921921
resid_train <- (y_train - y_bar_train) / y_std_train
922922

923923
# Compute initial value of root nodes in mean forest
924-
init_val_mean <- mean(resid_train)
924+
init_val_mean <- mean_cpp(as.numeric(resid_train))
925925

926926
# Calibrate priors for sigma^2 and tau
927927
if (is.null(sigma2_init)) {
928-
sigma2_init <- 1.0 * var(resid_train)
928+
sigma2_init <- 1.0 * var_cpp(as.numeric(resid_train))
929929
}
930930
if (is.null(variance_forest_init)) {
931-
variance_forest_init <- 1.0 * var(resid_train)
931+
variance_forest_init <- 1.0 * var_cpp(as.numeric(resid_train))
932932
}
933933
if (is.null(b_leaf)) {
934-
b_leaf <- var(resid_train) / (2 * num_trees_mean)
934+
b_leaf <- var_cpp(as.numeric(resid_train)) / (2 * num_trees_mean)
935935
}
936936
if (has_basis) {
937937
if (ncol(leaf_basis_train) > 1) {
938938
if (is.null(sigma2_leaf_init)) {
939939
sigma2_leaf_init <- diag(
940-
2 * var(resid_train) / (num_trees_mean),
940+
2 * var_cpp(as.numeric(resid_train)) / (num_trees_mean),
941941
ncol(leaf_basis_train)
942942
)
943943
}
@@ -952,7 +952,7 @@ bart <- function(
952952
} else {
953953
if (is.null(sigma2_leaf_init)) {
954954
sigma2_leaf_init <- as.matrix(
955-
2 * var(resid_train) / (num_trees_mean)
955+
2 * var_cpp(as.numeric(resid_train)) / (num_trees_mean)
956956
)
957957
}
958958
if (!is.matrix(sigma2_leaf_init)) {
@@ -964,7 +964,7 @@ bart <- function(
964964
} else {
965965
if (is.null(sigma2_leaf_init)) {
966966
sigma2_leaf_init <- as.matrix(
967-
2 * var(resid_train) / (num_trees_mean)
967+
2 * var_cpp(as.numeric(resid_train)) / (num_trees_mean)
968968
)
969969
}
970970
if (!is.matrix(sigma2_leaf_init)) {

R/bcf.R

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1207,11 +1207,11 @@ bcf <- function(
12071207
# differently for binary and continuous outcomes
12081208
if (probit_outcome_model) {
12091209
# Compute a probit-scale offset and fix scale to 1
1210-
y_bar_train <- qnorm(mean(y_train))
1210+
y_bar_train <- qnorm(mean_cpp(as.numeric(y_train)))
12111211
y_std_train <- 1
12121212

1213-
# Set a pseudo outcome by subtracting mean(y_train) from y_train
1214-
resid_train <- y_train - mean(y_train)
1213+
# Set a pseudo outcome by subtracting mean_cpp(y_train) from y_train
1214+
resid_train <- y_train - mean_cpp(as.numeric(y_train))
12151215

12161216
# Set initial value for the mu forest
12171217
init_mu <- 0.0
@@ -1274,8 +1274,8 @@ bcf <- function(
12741274
} else {
12751275
# Only standardize if user requested
12761276
if (standardize) {
1277-
y_bar_train <- mean(y_train)
1278-
y_std_train <- sd(y_train)
1277+
y_bar_train <- mean_cpp(as.numeric(y_train))
1278+
y_std_train <- sd_cpp(as.numeric(y_train))
12791279
} else {
12801280
y_bar_train <- 0
12811281
y_std_train <- 1
@@ -1285,23 +1285,23 @@ bcf <- function(
12851285
resid_train <- (y_train - y_bar_train) / y_std_train
12861286

12871287
# Set initial value for the mu forest
1288-
init_mu <- mean(resid_train)
1288+
init_mu <- mean_cpp(as.numeric(resid_train))
12891289

12901290
# Calibrate priors for global sigma^2 and sigma2_leaf_mu / sigma2_leaf_tau
12911291
if (is.null(sigma2_init)) {
1292-
sigma2_init <- 1.0 * var(resid_train)
1292+
sigma2_init <- 1.0 * var_cpp(as.numeric(resid_train))
12931293
}
12941294
if (is.null(variance_forest_init)) {
1295-
variance_forest_init <- 1.0 * var(resid_train)
1295+
variance_forest_init <- 1.0 * var_cpp(as.numeric(resid_train))
12961296
}
12971297
if (is.null(b_leaf_mu)) {
1298-
b_leaf_mu <- var(resid_train) / (num_trees_mu)
1298+
b_leaf_mu <- var_cpp(as.numeric(resid_train)) / (num_trees_mu)
12991299
}
13001300
if (is.null(b_leaf_tau)) {
1301-
b_leaf_tau <- var(resid_train) / (2 * num_trees_tau)
1301+
b_leaf_tau <- var_cpp(as.numeric(resid_train)) / (2 * num_trees_tau)
13021302
}
13031303
if (is.null(sigma2_leaf_mu)) {
1304-
sigma2_leaf_mu <- 2.0 * var(resid_train) / (num_trees_mu)
1304+
sigma2_leaf_mu <- 2.0 * var_cpp(as.numeric(resid_train)) / (num_trees_mu)
13051305
current_leaf_scale_mu <- as.matrix(sigma2_leaf_mu)
13061306
} else {
13071307
if (!is.matrix(sigma2_leaf_mu)) {
@@ -1311,7 +1311,7 @@ bcf <- function(
13111311
}
13121312
}
13131313
if (is.null(sigma2_leaf_tau)) {
1314-
sigma2_leaf_tau <- var(resid_train) / (num_trees_tau)
1314+
sigma2_leaf_tau <- var_cpp(as.numeric(resid_train)) / (num_trees_tau)
13151315
current_leaf_scale_tau <- as.matrix(diag(
13161316
sigma2_leaf_tau,
13171317
ncol(Z_train)

R/cpp11.R

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,22 @@ root_reset_rfx_tracker_cpp <- function(tracker, dataset, residual, rfx_model) {
276276
invisible(.Call(`_stochtree_root_reset_rfx_tracker_cpp`, tracker, dataset, residual, rfx_model))
277277
}
278278

279+
sum_cpp <- function(x) {
280+
.Call(`_stochtree_sum_cpp`, x)
281+
}
282+
283+
mean_cpp <- function(x) {
284+
.Call(`_stochtree_mean_cpp`, x)
285+
}
286+
287+
var_cpp <- function(x) {
288+
.Call(`_stochtree_var_cpp`, x)
289+
}
290+
291+
sd_cpp <- function(x) {
292+
.Call(`_stochtree_sd_cpp`, x)
293+
}
294+
279295
active_forest_cpp <- function(num_trees, output_dimension, is_leaf_constant, is_exponentiated) {
280296
.Call(`_stochtree_active_forest_cpp`, num_trees, output_dimension, is_leaf_constant, is_exponentiated)
281297
}

src/Makevars.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ OBJECTS = \
2222
kernel.o \
2323
R_data.o \
2424
R_random_effects.o \
25+
R_utils.o \
2526
sampler.o \
2627
serialization.o \
2728
cpp11.o \

src/Makevars.win.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ OBJECTS = \
2323
kernel.o \
2424
R_data.o \
2525
R_random_effects.o \
26+
R_utils.o \
2627
sampler.o \
2728
serialization.o \
2829
cpp11.o \

src/R_utils.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#include <cpp11.hpp>
2+
#include <cmath>
3+
4+
[[cpp11::register]]
5+
double sum_cpp(cpp11::doubles x) {
6+
double output = 0.0;
7+
for (int i = 0; i < x.size(); i++) {
8+
output += x[i];
9+
}
10+
return output;
11+
}
12+
13+
[[cpp11::register]]
14+
double mean_cpp(cpp11::doubles x) {
15+
double output = 0.0;
16+
for (int i = 0; i < x.size(); i++) {
17+
output += x[i];
18+
}
19+
return output / x.size();
20+
}
21+
22+
[[cpp11::register]]
23+
double var_cpp(cpp11::doubles x) {
24+
double mean = mean_cpp(x);
25+
double output = 0.0;
26+
for (int i = 0; i < x.size(); i++) {
27+
output += (x[i] - mean) * (x[i] - mean);
28+
}
29+
return output / (x.size() - 1);
30+
}
31+
32+
[[cpp11::register]]
33+
double sd_cpp(cpp11::doubles x) {
34+
return std::sqrt(var_cpp(x));
35+
}

src/cpp11.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,34 @@ extern "C" SEXP _stochtree_root_reset_rfx_tracker_cpp(SEXP tracker, SEXP dataset
516516
return R_NilValue;
517517
END_CPP11
518518
}
519+
// R_utils.cpp
520+
double sum_cpp(cpp11::doubles x);
521+
extern "C" SEXP _stochtree_sum_cpp(SEXP x) {
522+
BEGIN_CPP11
523+
return cpp11::as_sexp(sum_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(x)));
524+
END_CPP11
525+
}
526+
// R_utils.cpp
527+
double mean_cpp(cpp11::doubles x);
528+
extern "C" SEXP _stochtree_mean_cpp(SEXP x) {
529+
BEGIN_CPP11
530+
return cpp11::as_sexp(mean_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(x)));
531+
END_CPP11
532+
}
533+
// R_utils.cpp
534+
double var_cpp(cpp11::doubles x);
535+
extern "C" SEXP _stochtree_var_cpp(SEXP x) {
536+
BEGIN_CPP11
537+
return cpp11::as_sexp(var_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(x)));
538+
END_CPP11
539+
}
540+
// R_utils.cpp
541+
double sd_cpp(cpp11::doubles x);
542+
extern "C" SEXP _stochtree_sd_cpp(SEXP x) {
543+
BEGIN_CPP11
544+
return cpp11::as_sexp(sd_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(x)));
545+
END_CPP11
546+
}
519547
// forest.cpp
520548
cpp11::external_pointer<StochTree::TreeEnsemble> active_forest_cpp(int num_trees, int output_dimension, bool is_leaf_constant, bool is_exponentiated);
521549
extern "C" SEXP _stochtree_active_forest_cpp(SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP is_exponentiated) {
@@ -1700,6 +1728,7 @@ static const R_CallMethodDef CallEntries[] = {
17001728
{"_stochtree_leaf_values_forest_container_cpp", (DL_FUNC) &_stochtree_leaf_values_forest_container_cpp, 4},
17011729
{"_stochtree_leaves_forest_container_cpp", (DL_FUNC) &_stochtree_leaves_forest_container_cpp, 3},
17021730
{"_stochtree_left_child_node_forest_container_cpp", (DL_FUNC) &_stochtree_left_child_node_forest_container_cpp, 4},
1731+
{"_stochtree_mean_cpp", (DL_FUNC) &_stochtree_mean_cpp, 1},
17031732
{"_stochtree_multiply_forest_forest_container_cpp", (DL_FUNC) &_stochtree_multiply_forest_forest_container_cpp, 3},
17041733
{"_stochtree_node_depth_forest_container_cpp", (DL_FUNC) &_stochtree_node_depth_forest_container_cpp, 4},
17051734
{"_stochtree_nodes_forest_container_cpp", (DL_FUNC) &_stochtree_nodes_forest_container_cpp, 3},
@@ -1781,6 +1810,7 @@ static const R_CallMethodDef CallEntries[] = {
17811810
{"_stochtree_sample_sigma2_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_sigma2_one_iteration_cpp, 5},
17821811
{"_stochtree_sample_tau_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_tau_one_iteration_cpp, 4},
17831812
{"_stochtree_sample_without_replacement_integer_cpp", (DL_FUNC) &_stochtree_sample_without_replacement_integer_cpp, 3},
1813+
{"_stochtree_sd_cpp", (DL_FUNC) &_stochtree_sd_cpp, 1},
17841814
{"_stochtree_set_leaf_value_active_forest_cpp", (DL_FUNC) &_stochtree_set_leaf_value_active_forest_cpp, 2},
17851815
{"_stochtree_set_leaf_value_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_value_forest_container_cpp, 2},
17861816
{"_stochtree_set_leaf_vector_active_forest_cpp", (DL_FUNC) &_stochtree_set_leaf_vector_active_forest_cpp, 2},
@@ -1789,12 +1819,14 @@ static const R_CallMethodDef CallEntries[] = {
17891819
{"_stochtree_split_index_forest_container_cpp", (DL_FUNC) &_stochtree_split_index_forest_container_cpp, 4},
17901820
{"_stochtree_split_theshold_forest_container_cpp", (DL_FUNC) &_stochtree_split_theshold_forest_container_cpp, 4},
17911821
{"_stochtree_subtract_from_column_vector_cpp", (DL_FUNC) &_stochtree_subtract_from_column_vector_cpp, 2},
1822+
{"_stochtree_sum_cpp", (DL_FUNC) &_stochtree_sum_cpp, 1},
17921823
{"_stochtree_sum_leaves_squared_ensemble_forest_container_cpp", (DL_FUNC) &_stochtree_sum_leaves_squared_ensemble_forest_container_cpp, 2},
17931824
{"_stochtree_tree_prior_cpp", (DL_FUNC) &_stochtree_tree_prior_cpp, 4},
17941825
{"_stochtree_update_alpha_tree_prior_cpp", (DL_FUNC) &_stochtree_update_alpha_tree_prior_cpp, 2},
17951826
{"_stochtree_update_beta_tree_prior_cpp", (DL_FUNC) &_stochtree_update_beta_tree_prior_cpp, 2},
17961827
{"_stochtree_update_max_depth_tree_prior_cpp", (DL_FUNC) &_stochtree_update_max_depth_tree_prior_cpp, 2},
17971828
{"_stochtree_update_min_samples_leaf_tree_prior_cpp", (DL_FUNC) &_stochtree_update_min_samples_leaf_tree_prior_cpp, 2},
1829+
{"_stochtree_var_cpp", (DL_FUNC) &_stochtree_var_cpp, 1},
17981830
{NULL, NULL, 0}
17991831
};
18001832
}

0 commit comments

Comments
 (0)