diff --git a/CHANGELOG.md b/CHANGELOG.md index a57bb5b9..f6944653 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ ## Bug Fixes +* Fixed multi-chain BCF bugs with the parametric intercept term in R and Python [#326](https://github.com/StochasticTree/stochtree/pull/326) +* Fixed indexing bugs for multivariate treatment BCF in Python [#326](https://github.com/StochasticTree/stochtree/pull/326) + ## Documentation and Other Maintenance # stochtree 0.4.1 diff --git a/NEWS.md b/NEWS.md index ebfc8d91..5aafe8fd 100644 --- a/NEWS.md +++ b/NEWS.md @@ -6,6 +6,9 @@ ## Bug Fixes +* Fixed multi-chain BCF bugs with the parametric intercept term in R and Python [#326](https://github.com/StochasticTree/stochtree/pull/326) +* Fixed indexing bugs for multivariate treatment BCF in Python [#326](https://github.com/StochasticTree/stochtree/pull/326) + ## Documentation and Other Maintenance # stochtree 0.4.1 diff --git a/R/bcf.R b/R/bcf.R index 93663faf..1597e4c3 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -583,6 +583,11 @@ bcf <- function( previous_b_1_samples <- NULL previous_b_0_samples <- NULL } + if (previous_bcf_model$model_params$sample_tau_0) { + previous_tau_0_samples <- previous_bcf_model$tau_0_samples + } else { + previous_tau_0_samples <- NULL + } previous_model_num_samples <- previous_bcf_model$model_params$num_samples if (previous_model_warmstart_sample_num > previous_model_num_samples) { stop( @@ -601,6 +606,7 @@ bcf <- function( previous_forest_samples_variance <- NULL previous_b_1_samples <- NULL previous_b_0_samples <- NULL + previous_tau_0_samples <- NULL } # Determine whether conditional variance will be modeled @@ -2162,6 +2168,7 @@ bcf <- function( ) } if (adaptive_coding) { + tau_basis_train_old <- tau_basis_train current_b_1 <- b_1_samples[forest_ind + 1] current_b_0 <- b_0_samples[forest_ind + 1] tau_basis_train <- (1 - Z_train) * @@ -2179,6 +2186,21 @@ bcf <- function( outcome_train, active_forest_tau ) + # Correct residual for tau_0 component of the basis change + if (sample_tau_0) { + outcome_train$subtract_vector( + as.numeric((tau_basis_train - tau_basis_train_old) * tau_0[1]) + ) + } + } + # Reset tau_0 intercept and correct the running residual + if (sample_tau_0) { + tau_0_old <- tau_0 + tau_0 <- tau_0_samples[, forest_ind + 1] + Z_basis_gfr <- as.matrix(tau_basis_train) + outcome_train$subtract_vector( + as.numeric(Z_basis_gfr %*% matrix(tau_0 - tau_0_old, ncol = 1)) + ) } if (sample_sigma2_global) { current_sigma2 <- global_var_samples[forest_ind + 1] @@ -2255,6 +2277,7 @@ bcf <- function( ) } if (adaptive_coding) { + tau_basis_train_old <- tau_basis_train if (!is.null(previous_b_1_samples)) { current_b_1 <- previous_b_1_samples[ warmstart_index @@ -2280,6 +2303,22 @@ bcf <- function( outcome_train, active_forest_tau ) + # Correct residual for tau_0 component of the basis change + if (sample_tau_0) { + outcome_train$subtract_vector( + as.numeric((tau_basis_train - tau_basis_train_old) * tau_0[1]) + ) + } + } + # Reset tau_0 intercept and correct the running residual + if (sample_tau_0 && !is.null(previous_tau_0_samples)) { + tau_0_old <- tau_0 + # previous model stores tau_0 in original scale; convert to standardized scale + tau_0 <- as.numeric(previous_tau_0_samples[, warmstart_index] / previous_y_scale) + Z_basis_ws <- as.matrix(tau_basis_train) + outcome_train$subtract_vector( + as.numeric(Z_basis_ws %*% matrix(tau_0 - tau_0_old, ncol = 1)) + ) } if (has_rfx) { if (is.null(previous_rfx_samples)) { @@ -2389,6 +2428,7 @@ bcf <- function( ) } if (adaptive_coding) { + tau_basis_train_old <- tau_basis_train current_b_1 <- b_1 current_b_0 <- b_0 tau_basis_train <- (1 - Z_train) * @@ -2406,6 +2446,21 @@ bcf <- function( outcome_train, active_forest_tau ) + # Correct residual for tau_0 component of the basis change + if (sample_tau_0) { + outcome_train$subtract_vector( + as.numeric((tau_basis_train - tau_basis_train_old) * tau_0[1]) + ) + } + } + # Reset tau_0 to initial value (0) and correct the running residual + if (sample_tau_0) { + tau_0_old <- tau_0 + tau_0 <- rep(0.0, p_tau0) + Z_basis_reset <- as.matrix(tau_basis_train) + outcome_train$subtract_vector( + as.numeric(Z_basis_reset %*% matrix(tau_0 - tau_0_old, ncol = 1)) + ) } if (sample_sigma2_global) { current_sigma2 <- sigma2_init @@ -4314,7 +4369,27 @@ extractParameter.bcfmodel <- function(object, term) { } } - if (term %in% c("tau_hat_train")) { + if (term %in% c("mu_hat_train", "prognostic_function_train")) { + if (!is.null(object$mu_hat_train)) { + return(object$mu_hat_train) + } else { + stop( + "This model does not have in-sample prognostic function predictions" + ) + } + } + + if (term %in% c("mu_hat_test", "prognostic_function_test")) { + if (!is.null(object$mu_hat_test)) { + return(object$mu_hat_test) + } else { + stop( + "This model does not have test set prognostic function predictions" + ) + } + } + + if (term %in% c("tau_hat_train", "cate_train")) { if (!is.null(object$tau_hat_train)) { return(object$tau_hat_train) } else { @@ -4324,7 +4399,7 @@ extractParameter.bcfmodel <- function(object, term) { } } - if (term %in% c("tau_hat_test")) { + if (term %in% c("tau_hat_test", "cate_test")) { if (!is.null(object$tau_hat_test)) { return(object$tau_hat_test) } else { diff --git a/stochtree/bcf.py b/stochtree/bcf.py index 34802b7e..735154b6 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -1637,7 +1637,7 @@ def sample( else: raise ValueError("sigma2_leaf_mu must be a scalar") sigma2_leaf_tau = ( - np.squeeze(np.var(resid_train) * 0.5) / (num_trees_tau) + np.squeeze(0.5 * np.var(resid_train)) / (num_trees_tau) if sigma2_leaf_tau is None else sigma2_leaf_tau ) @@ -2306,6 +2306,7 @@ def sample( ) # Reset adaptive coding parameters if self.adaptive_coding: + tau_basis_train_old = tau_basis_train.copy() if self.b0_samples is not None: current_b_0 = self.b0_samples[forest_ind] else: @@ -2326,6 +2327,23 @@ def sample( forest_sampler_tau.propagate_basis_update( forest_dataset_train, residual_train, active_forest_tau ) + # Correct residual for tau_0 component of the basis change + if self.sample_tau_0: + residual_train.add_vector( + -(np.squeeze(tau_basis_train) - np.squeeze(tau_basis_train_old)) * tau_0[0] + ) + # Reset tau_0 intercept and correct the running residual + if self.sample_tau_0: + tau_0_old = tau_0.copy() + tau_0 = self.tau_0_samples[:, forest_ind].copy() + Z_basis_gfr = ( + tau_basis_train.reshape(-1, 1) + if tau_basis_train.ndim == 1 + else tau_basis_train + ) + residual_train.add_vector( + -np.squeeze(Z_basis_gfr @ (tau_0 - tau_0_old)) + ) # Reset random effects terms if self.has_rfx: rfx_model.reset( @@ -2405,6 +2423,7 @@ def sample( ) # Reset adaptive coding parameters if self.adaptive_coding: + tau_basis_train_old = tau_basis_train.copy() if previous_b0_samples is not None: current_b_0 = previous_b0_samples[warmstart_index] if previous_b1_samples is not None: @@ -2421,6 +2440,26 @@ def sample( forest_sampler_tau.propagate_basis_update( forest_dataset_train, residual_train, active_forest_tau ) + # Correct residual for tau_0 component of the basis change + if self.sample_tau_0: + residual_train.add_vector( + -(np.squeeze(tau_basis_train) - np.squeeze(tau_basis_train_old)) * tau_0[0] + ) + # Reset tau_0 intercept and correct the running residual + if self.sample_tau_0: + prev_tau_0_samples = getattr(previous_bcf_model, "tau_0_samples", None) + if prev_tau_0_samples is not None: + tau_0_old = tau_0.copy() + # tau_0_samples in previous model are in original scale; convert back + tau_0 = (prev_tau_0_samples[:, warmstart_index] / previous_bcf_model.y_std).copy() + Z_basis_ws = ( + tau_basis_train.reshape(-1, 1) + if tau_basis_train.ndim == 1 + else tau_basis_train + ) + residual_train.add_vector( + -np.squeeze(Z_basis_ws @ (tau_0 - tau_0_old)) + ) # Reset random effects terms if self.has_rfx: rfx_model.reset( @@ -2495,6 +2534,7 @@ def sample( ) # Reset adaptive coding parameters if self.adaptive_coding: + tau_basis_train_old = tau_basis_train.copy() current_b_0 = b_0 current_b_1 = b_1 tau_basis_train = ( @@ -2509,6 +2549,24 @@ def sample( forest_sampler_tau.propagate_basis_update( forest_dataset_train, residual_train, active_forest_tau ) + # Correct residual for tau_0 component of the basis change + if self.sample_tau_0: + residual_train.add_vector( + -(np.squeeze(tau_basis_train) - np.squeeze(tau_basis_train_old)) + * tau_0[0] + ) + # Reset tau_0 to initial value (0) and correct the running residual + if self.sample_tau_0: + tau_0_old = tau_0.copy() + tau_0 = np.zeros_like(tau_0) + Z_basis_reset = ( + tau_basis_train.reshape(-1, 1) + if tau_basis_train.ndim == 1 + else tau_basis_train + ) + residual_train.add_vector( + -np.squeeze(Z_basis_reset @ (tau_0 - tau_0_old)) + ) # Reset random effects terms if self.has_rfx: rfx_model.root_reset( @@ -3351,17 +3409,11 @@ def predict( if predict_mu_forest: mu_x = np.mean(mu_x, axis=1) if predict_tau_forest: - if Z.shape[1] > 1: - tau_x = np.mean(tau_x, axis=2) - else: - tau_x = np.mean(tau_x, axis=1) + tau_x = np.mean(tau_x, axis=1) if predict_prog_function: prognostic_function = np.mean(prognostic_function, axis=1) if predict_cate_function: - if Z.shape[1] > 1: - cate = np.mean(cate, axis=2) - else: - cate = np.mean(cate, axis=1) + cate = np.mean(cate, axis=1) if predict_rfx: rfx_preds = np.mean(rfx_preds, axis=1) if predict_y_hat: @@ -4524,20 +4576,34 @@ def extract_parameter(self, term: str) -> np.array: else: raise ValueError("This model does not have test set mean function prediction samples") - if term in ["tau_hat_train"]: + if term in ["tau_hat_train", "cate_train"]: tht = getattr(self, "tau_hat_train", None) if tht is not None: return tht else: raise ValueError("This model does not have in-sample treatment effect forest predictions") - if term in ["tau_hat_test"]: + if term in ["tau_hat_test", "cate_test"]: tht = getattr(self, "tau_hat_test", None) if tht is not None: return tht else: raise ValueError("This model does not have test set treatment effect forest predictions") + if term in ["mu_hat_train", "prognostic_function_train"]: + mht = getattr(self, "mu_hat_train", None) + if mht is not None: + return mht + else: + raise ValueError("This model does not have in-sample prognostic function predictions") + + if term in ["mu_hat_test", "prognostic_function_test"]: + mht = getattr(self, "mu_hat_test", None) + if mht is not None: + return mht + else: + raise ValueError("This model does not have test set prognostic function predictions") + if term in ["sigma2_x_train", "var_x_train"]: s2x = getattr(self, "sigma2_x_train", None) if s2x is not None: diff --git a/test/R/testthat/test-extract-parameter.R b/test/R/testthat/test-extract-parameter.R index d800028b..c4d8c640 100644 --- a/test/R/testthat/test-extract-parameter.R +++ b/test/R/testthat/test-extract-parameter.R @@ -250,11 +250,25 @@ test_that("extractParameter.bcfmodel", { yhtest <- extractParameter(bcf_base, "y_hat_test") expect_equal(dim(yhtest), c(n_test, num_mcmc)) - # tau_hat_train and tau_hat_test + # mu_hat_train / prognostic_function_train + mht <- extractParameter(bcf_base, "mu_hat_train") + expect_equal(dim(mht), c(n_train, num_mcmc)) + expect_equal(mht, extractParameter(bcf_base, "prognostic_function_train")) + + # mu_hat_test / prognostic_function_test + mhtest <- extractParameter(bcf_base, "mu_hat_test") + expect_equal(dim(mhtest), c(n_test, num_mcmc)) + expect_equal(mhtest, extractParameter(bcf_base, "prognostic_function_test")) + + # tau_hat_train / cate_train tht <- extractParameter(bcf_base, "tau_hat_train") expect_equal(dim(tht), c(n_train, num_mcmc)) + expect_equal(tht, extractParameter(bcf_base, "cate_train")) + + # tau_hat_test / cate_test thtest <- extractParameter(bcf_base, "tau_hat_test") expect_equal(dim(thtest), c(n_test, num_mcmc)) + expect_equal(thtest, extractParameter(bcf_base, "cate_test")) # sigma2_x_train / var_x_train and sigma2_x_test / var_x_test (variance forest) bcf_var <- bcf( diff --git a/test/R/testthat/test-multi-chain.R b/test/R/testthat/test-multi-chain.R new file mode 100644 index 00000000..21570994 --- /dev/null +++ b/test/R/testthat/test-multi-chain.R @@ -0,0 +1,409 @@ +# Tests for multi-chain BART and BCF sampling. +# +# Covers sample-count correctness, GFR warm-start path, chain independence, +# extractParameter dimensions, serialization round-trip, and the +# num_gfr >= num_chains validation. + +# --------------------------------------------------------------------------- +# Shared test data helpers +# --------------------------------------------------------------------------- + +.make_bart_data <- function() { + set.seed(42) + n <- 200; p <- 5 + X <- matrix(runif(n * p), ncol = p) + y <- 5 * X[, 1] + rnorm(n) + test_inds <- sort(sample(1:n, 40)) + train_inds <- setdiff(1:n, test_inds) + list( + X_train = X[train_inds, ], + X_test = X[test_inds, ], + y_train = y[train_inds], + n_train = length(train_inds), + n_test = length(test_inds) + ) +} + +.make_bcf_data <- function() { + set.seed(42) + n <- 200; p <- 5 + X <- matrix(runif(n * p), ncol = p) + pi_X <- 0.25 + 0.5 * X[, 1] + Z <- rbinom(n, 1, pi_X) + y <- 5 * X[, 1] + 2 * X[, 2] * Z + rnorm(n) + test_inds <- sort(sample(1:n, 40)) + train_inds <- setdiff(1:n, test_inds) + list( + X_train = X[train_inds, ], + X_test = X[test_inds, ], + Z_train = Z[train_inds], + Z_test = Z[test_inds], + y_train = y[train_inds], + pi_train = pi_X[train_inds], + pi_test = pi_X[test_inds], + n_train = length(train_inds), + n_test = length(test_inds) + ) +} + +# --------------------------------------------------------------------------- +# BARTModel multi-chain tests +# --------------------------------------------------------------------------- + +test_that("BART multi-chain: sample counts with no GFR", { + skip_on_cran() + d <- .make_bart_data() + n_chains <- 3; n_mcmc <- 10 + m <- bart( + X_train = d$X_train, y_train = d$y_train, X_test = d$X_test, + num_gfr = 0, num_burnin = 0, num_mcmc = n_mcmc, + general_params = list(num_chains = n_chains, num_threads = 1) + ) + expected <- n_chains * n_mcmc + expect_length(m$sigma2_global_samples, expected) + expect_equal(dim(m$y_hat_train), c(d$n_train, expected)) + expect_equal(dim(m$y_hat_test), c(d$n_test, expected)) +}) + +test_that("BART multi-chain: sample counts with GFR warm-start", { + skip_on_cran() + d <- .make_bart_data() + n_chains <- 3; n_mcmc <- 10; n_gfr <- 6 + m <- bart( + X_train = d$X_train, y_train = d$y_train, X_test = d$X_test, + num_gfr = n_gfr, num_burnin = 5, num_mcmc = n_mcmc, + general_params = list(num_chains = n_chains, num_threads = 1) + ) + expected <- n_chains * n_mcmc + expect_length(m$sigma2_global_samples, expected) + expect_equal(dim(m$y_hat_train), c(d$n_train, expected)) +}) + +test_that("BART multi-chain: leaf-scale sample count", { + skip_on_cran() + d <- .make_bart_data() + n_chains <- 3; n_mcmc <- 10 + m <- bart( + X_train = d$X_train, y_train = d$y_train, X_test = d$X_test, + num_gfr = 0, num_burnin = 0, num_mcmc = n_mcmc, + general_params = list( + num_chains = n_chains, num_threads = 1, + sample_sigma2_global = FALSE + ), + mean_forest_params = list(sample_sigma2_leaf = TRUE) + ) + expect_length(m$sigma2_leaf_samples, n_chains * n_mcmc) +}) + +test_that("BART multi-chain: chain independence (no GFR)", { + skip_on_cran() + d <- .make_bart_data() + n_mcmc <- 10 + m <- bart( + X_train = d$X_train, y_train = d$y_train, X_test = d$X_test, + num_gfr = 0, num_burnin = 0, num_mcmc = n_mcmc, + general_params = list(num_chains = 2, num_threads = 1) + ) + chain1 <- m$sigma2_global_samples[seq_len(n_mcmc)] + chain2 <- m$sigma2_global_samples[seq(n_mcmc + 1, 2 * n_mcmc)] + expect_false(isTRUE(all.equal(chain1, chain2)), + label = "Chains should produce distinct sigma2 samples") +}) + +test_that("BART multi-chain: chain independence (with GFR)", { + skip_on_cran() + d <- .make_bart_data() + n_mcmc <- 10; n_gfr <- 4 + m <- bart( + X_train = d$X_train, y_train = d$y_train, X_test = d$X_test, + num_gfr = n_gfr, num_burnin = 5, num_mcmc = n_mcmc, + general_params = list(num_chains = 2, num_threads = 1) + ) + chain1 <- m$sigma2_global_samples[seq_len(n_mcmc)] + chain2 <- m$sigma2_global_samples[seq(n_mcmc + 1, 2 * n_mcmc)] + expect_false(isTRUE(all.equal(chain1, chain2))) +}) + +test_that("BART multi-chain: extractParameter dimensions", { + skip_on_cran() + d <- .make_bart_data() + n_chains <- 3; n_mcmc <- 10 + m <- bart( + X_train = d$X_train, y_train = d$y_train, X_test = d$X_test, + num_gfr = 0, num_burnin = 0, num_mcmc = n_mcmc, + general_params = list(num_chains = n_chains, num_threads = 1) + ) + expected <- n_chains * n_mcmc + s2 <- extractParameter(m, "sigma2_global") + expect_length(s2, expected) + yht <- extractParameter(m, "y_hat_train") + expect_equal(dim(yht), c(d$n_train, expected)) +}) + +test_that("BART multi-chain: predict() shape and finiteness (no GFR)", { + skip_on_cran() + d <- .make_bart_data() + n_chains <- 3; n_mcmc <- 10 + m <- bart( + X_train = d$X_train, y_train = d$y_train, X_test = d$X_test, + num_gfr = 0, num_burnin = 0, num_mcmc = n_mcmc, + general_params = list(num_chains = n_chains, num_threads = 1) + ) + expected_cols <- n_chains * n_mcmc + result <- predict(m, X = d$X_test, terms = "y_hat") + expect_equal(dim(result), c(d$n_test, expected_cols)) + expect_true(all(is.finite(result))) +}) + +test_that("BART multi-chain: predict() shape and finiteness (GFR path)", { + skip_on_cran() + d <- .make_bart_data() + n_chains <- 3; n_mcmc <- 10 + m <- bart( + X_train = d$X_train, y_train = d$y_train, X_test = d$X_test, + num_gfr = 6, num_burnin = 5, num_mcmc = n_mcmc, + general_params = list(num_chains = n_chains, num_threads = 1) + ) + result <- predict(m, X = d$X_test, terms = "y_hat") + expect_equal(dim(result), c(d$n_test, n_chains * n_mcmc)) + expect_true(all(is.finite(result))) +}) + +test_that("BART multi-chain: num_gfr < num_chains raises an error", { + skip_on_cran() + d <- .make_bart_data() + expect_error( + bart( + X_train = d$X_train, y_train = d$y_train, X_test = d$X_test, + num_gfr = 2, num_burnin = 0, num_mcmc = 5, + general_params = list(num_chains = 4, num_threads = 1) + ) + ) +}) + +test_that("BART multi-chain: sigma2 samples are finite and positive with GFR", { + skip_on_cran() + d <- .make_bart_data() + m <- bart( + X_train = d$X_train, y_train = d$y_train, X_test = d$X_test, + num_gfr = 6, num_burnin = 10, num_mcmc = 10, + general_params = list(num_chains = 3, num_threads = 1) + ) + expect_true(all(is.finite(m$sigma2_global_samples))) + expect_true(all(m$sigma2_global_samples > 0)) +}) + +# --------------------------------------------------------------------------- +# BCFModel multi-chain tests +# --------------------------------------------------------------------------- + +test_that("BCF multi-chain: sample counts with no GFR", { + skip_on_cran() + d <- .make_bcf_data() + n_chains <- 3; n_mcmc <- 10 + m <- bcf( + X_train = d$X_train, Z_train = d$Z_train, y_train = d$y_train, + propensity_train = d$pi_train, + X_test = d$X_test, Z_test = d$Z_test, propensity_test = d$pi_test, + num_gfr = 0, num_burnin = 10, num_mcmc = n_mcmc, + general_params = list(num_chains = n_chains, num_threads = 1) + ) + expected <- n_chains * n_mcmc + expect_length(m$sigma2_global_samples, expected) + expect_equal(dim(m$tau_hat_train), c(d$n_train, expected)) + expect_equal(dim(m$mu_hat_train), c(d$n_train, expected)) + expect_equal(dim(m$tau_hat_test), c(d$n_test, expected)) +}) + +test_that("BCF multi-chain: sample counts with GFR warm-start", { + skip_on_cran() + d <- .make_bcf_data() + n_chains <- 3; n_mcmc <- 10; n_gfr <- 6 + m <- bcf( + X_train = d$X_train, Z_train = d$Z_train, y_train = d$y_train, + propensity_train = d$pi_train, + X_test = d$X_test, Z_test = d$Z_test, propensity_test = d$pi_test, + num_gfr = n_gfr, num_burnin = 5, num_mcmc = n_mcmc, + general_params = list(num_chains = n_chains, num_threads = 1) + ) + expected <- n_chains * n_mcmc + expect_length(m$sigma2_global_samples, expected) + expect_equal(dim(m$tau_hat_train), c(d$n_train, expected)) + expect_equal(dim(m$mu_hat_train), c(d$n_train, expected)) + # BCF-specific scalar parameter arrays + expect_length(m$b_0_samples, expected) + expect_length(m$b_1_samples, expected) + expect_length(m$sigma2_leaf_mu_samples, expected) +}) + +test_that("BCF multi-chain: chain independence (no GFR)", { + skip_on_cran() + d <- .make_bcf_data() + n_mcmc <- 10 + m <- bcf( + X_train = d$X_train, Z_train = d$Z_train, y_train = d$y_train, + propensity_train = d$pi_train, + X_test = d$X_test, Z_test = d$Z_test, propensity_test = d$pi_test, + num_gfr = 0, num_burnin = 10, num_mcmc = n_mcmc, + general_params = list(num_chains = 2, num_threads = 1) + ) + chain1 <- m$sigma2_global_samples[seq_len(n_mcmc)] + chain2 <- m$sigma2_global_samples[seq(n_mcmc + 1, 2 * n_mcmc)] + expect_false(isTRUE(all.equal(chain1, chain2)), + label = "BCF chains should produce distinct sigma2 samples") +}) + +test_that("BCF multi-chain: chain independence (with GFR)", { + skip_on_cran() + d <- .make_bcf_data() + n_mcmc <- 10 + m <- bcf( + X_train = d$X_train, Z_train = d$Z_train, y_train = d$y_train, + propensity_train = d$pi_train, + X_test = d$X_test, Z_test = d$Z_test, propensity_test = d$pi_test, + num_gfr = 4, num_burnin = 5, num_mcmc = n_mcmc, + general_params = list(num_chains = 2, num_threads = 1) + ) + chain1 <- m$sigma2_global_samples[seq_len(n_mcmc)] + chain2 <- m$sigma2_global_samples[seq(n_mcmc + 1, 2 * n_mcmc)] + expect_false(isTRUE(all.equal(chain1, chain2))) +}) + +test_that("BCF multi-chain: all samples finite with GFR + multiple chains", { + skip_on_cran() + # Exercises the tau_0 / adaptive-coding reset logic introduced to prevent + # residual blowup when transitioning between chains. + d <- .make_bcf_data() + m <- bcf( + X_train = d$X_train, Z_train = d$Z_train, y_train = d$y_train, + propensity_train = d$pi_train, + X_test = d$X_test, Z_test = d$Z_test, propensity_test = d$pi_test, + num_gfr = 6, num_burnin = 20, num_mcmc = 10, + general_params = list(num_chains = 3, num_threads = 1) + ) + expect_true(all(is.finite(m$sigma2_global_samples)), + label = "sigma2 samples must be finite (no chain-transition blowup)") + expect_true(all(m$sigma2_global_samples > 0)) + expect_true(all(is.finite(m$b_0_samples))) + expect_true(all(is.finite(m$b_1_samples))) +}) + +test_that("BCF multi-chain: extractParameter dimensions", { + skip_on_cran() + d <- .make_bcf_data() + n_chains <- 3; n_mcmc <- 10 + m <- bcf( + X_train = d$X_train, Z_train = d$Z_train, y_train = d$y_train, + propensity_train = d$pi_train, + X_test = d$X_test, Z_test = d$Z_test, propensity_test = d$pi_test, + num_gfr = 0, num_burnin = 10, num_mcmc = n_mcmc, + general_params = list(num_chains = n_chains, num_threads = 1) + ) + expected <- n_chains * n_mcmc + s2 <- extractParameter(m, "sigma2_global") + expect_length(s2, expected) + tau_test <- extractParameter(m, "tau_hat_test") + expect_equal(dim(tau_test), c(d$n_test, expected)) + tau_train <- extractParameter(m, "tau_hat_train") + expect_equal(dim(tau_train), c(d$n_train, expected)) +}) + +test_that("BCF multi-chain: predict() shape and finiteness for all forest terms (no GFR)", { + skip_on_cran() + d <- .make_bcf_data() + n_chains <- 3; n_mcmc <- 10 + m <- bcf( + X_train = d$X_train, Z_train = d$Z_train, y_train = d$y_train, + propensity_train = d$pi_train, + X_test = d$X_test, Z_test = d$Z_test, propensity_test = d$pi_test, + num_gfr = 0, num_burnin = 10, num_mcmc = n_mcmc, + general_params = list(num_chains = n_chains, num_threads = 1) + ) + expected_cols <- n_chains * n_mcmc + for (term in c("y_hat", "cate", "prognostic_function", "mu", "tau")) { + result <- predict(m, X = d$X_test, Z = d$Z_test, propensity = d$pi_test, terms = term) + expect_equal(dim(result), c(d$n_test, expected_cols), + label = paste0("dim for term='", term, "'")) + expect_true(all(is.finite(result)), + label = paste0("finiteness for term='", term, "'")) + } +}) + +test_that("BCF multi-chain: predict() shape and finiteness for all forest terms (GFR path)", { + skip_on_cran() + d <- .make_bcf_data() + n_chains <- 3; n_mcmc <- 10; n_gfr <- 6 + m <- bcf( + X_train = d$X_train, Z_train = d$Z_train, y_train = d$y_train, + propensity_train = d$pi_train, + X_test = d$X_test, Z_test = d$Z_test, propensity_test = d$pi_test, + num_gfr = n_gfr, num_burnin = 5, num_mcmc = n_mcmc, + general_params = list(num_chains = n_chains, num_threads = 1) + ) + expected_cols <- n_chains * n_mcmc + for (term in c("y_hat", "cate", "prognostic_function", "mu", "tau")) { + result <- predict(m, X = d$X_test, Z = d$Z_test, propensity = d$pi_test, terms = term) + expect_equal(dim(result), c(d$n_test, expected_cols), + label = paste0("dim for term='", term, "'")) + expect_true(all(is.finite(result)), + label = paste0("finiteness for term='", term, "'")) + } +}) + +test_that("BCF multi-chain: predict() shape and positivity for variance forest term", { + skip_on_cran() + d <- .make_bcf_data() + n_chains <- 3; n_mcmc <- 10 + m <- bcf( + X_train = d$X_train, Z_train = d$Z_train, y_train = d$y_train, + propensity_train = d$pi_train, + X_test = d$X_test, Z_test = d$Z_test, propensity_test = d$pi_test, + num_gfr = 0, num_burnin = 10, num_mcmc = n_mcmc, + general_params = list(num_chains = n_chains, num_threads = 1), + variance_forest_params = list(num_trees = 10) + ) + result <- predict(m, X = d$X_test, Z = d$Z_test, propensity = d$pi_test, + terms = "variance_forest") + expect_equal(dim(result), c(d$n_test, n_chains * n_mcmc)) + expect_true(all(is.finite(result))) + expect_true(all(result > 0)) +}) + +test_that("BCF multi-chain: num_gfr < num_chains raises an error", { + skip_on_cran() + d <- .make_bcf_data() + expect_error( + bcf( + X_train = d$X_train, Z_train = d$Z_train, y_train = d$y_train, + propensity_train = d$pi_train, + X_test = d$X_test, Z_test = d$Z_test, propensity_test = d$pi_test, + num_gfr = 2, num_burnin = 0, num_mcmc = 5, + general_params = list(num_chains = 4, num_threads = 1) + ) + ) +}) + +test_that("BCF multi-chain: serialization round-trip preserves predictions", { + skip_on_cran() + d <- .make_bcf_data() + n_chains <- 2; n_mcmc <- 10; n_gfr <- 4 + m <- bcf( + X_train = d$X_train, Z_train = d$Z_train, y_train = d$y_train, + propensity_train = d$pi_train, + X_test = d$X_test, Z_test = d$Z_test, propensity_test = d$pi_test, + num_gfr = n_gfr, num_burnin = 5, num_mcmc = n_mcmc, + general_params = list(num_chains = n_chains, num_threads = 1) + ) + + json_str <- saveBCFModelToJsonString(m) + m2 <- createBCFModelFromJsonString(json_str) + + pred_orig <- predict( + m, X = d$X_test, Z = d$Z_test, propensity = d$pi_test, terms = "cate" + ) + pred_rt <- predict( + m2, X = d$X_test, Z = d$Z_test, propensity = d$pi_test, terms = "cate" + ) + expect_equal(dim(pred_orig), dim(pred_rt)) + expect_equal(pred_orig, pred_rt) +}) diff --git a/test/python/test_multi_chain.py b/test/python/test_multi_chain.py new file mode 100644 index 00000000..9f297a2f --- /dev/null +++ b/test/python/test_multi_chain.py @@ -0,0 +1,358 @@ +"""Tests for multi-chain BART and BCF sampling. + +Covers sample-count correctness, GFR warm-start path, chain independence, +extract_parameter dimensions, and the num_gfr >= num_chains validation. +""" +import numpy as np +import pytest +from sklearn.model_selection import train_test_split + +from stochtree import BARTModel, BCFModel + + +# --------------------------------------------------------------------------- +# Shared fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture(scope="module") +def bart_data(): + rng = np.random.default_rng(42) + n, p = 200, 5 + X = rng.uniform(0, 1, (n, p)) + y = 5 * X[:, 0] + rng.standard_normal(n) + idx = np.arange(n) + train_inds, test_inds = train_test_split(idx, test_size=0.2, random_state=42) + return { + "X_train": X[train_inds], + "X_test": X[test_inds], + "y_train": y[train_inds], + "n_train": len(train_inds), + "n_test": len(test_inds), + } + + +@pytest.fixture(scope="module") +def bcf_data(): + rng = np.random.default_rng(42) + n, p = 200, 5 + X = rng.uniform(0, 1, (n, p)) + pi_X = 0.25 + 0.5 * X[:, 0] + Z = rng.binomial(1, pi_X, n).astype(float) + y = 5 * X[:, 0] + 2 * X[:, 1] * Z + rng.standard_normal(n) + idx = np.arange(n) + train_inds, test_inds = train_test_split(idx, test_size=0.2, random_state=42) + return { + "X_train": X[train_inds], + "X_test": X[test_inds], + "Z_train": Z[train_inds], + "Z_test": Z[test_inds], + "y_train": y[train_inds], + "pi_train": pi_X[train_inds], + "pi_test": pi_X[test_inds], + "n_train": len(train_inds), + "n_test": len(test_inds), + } + + +# --------------------------------------------------------------------------- +# Helper +# --------------------------------------------------------------------------- + +def _bart(data, *, num_gfr, num_burnin, num_mcmc, num_chains, **kw): + m = BARTModel() + m.sample( + X_train=data["X_train"], + y_train=data["y_train"], + X_test=data["X_test"], + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + general_params={"num_chains": num_chains, "num_threads": 1, **kw}, + ) + return m + + +def _bcf(data, *, num_gfr, num_burnin, num_mcmc, num_chains, **kw): + m = BCFModel() + m.sample( + X_train=data["X_train"], + Z_train=data["Z_train"], + y_train=data["y_train"], + propensity_train=data["pi_train"], + X_test=data["X_test"], + Z_test=data["Z_test"], + propensity_test=data["pi_test"], + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + general_params={"num_chains": num_chains, "num_threads": 1, **kw}, + ) + return m + + +# --------------------------------------------------------------------------- +# BARTModel multi-chain tests +# --------------------------------------------------------------------------- + +class TestBARTMultiChain: + NUM_MCMC = 10 + NUM_CHAINS = 3 + + def test_sample_counts_no_gfr(self, bart_data): + """Total kept samples = num_chains * num_mcmc when num_gfr=0.""" + n_chains, n_mcmc = self.NUM_CHAINS, self.NUM_MCMC + m = _bart(bart_data, num_gfr=0, num_burnin=0, num_mcmc=n_mcmc, num_chains=n_chains) + expected = n_chains * n_mcmc + assert m.global_var_samples.shape == (expected,) + assert m.y_hat_train.shape == (bart_data["n_train"], expected) + assert m.y_hat_test.shape == (bart_data["n_test"], expected) + + def test_sample_counts_with_gfr(self, bart_data): + """With GFR, total kept samples = num_chains * num_mcmc (GFR dropped by default).""" + n_chains, n_mcmc, n_gfr = self.NUM_CHAINS, self.NUM_MCMC, 6 + m = _bart(bart_data, num_gfr=n_gfr, num_burnin=5, num_mcmc=n_mcmc, num_chains=n_chains) + expected = n_chains * n_mcmc + assert m.global_var_samples.shape == (expected,) + assert m.y_hat_train.shape == (bart_data["n_train"], expected) + + def test_leaf_scale_sample_count(self, bart_data): + """Leaf-scale samples also have num_chains * num_mcmc entries.""" + n_chains, n_mcmc = self.NUM_CHAINS, self.NUM_MCMC + m = _bart( + bart_data, + num_gfr=0, num_burnin=0, num_mcmc=n_mcmc, num_chains=n_chains, + sample_sigma2_global=False, + ) + m2 = BARTModel() + m2.sample( + X_train=bart_data["X_train"], + y_train=bart_data["y_train"], + X_test=bart_data["X_test"], + num_gfr=0, + num_burnin=0, + num_mcmc=n_mcmc, + general_params={"num_chains": n_chains, "num_threads": 1, "sample_sigma2_global": False}, + mean_forest_params={"sample_sigma2_leaf": True}, + ) + assert m2.leaf_scale_samples.shape == (n_chains * n_mcmc,) + + def test_chain_independence_no_gfr(self, bart_data): + """With 2 chains, sigma2 samples from different chains are not identical.""" + m = _bart(bart_data, num_gfr=0, num_burnin=0, num_mcmc=self.NUM_MCMC, num_chains=2) + chain1 = m.global_var_samples[: self.NUM_MCMC] + chain2 = m.global_var_samples[self.NUM_MCMC :] + assert not np.allclose(chain1, chain2), ( + "Chains produced identical sigma2 samples; they should be independent." + ) + + def test_chain_independence_with_gfr(self, bart_data): + """With GFR warm-start, different chains still produce distinct sigma2 samples.""" + n_gfr, n_mcmc = 4, self.NUM_MCMC + m = _bart(bart_data, num_gfr=n_gfr, num_burnin=5, num_mcmc=n_mcmc, num_chains=2) + chain1 = m.global_var_samples[:n_mcmc] + chain2 = m.global_var_samples[n_mcmc:] + assert not np.allclose(chain1, chain2) + + def test_extract_parameter_multi_chain(self, bart_data): + """extract_parameter returns num_chains * num_mcmc samples.""" + n_chains, n_mcmc = self.NUM_CHAINS, self.NUM_MCMC + m = _bart(bart_data, num_gfr=0, num_burnin=0, num_mcmc=n_mcmc, num_chains=n_chains) + s2 = m.extract_parameter("sigma2_global") + assert s2.shape == (n_chains * n_mcmc,) + yht = m.extract_parameter("y_hat_train") + assert yht.shape == (bart_data["n_train"], n_chains * n_mcmc) + + def test_predict_multi_chain(self, bart_data): + """predict() returns correct shape and finite values for multi-chain BART.""" + n_chains, n_mcmc = self.NUM_CHAINS, self.NUM_MCMC + m = _bart(bart_data, num_gfr=0, num_burnin=0, num_mcmc=n_mcmc, num_chains=n_chains) + expected_cols = n_chains * n_mcmc + n_test = bart_data["n_test"] + result = m.predict(X=bart_data["X_test"], terms="y_hat") + assert result.shape == (n_test, expected_cols) + assert np.all(np.isfinite(result)) + + def test_predict_multi_chain_gfr(self, bart_data): + """predict() stays finite with GFR warm-start + multiple chains.""" + m = _bart(bart_data, num_gfr=6, num_burnin=5, num_mcmc=self.NUM_MCMC, num_chains=self.NUM_CHAINS) + result = m.predict(X=bart_data["X_test"], terms="y_hat") + assert result.shape == (bart_data["n_test"], self.NUM_CHAINS * self.NUM_MCMC) + assert np.all(np.isfinite(result)) + + def test_num_gfr_less_than_num_chains_raises(self, bart_data): + """num_chains > num_gfr must raise a ValueError.""" + with pytest.raises((ValueError, Exception)): + _bart(bart_data, num_gfr=2, num_burnin=0, num_mcmc=5, num_chains=4) + + def test_samples_finite_multi_chain_gfr(self, bart_data): + """sigma2 samples are finite and positive with GFR warm-start + multiple chains.""" + m = _bart(bart_data, num_gfr=6, num_burnin=10, num_mcmc=self.NUM_MCMC, num_chains=3) + assert np.all(np.isfinite(m.global_var_samples)) + assert np.all(m.global_var_samples > 0) + + +# --------------------------------------------------------------------------- +# BCFModel multi-chain tests +# --------------------------------------------------------------------------- + +class TestBCFMultiChain: + NUM_MCMC = 10 + NUM_CHAINS = 3 + NUM_GFR = 6 + + def test_sample_counts_no_gfr(self, bcf_data): + """Total kept samples = num_chains * num_mcmc when num_gfr=0.""" + n_chains, n_mcmc = self.NUM_CHAINS, self.NUM_MCMC + m = _bcf(bcf_data, num_gfr=0, num_burnin=10, num_mcmc=n_mcmc, num_chains=n_chains) + expected = n_chains * n_mcmc + assert m.global_var_samples.shape == (expected,) + assert m.tau_hat_train.shape == (bcf_data["n_train"], expected) + assert m.mu_hat_train.shape == (bcf_data["n_train"], expected) + assert m.tau_hat_test.shape == (bcf_data["n_test"], expected) + + def test_sample_counts_with_gfr(self, bcf_data): + """GFR warm-start path: all parameter arrays have num_chains * num_mcmc entries.""" + n_chains = self.NUM_CHAINS + n_mcmc = self.NUM_MCMC + n_gfr = self.NUM_GFR + m = _bcf(bcf_data, num_gfr=n_gfr, num_burnin=5, num_mcmc=n_mcmc, num_chains=n_chains) + expected = n_chains * n_mcmc + assert m.global_var_samples.shape == (expected,) + # BCF-specific samples + assert m.tau_0_samples.shape == (1, expected) + assert m.b0_samples.shape == (expected,) + assert m.b1_samples.shape == (expected,) + assert m.leaf_scale_mu_samples.shape == (expected,) + # Predictions + assert m.tau_hat_train.shape == (bcf_data["n_train"], expected) + assert m.mu_hat_train.shape == (bcf_data["n_train"], expected) + + def test_chain_independence_no_gfr(self, bcf_data): + """With 2 chains (no GFR), sigma2 samples differ across chains.""" + m = _bcf(bcf_data, num_gfr=0, num_burnin=10, num_mcmc=self.NUM_MCMC, num_chains=2) + chain1 = m.global_var_samples[: self.NUM_MCMC] + chain2 = m.global_var_samples[self.NUM_MCMC :] + assert not np.allclose(chain1, chain2) + + def test_chain_independence_with_gfr(self, bcf_data): + """With GFR warm-start, chains produce distinct sigma2 samples.""" + n_mcmc = self.NUM_MCMC + m = _bcf(bcf_data, num_gfr=4, num_burnin=5, num_mcmc=n_mcmc, num_chains=2) + chain1 = m.global_var_samples[:n_mcmc] + chain2 = m.global_var_samples[n_mcmc:] + assert not np.allclose(chain1, chain2) + + def test_samples_finite_gfr_multi_chain(self, bcf_data): + """sigma2 samples remain finite with GFR warm-start + multiple chains. + + This exercises the tau_0 / adaptive-coding reset logic introduced to + prevent residual blowup when transitioning between chains. + """ + m = _bcf( + bcf_data, + num_gfr=self.NUM_GFR, + num_burnin=20, + num_mcmc=self.NUM_MCMC, + num_chains=self.NUM_CHAINS, + ) + assert np.all(np.isfinite(m.global_var_samples)), ( + "sigma2 samples contain non-finite values; possible chain-transition blowup." + ) + assert np.all(m.global_var_samples > 0) + assert np.all(np.isfinite(m.tau_0_samples)) + assert np.all(np.isfinite(m.b0_samples)) + assert np.all(np.isfinite(m.b1_samples)) + + def test_extract_parameter_multi_chain(self, bcf_data): + """extract_parameter returns num_chains * num_mcmc samples for BCF.""" + n_chains, n_mcmc = self.NUM_CHAINS, self.NUM_MCMC + m = _bcf(bcf_data, num_gfr=0, num_burnin=10, num_mcmc=n_mcmc, num_chains=n_chains) + expected = n_chains * n_mcmc + s2 = m.extract_parameter("sigma2_global") + assert s2.shape == (expected,) + cate = m.extract_parameter("cate_test") + assert cate.shape == (bcf_data["n_test"], expected) + prog = m.extract_parameter("prognostic_function_test") + assert prog.shape == (bcf_data["n_test"], expected) + + def test_predict_terms_multi_chain_no_gfr(self, bcf_data): + """predict() returns correct shape and finite values for each forest term (no GFR).""" + n_chains, n_mcmc = self.NUM_CHAINS, self.NUM_MCMC + m = _bcf(bcf_data, num_gfr=0, num_burnin=10, num_mcmc=n_mcmc, num_chains=n_chains) + expected_cols = n_chains * n_mcmc + n_test = bcf_data["n_test"] + kw = dict(X=bcf_data["X_test"], Z=bcf_data["Z_test"], propensity=bcf_data["pi_test"]) + for term in ["y_hat", "cate", "prognostic_function", "mu", "tau"]: + result = m.predict(**kw, terms=term) + assert result.shape == (n_test, expected_cols), f"shape mismatch for term={term!r}" + assert np.all(np.isfinite(result)), f"non-finite values for term={term!r}" + + def test_predict_terms_multi_chain_with_gfr(self, bcf_data): + """predict() returns correct shape and finite values for each forest term (GFR path).""" + n_chains, n_mcmc = self.NUM_CHAINS, self.NUM_MCMC + m = _bcf(bcf_data, num_gfr=self.NUM_GFR, num_burnin=5, num_mcmc=n_mcmc, num_chains=n_chains) + expected_cols = n_chains * n_mcmc + n_test = bcf_data["n_test"] + kw = dict(X=bcf_data["X_test"], Z=bcf_data["Z_test"], propensity=bcf_data["pi_test"]) + for term in ["y_hat", "cate", "prognostic_function", "mu", "tau"]: + result = m.predict(**kw, terms=term) + assert result.shape == (n_test, expected_cols), f"shape mismatch for term={term!r}" + assert np.all(np.isfinite(result)), f"non-finite values for term={term!r}" + + def test_predict_variance_forest_multi_chain(self, bcf_data): + """predict() returns correct shape and positive values for variance forest term.""" + n_chains, n_mcmc = self.NUM_CHAINS, self.NUM_MCMC + m = BCFModel() + m.sample( + X_train=bcf_data["X_train"], + Z_train=bcf_data["Z_train"], + y_train=bcf_data["y_train"], + propensity_train=bcf_data["pi_train"], + X_test=bcf_data["X_test"], + Z_test=bcf_data["Z_test"], + propensity_test=bcf_data["pi_test"], + num_gfr=0, + num_burnin=10, + num_mcmc=n_mcmc, + general_params={"num_chains": n_chains, "num_threads": 1}, + variance_forest_params={"num_trees": 10}, + ) + result = m.predict( + X=bcf_data["X_test"], + Z=bcf_data["Z_test"], + propensity=bcf_data["pi_test"], + terms="variance_forest", + ) + assert result.shape == (bcf_data["n_test"], n_chains * n_mcmc) + assert np.all(np.isfinite(result)) + assert np.all(result > 0) + + def test_num_gfr_less_than_num_chains_raises(self, bcf_data): + """num_chains > num_gfr must raise an error.""" + with pytest.raises((ValueError, Exception)): + _bcf(bcf_data, num_gfr=2, num_burnin=0, num_mcmc=5, num_chains=4) + + def test_serialization_round_trip_multi_chain(self, bcf_data): + """Serialize and reload a multi-chain BCF model; predictions must match.""" + import json + n_chains, n_mcmc, n_gfr = 2, self.NUM_MCMC, 4 + m = _bcf(bcf_data, num_gfr=n_gfr, num_burnin=5, num_mcmc=n_mcmc, num_chains=n_chains) + json_str = m.to_json() + + m2 = BCFModel() + m2.from_json(json_str) + + pred_orig = m.predict( + X=bcf_data["X_test"], + Z=bcf_data["Z_test"], + propensity=bcf_data["pi_test"], + terms="cate", + ) + pred_rt = m2.predict( + X=bcf_data["X_test"], + Z=bcf_data["Z_test"], + propensity=bcf_data["pi_test"], + terms="cate", + ) + assert pred_orig.shape == pred_rt.shape + np.testing.assert_allclose(pred_orig, pred_rt) diff --git a/tools/debug/R_bcf_troubleshoot.R b/tools/debug/R_bcf_troubleshoot.R new file mode 100644 index 00000000..50b55101 --- /dev/null +++ b/tools/debug/R_bcf_troubleshoot.R @@ -0,0 +1,133 @@ +# Load package +library(stochtree) + +# Helper functions +g <- function(x) { + ifelse(x[, 5] == 1, 2, ifelse(x[, 5] == 2, -1, -4)) +} +mu1 <- function(x) { + 1 + g(x) + x[, 1] * x[, 3] +} +mu2 <- function(x) { + 1 + g(x) + 6 * abs(x[, 3] - 1) +} +tau1 <- function(x) { + rep(3, nrow(x)) +} +tau2 <- function(x) { + 1 + 2 * x[, 2] * x[, 4] +} + +# Generate data +n <- 500 +snr <- 3 +x1 <- rnorm(n) +x2 <- rnorm(n) +x3 <- rnorm(n) +x4 <- as.numeric(rbinom(n, 1, 0.5)) +x5 <- as.numeric(sample(1:3, n, replace = TRUE)) +X <- cbind(x1, x2, x3, x4, x5) +p <- ncol(X) +mu_x <- mu1(X) +tau_x <- tau2(X) +pi_x <- 0.8 * pnorm((3 * mu_x / sd(mu_x)) - 0.5 * X[, 1]) + 0.05 + runif(n) / 10 +Z <- rbinom(n, 1, pi_x) +E_XZ <- mu_x + Z * tau_x +y <- E_XZ + rnorm(n, 0, 1) * (sd(E_XZ) / snr) +X <- as.data.frame(X) +X$x4 <- factor(X$x4, ordered = TRUE) +X$x5 <- factor(X$x5, ordered = TRUE) + +# Split data into test and train sets +test_set_pct <- 0.2 +n_test <- round(test_set_pct * n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds, ] +X_train <- X[train_inds, ] +pi_test <- pi_x[test_inds] +pi_train <- pi_x[train_inds] +Z_test <- Z[test_inds] +Z_train <- Z[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] +mu_test <- mu_x[test_inds] +mu_train <- mu_x[train_inds] +tau_test <- tau_x[test_inds] +tau_train <- tau_x[train_inds] + +# Sample the model +general_params <- list(num_threads = 1, num_chains = 4) +bcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 10, + num_burnin = 1000, + num_mcmc = 100, + general_params = general_params +) + +# Plot true versus estimated prognostic function +mu_hat_test <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + terms = "prognostic_function" +) +plot( + rowMeans(mu_hat_test), + mu_test, + xlab = "predicted", + ylab = "actual", + main = "Prognostic function" +) +abline(0, 1, col = "red", lty = 3, lwd = 3) + +# Plot true versus estimated CATE function +tau_hat_test <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + terms = "cate" +) +plot( + rowMeans(tau_hat_test), + tau_test, + xlab = "predicted", + ylab = "actual", + main = "Treatment effect" +) +abline(0, 1, col = "red", lty = 3, lwd = 3) +sqrt(mean((rowMeans(tau_hat_test) - tau_test)^2)) +cor(rowMeans(tau_hat_test), tau_test) + +# Inspect sigma^2 traceplot +sigma_observed <- var(y - E_XZ) +sigma2_global_samples <- extractParameter(bcf_model, "sigma2_global") +plot_bounds <- c( + min(c(sigma2_global_samples, sigma_observed)), + max(c(sigma2_global_samples, sigma_observed)) +) +plot( + sigma2_global_samples, + ylim = plot_bounds, + ylab = "sigma^2", + xlab = "Sample", + main = "Global variance parameter" +) +abline(h = sigma_observed, lty = 3, lwd = 3, col = "blue") + +# Assess CATE function coverage +test_lb <- apply(tau_hat_test, 1, quantile, 0.025) +test_ub <- apply(tau_hat_test, 1, quantile, 0.975) +cover <- ((test_lb <= tau_x[test_inds]) & + (test_ub >= tau_x[test_inds])) +mean(cover) diff --git a/tools/debug/python_bcf_troubleshoot.py b/tools/debug/python_bcf_troubleshoot.py new file mode 100644 index 00000000..c0acd3e5 --- /dev/null +++ b/tools/debug/python_bcf_troubleshoot.py @@ -0,0 +1,123 @@ +# Load packages +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +from scipy.stats import norm +from stochtree import BCFModel + +# Helper functions +def g(x): + return np.where(x[:, 4] == 1, 2, np.where(x[:, 4] == 2, -1, -4)) + +def mu1(x): + return 1 + g(x) + x[:, 0] * x[:, 2] + +def mu2(x): + return 1 + g(x) + 6 * np.abs(x[:, 2] - 1) + +def tau1(x): + return np.full(x.shape[0], 3.0) + +def tau2(x): + return 1 + 2 * x[:, 1] * x[:, 3] + +rng = np.random.default_rng(101) + +# Generate data +n = 500 +snr = 3 +x1 = rng.normal(size=n) +x2 = rng.normal(size=n) +x3 = rng.normal(size=n) +x4 = rng.binomial(1, 0.5, n).astype(float) +x5 = rng.choice([1, 2, 3], size=n).astype(float) +X = np.column_stack([x1, x2, x3, x4, x5]) +mu_x = mu1(X) +tau_x = tau2(X) +pi_x = (0.8 * norm.cdf((3 * mu_x / np.std(mu_x)) - 0.5 * X[:, 0]) + + 0.05 + rng.uniform(size=n) / 10) +Z = rng.binomial(1, pi_x, n).astype(float) +E_XZ = mu_x + Z * tau_x +y = E_XZ + rng.normal(size=n) * (np.std(E_XZ) / snr) + +# Convert to DataFrame with ordered categoricals (matching R's factor(..., ordered=TRUE)) +X_df = pd.DataFrame({"x1": x1, "x2": x2, "x3": x3, "x4": x4, "x5": x5}) +X_df["x4"] = pd.Categorical(X_df["x4"].astype(int), categories=[0, 1], ordered=True) +X_df["x5"] = pd.Categorical(X_df["x5"].astype(int), categories=[1, 2, 3], ordered=True) + +# Split data into test and train sets +test_set_pct = 0.2 +n_test = round(test_set_pct * n) +n_train = n - n_test +test_inds = rng.choice(n, n_test, replace=False) +train_inds = np.setdiff1d(np.arange(n), test_inds) +X_test = X_df.iloc[test_inds] +X_train = X_df.iloc[train_inds] +pi_test, pi_train = pi_x[test_inds], pi_x[train_inds] +Z_test, Z_train = Z[test_inds], Z[train_inds] +y_test, y_train = y[test_inds], y[train_inds] +mu_test, mu_train = mu_x[test_inds], mu_x[train_inds] +tau_test, tau_train = tau_x[test_inds], tau_x[train_inds] + +# Sample the model +bcf_model = BCFModel() +bcf_model.sample( + X_train=X_train, + Z_train=Z_train, + y_train=y_train, + propensity_train=pi_train, + X_test=X_test, + Z_test=Z_test, + num_gfr=10, + num_burnin=1000, + num_mcmc=100, + propensity_test=pi_test, + general_params={"num_threads": 1, "num_chains": 4}, +) + +# Plot true versus estimated prognostic function +mu_hat_test = bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_test, terms="prognostic_function") +sigma_observed = np.var(y - E_XZ) +mu_pred = mu_hat_test.mean(axis=1) +lo, hi = min(mu_pred.min(), mu_test.min()), max(mu_pred.max(), mu_test.max()) +plt.close() +plt.scatter(mu_pred, mu_test, alpha=0.5) +plt.plot([lo, hi], [lo, hi], color="red", linestyle="dashed", linewidth=2) +plt.xlabel("Predicted") +plt.ylabel("Actual") +plt.title("Prognostic function") +plt.show() + +# Plot true versus estimated CATE function +tau_hat_test = bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_test, terms="cate") +tau_pred = tau_hat_test.mean(axis=1) +lo, hi = min(tau_pred.min(), tau_test.min()), max(tau_pred.max(), tau_test.max()) +plt.close() +plt.scatter(tau_pred, tau_test, alpha=0.5) +plt.plot([lo, hi], [lo, hi], color="red", linestyle="dashed", linewidth=2) +plt.xlabel("Predicted") +plt.ylabel("Actual") +plt.title("Treatment effect") +plt.show() +rmse = np.sqrt(np.mean((tau_pred - tau_test) ** 2)) +corr = np.corrcoef(tau_pred, tau_test)[0, 1] +print(f"RMSE between predicted and actual treatment effects: {rmse:.2f}") +print(f"Correlation between predicted and actual treatment effects: {corr:.2f}") + +# Inspect sigma^2 traceplot +global_var_samples = bcf_model.extract_parameter("sigma2_global") +plt.close() +plt.plot(global_var_samples) +plt.axhline(sigma_observed, color="blue", linestyle="dashed", linewidth=2) +plt.xlabel("Sample") +plt.ylabel(r"$\sigma^2$") +plt.title("Global variance parameter") +plt.show() + +# Assess CATE function coverage +test_lb = np.quantile(tau_hat_test, 0.025, axis=1) +test_ub = np.quantile(tau_hat_test, 0.975, axis=1) +cover = ((test_lb <= tau_x[test_inds]) & + (test_ub >= tau_x[test_inds])) +coverage = np.mean(cover) +print(f"Coverage of 95% credible intervals of CATE function: {coverage*100:.2f}%")