diff --git a/R/bcf.R b/R/bcf.R index 1e7daad4..44700393 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -134,20 +134,22 @@ NULL #' #' @param treatment_effect_forest_params (Optional) A list of treatment effect forest model parameters, each of which has a default value processed internally, so this argument list is optional. #' -#' - `num_trees` Number of trees in the ensemble for the treatment effect forest. Default: `50`. Must be a positive integer. +#' - `num_trees` Number of trees in the ensemble for the treatment effect forest. Default: `100`. Must be a positive integer. #' - `alpha` Prior probability of splitting for a tree of depth 0 in the treatment effect forest. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Default: `0.25`. #' - `beta` Exponent that decreases split probabilities for nodes of depth > 0 in the treatment effect forest. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Default: `3`. #' - `min_samples_leaf` Minimum allowable size of a leaf, in terms of training samples, in the treatment effect forest. Default: `5`. #' - `max_depth` Maximum depth of any tree in the ensemble in the treatment effect forest. Default: `5`. Can be overridden with ``-1`` which does not enforce any depth limits on trees. #' - `variable_weights` Numeric weights reflecting the relative probability of splitting on each variable in the treatment effect forest. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here. #' - `sample_sigma2_leaf` Whether or not to update the leaf scale variance parameter based on `IG(sigma2_leaf_shape, sigma2_leaf_scale)`. Cannot (currently) be set to true if `ncol(Z_train)>1`. Default: `FALSE`. -#' - `sigma2_leaf_init` Starting value of leaf node scale parameter. Calibrated internally as `1/num_trees` if not set here. +#' - `sigma2_leaf_init` Starting value of leaf node scale parameter. Calibrated internally as `0.5 * var(y)/num_trees` if not set here (`0.5 / num_trees` if `y` is continuous and `standardize = TRUE` in the `general_params` list). #' - `sigma2_leaf_shape` Shape parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Default: `3`. #' - `sigma2_leaf_scale` Scale parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here. #' - `delta_max` Maximum plausible conditional distributional treatment effect (i.e. P(Y(1) = 1 | X) - P(Y(0) = 1 | X)) when the outcome is binary. Only used when the outcome is specified as a probit model in `general_params`. Must be > 0 and < 1. Default: `0.9`. Ignored if `sigma2_leaf_init` is set directly, as this parameter is used to calibrate `sigma2_leaf_init`. #' - `keep_vars` Vector of variable names or column indices denoting variables that should be included in the forest. Default: `NULL`. #' - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored. #' - `num_features_subsample` How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset. +#' - `sample_intercept` Whether to sample a global treatment effect intercept `tau_0` so the full CATE is `tau_0 + tau(X)`. Default: `TRUE`. Compatible with `adaptive_coding = TRUE`, in which case the recoded treatment basis is used. +#' - `tau_0_prior_var` Variance of the normal prior on `tau_0` (a scalar applied to each treatment dimension independently). Auto-calibrated to outcome variance when `NULL` and outcome is continuous. Only used when `sample_intercept = TRUE`. #' #' @param variance_forest_params (Optional) A list of variance forest model parameters, each of which has a default value processed internally, so this argument list is optional. #' @@ -297,7 +299,7 @@ bcf <- function( # Update tau forest BCF parameters treatment_effect_forest_params_default <- list( - num_trees = 50, + num_trees = 100, alpha = 0.25, beta = 3.0, min_samples_leaf = 5, @@ -309,7 +311,9 @@ bcf <- function( keep_vars = NULL, drop_vars = NULL, delta_max = 0.9, - num_features_subsample = NULL + num_features_subsample = NULL, + sample_intercept = TRUE, + tau_0_prior_var = NULL ) treatment_effect_forest_params_updated <- preprocessParams( treatment_effect_forest_params_default, @@ -403,6 +407,8 @@ bcf <- function( drop_vars_tau <- treatment_effect_forest_params_updated$drop_vars delta_max <- treatment_effect_forest_params_updated$delta_max num_features_subsample_tau <- treatment_effect_forest_params_updated$num_features_subsample + sample_tau_0 <- treatment_effect_forest_params_updated$sample_intercept + tau_0_prior_var <- treatment_effect_forest_params_updated$tau_0_prior_var # 4. Variance forest parameters num_trees_variance <- variance_forest_params_updated$num_trees @@ -1157,6 +1163,17 @@ bcf <- function( adaptive_coding <- FALSE } + # Validate tau_0_prior_var if sample_tau_0 is TRUE + if (sample_tau_0 && !is.null(tau_0_prior_var)) { + if ( + !is.numeric(tau_0_prior_var) || + length(tau_0_prior_var) != 1 || + tau_0_prior_var <= 0 + ) { + stop("tau_0_prior_var must be a single positive numeric value") + } + } + # Check if propensity_covariate is one of the required inputs if ( !(propensity_covariate %in% @@ -1421,7 +1438,9 @@ bcf <- function( } } if (is.null(sigma2_leaf_tau)) { - sigma2_leaf_tau <- var_cpp(as.numeric(resid_train)) / (num_trees_tau) + sigma2_leaf_tau <- 0.5 * + var_cpp(as.numeric(resid_train)) / + (num_trees_tau) current_leaf_scale_tau <- as.matrix(diag( sigma2_leaf_tau, ncol(Z_train) @@ -1559,6 +1578,10 @@ bcf <- function( if (sample_sigma2_leaf_tau) { leaf_scale_tau_samples <- rep(NA, num_retained_samples) } + if (sample_tau_0) { + p_tau0 <- ncol(as.matrix(Z_train)) + tau_0_samples <- matrix(NA_real_, p_tau0, num_retained_samples) + } muhat_train_raw <- matrix(NA_real_, nrow(X_train), num_retained_samples) if (include_variance_forest) { sigma2_x_train_raw <- matrix( @@ -1592,6 +1615,19 @@ bcf <- function( if (has_test) tau_basis_test <- Z_test } + # Prepare tau_0 (global treatment effect intercept) structure + if (sample_tau_0) { + if (!exists("p_tau0")) { + p_tau0 <- ncol(as.matrix(Z_train)) + } + tau_0 <- rep(0.0, p_tau0) + # Auto-calibrate prior variance if not provided + if (is.null(tau_0_prior_var)) { + tau_0_prior_var <- var_cpp(as.numeric(resid_train)) + } + prior_var_tau0 <- diag(p_tau0) * tau_0_prior_var + } + # Data forest_dataset_train <- createForestDataset(X_train, tau_basis_train) if (has_test) { @@ -1839,6 +1875,45 @@ bcf <- function( ) } + # Sample tau_0 (global treatment effect intercept, if requested) + if (sample_tau_0) { + mu_x_raw_tau0 <- active_forest_mu$predict_raw(forest_dataset_train) + tau_x_raw_tau0 <- active_forest_tau$predict_raw(forest_dataset_train) + Z_basis_mat <- as.matrix(tau_basis_train) + # tau(X) * basis contribution per observation + tau_x_full <- rowSums(Z_basis_mat * as.matrix(tau_x_raw_tau0)) + partial_resid_tau0 <- resid_train - + as.numeric(mu_x_raw_tau0) - + tau_x_full + if (has_rfx) { + partial_resid_tau0 <- partial_resid_tau0 - + as.numeric( + rfx_model$predict(rfx_dataset_train, rfx_tracker_train) + ) + } + Ztr_tau0 <- t(Z_basis_mat) %*% as.matrix(partial_resid_tau0) + ZtZ_current <- crossprod(Z_basis_mat) + Sigma_post <- solve( + ZtZ_current / current_sigma2 + diag(p_tau0) / tau_0_prior_var + ) + mu_post_tau0 <- as.numeric(Sigma_post %*% Ztr_tau0 / current_sigma2) + if (p_tau0 == 1) { + tau_0_new <- rnorm(1, mu_post_tau0, sqrt(as.numeric(Sigma_post))) + } else { + tau_0_new <- as.numeric( + mu_post_tau0 + t(chol(Sigma_post)) %*% rnorm(p_tau0) + ) + } + resid_delta <- as.numeric( + Z_basis_mat %*% matrix(tau_0_new - tau_0, ncol = 1) + ) + outcome_train$subtract_vector(resid_delta) + tau_0 <- tau_0_new + if (keep_sample) { + tau_0_samples[, sample_counter] <- tau_0 + } + } + # Sample the treatment forest forest_model_tau$sample_one_iteration( forest_dataset = forest_dataset_train, @@ -1876,17 +1951,23 @@ bcf <- function( rfx_preds_train } - # Compute sufficient statistics for regression of y - mu(X) on [tau(X)(1-Z), tau(X)Z] - s_tt0 <- sum(tau_x_raw_train * tau_x_raw_train * (Z_train == 0)) - s_tt1 <- sum(tau_x_raw_train * tau_x_raw_train * (Z_train == 1)) + # Compute sufficient statistics for regression of y - mu(X) on [tau_total(X)(1-Z), tau_total(X)Z] + # where tau_total(X) = tau_0 + tau(X) when sample_tau_0, else tau(X) + tau_x_for_coding <- if (sample_tau_0) { + tau_x_raw_train + tau_0[1] + } else { + tau_x_raw_train + } + s_tt0 <- sum(tau_x_for_coding * tau_x_for_coding * (Z_train == 0)) + s_tt1 <- sum(tau_x_for_coding * tau_x_for_coding * (Z_train == 1)) s_ty0 <- sum( - tau_x_raw_train * partial_resid_mu_train * (Z_train == 0) + tau_x_for_coding * partial_resid_mu_train * (Z_train == 0) ) s_ty1 <- sum( - tau_x_raw_train * partial_resid_mu_train * (Z_train == 1) + tau_x_for_coding * partial_resid_mu_train * (Z_train == 1) ) - # Sample b0 (coefficient on tau(X)(1-Z)) and b1 (coefficient on tau(X)Z) + # Sample b0 (coefficient on tau_total(X)(1-Z)) and b1 (coefficient on tau_total(X)Z) current_b_0 <- rnorm( 1, (s_ty0 / (s_tt0 + 2 * current_sigma2)), @@ -1899,6 +1980,9 @@ bcf <- function( ) # Update basis for the leaf regression + if (sample_tau_0) { + tau_basis_old <- tau_basis_train + } tau_basis_train <- (1 - Z_train) * current_b_0 + Z_train * current_b_1 @@ -1920,6 +2004,13 @@ bcf <- function( outcome_train, active_forest_tau ) + + # Fix tau_0 component of residual after basis change + if (sample_tau_0) { + outcome_train$subtract_vector( + as.numeric(tau_basis_train - tau_basis_old) * tau_0[1] + ) + } } # Sample variance parameters (if requested) @@ -2473,6 +2564,45 @@ bcf <- function( ) } + # Sample tau_0 (global treatment effect intercept, if requested) + if (sample_tau_0) { + mu_x_raw_tau0 <- active_forest_mu$predict_raw(forest_dataset_train) + tau_x_raw_tau0 <- active_forest_tau$predict_raw(forest_dataset_train) + Z_basis_mat <- as.matrix(tau_basis_train) + # tau(X) * basis contribution per observation + tau_x_full <- rowSums(Z_basis_mat * as.matrix(tau_x_raw_tau0)) + partial_resid_tau0 <- resid_train - + as.numeric(mu_x_raw_tau0) - + tau_x_full + if (has_rfx) { + partial_resid_tau0 <- partial_resid_tau0 - + as.numeric( + rfx_model$predict(rfx_dataset_train, rfx_tracker_train) + ) + } + Ztr_tau0 <- t(Z_basis_mat) %*% as.matrix(partial_resid_tau0) + ZtZ_current <- crossprod(Z_basis_mat) + Sigma_post <- solve( + ZtZ_current / current_sigma2 + diag(p_tau0) / tau_0_prior_var + ) + mu_post_tau0 <- as.numeric(Sigma_post %*% Ztr_tau0 / current_sigma2) + if (p_tau0 == 1) { + tau_0_new <- rnorm(1, mu_post_tau0, sqrt(as.numeric(Sigma_post))) + } else { + tau_0_new <- as.numeric( + mu_post_tau0 + t(chol(Sigma_post)) %*% rnorm(p_tau0) + ) + } + resid_delta <- as.numeric( + Z_basis_mat %*% matrix(tau_0_new - tau_0, ncol = 1) + ) + outcome_train$subtract_vector(resid_delta) + tau_0 <- tau_0_new + if (keep_sample) { + tau_0_samples[, sample_counter] <- tau_0 + } + } + # Sample the treatment forest forest_model_tau$sample_one_iteration( forest_dataset = forest_dataset_train, @@ -2510,25 +2640,31 @@ bcf <- function( rfx_preds_train } - # Compute sufficient statistics for regression of y - mu(X) on [tau(X)(1-Z), tau(X)Z] + # Compute sufficient statistics for regression of y - mu(X) on [tau_total(X)(1-Z), tau_total(X)Z] + # where tau_total(X) = tau_0 + tau(X) when sample_tau_0, else tau(X) + tau_x_for_coding <- if (sample_tau_0) { + tau_x_raw_train + tau_0[1] + } else { + tau_x_raw_train + } s_tt0 <- sum( - tau_x_raw_train * tau_x_raw_train * (Z_train == 0) + tau_x_for_coding * tau_x_for_coding * (Z_train == 0) ) s_tt1 <- sum( - tau_x_raw_train * tau_x_raw_train * (Z_train == 1) + tau_x_for_coding * tau_x_for_coding * (Z_train == 1) ) s_ty0 <- sum( - tau_x_raw_train * + tau_x_for_coding * partial_resid_mu_train * (Z_train == 0) ) s_ty1 <- sum( - tau_x_raw_train * + tau_x_for_coding * partial_resid_mu_train * (Z_train == 1) ) - # Sample b0 (coefficient on tau(X)(1-Z)) and b1 (coefficient on tau(X)Z) + # Sample b0 (coefficient on tau_total(X)(1-Z)) and b1 (coefficient on tau_total(X)Z) current_b_0 <- rnorm( 1, (s_ty0 / (s_tt0 + 2 * current_sigma2)), @@ -2541,6 +2677,9 @@ bcf <- function( ) # Update basis for the leaf regression + if (sample_tau_0) { + tau_basis_old <- tau_basis_train + } tau_basis_train <- (1 - Z_train) * current_b_0 + Z_train * current_b_1 @@ -2562,6 +2701,13 @@ bcf <- function( outcome_train, active_forest_tau ) + + # Fix tau_0 component of residual after basis change + if (sample_tau_0) { + outcome_train$subtract_vector( + as.numeric(tau_basis_train - tau_basis_old) * tau_0[1] + ) + } } # Sample variance parameters (if requested) @@ -2666,6 +2812,12 @@ bcf <- function( b_1_samples <- b_1_samples[(num_gfr + 1):length(b_1_samples)] b_0_samples <- b_0_samples[(num_gfr + 1):length(b_0_samples)] } + if (sample_tau_0) { + tau_0_samples <- tau_0_samples[, + (num_gfr + 1):ncol(tau_0_samples), + drop = FALSE + ] + } muhat_train_raw <- muhat_train_raw[, (num_gfr + 1):ncol(muhat_train_raw) ] @@ -2691,8 +2843,39 @@ bcf <- function( tau_hat_train <- forest_samples_tau$predict_raw(forest_dataset_train) * y_std_train } + # tau_hat_train stores the forest-only component tau(X); compute cate_train + # (tau_0 + tau(X)) separately for the treatment term used in y_hat + if (sample_tau_0) { + tau_0_vec <- as.numeric(tau_0_samples) # num_retained_samples vector (scalar treatment) + if (adaptive_coding) { + # CATE = (b_1 - b_0) * (tau_0 + tau(X)); control adj to mu = b_0 * (tau_0 + tau(X)) + cate_train <- sweep( + tau_hat_train, + 2, + (b_1_samples - b_0_samples) * tau_0_vec * y_std_train, + "+" + ) + mu_hat_train <- sweep( + mu_hat_train, + 2, + b_0_samples * tau_0_vec * y_std_train, + "+" + ) + } else if (!has_multivariate_treatment) { + cate_train <- sweep(tau_hat_train, 2, tau_0_vec * y_std_train, "+") + } else { + # tau_hat_train: n x p x num_retained_samples; tau_0_samples: p x num_retained_samples + cate_train <- tau_hat_train + for (j in seq_len(p_tau0)) { + cate_train[, j, ] <- cate_train[, j, ] + + outer(rep(1, nrow(X_train)), tau_0_samples[j, ] * y_std_train) + } + } + } else { + cate_train <- tau_hat_train + } if (has_multivariate_treatment) { - tau_train_dim <- dim(tau_hat_train) + tau_train_dim <- dim(cate_train) tau_num_obs <- tau_train_dim[1] tau_num_samples <- tau_train_dim[3] treatment_term_train <- matrix( @@ -2702,11 +2885,11 @@ bcf <- function( ) for (i in 1:nrow(Z_train)) { treatment_term_train[i, ] <- colSums( - tau_hat_train[i, , ] * Z_train[i, ] + cate_train[i, , ] * Z_train[i, ] ) } } else { - treatment_term_train <- tau_hat_train * as.numeric(Z_train) + treatment_term_train <- cate_train * as.numeric(Z_train) } y_hat_train <- mu_hat_train + treatment_term_train if (has_test) { @@ -2729,8 +2912,35 @@ bcf <- function( ) * y_std_train } + # tau_hat_test stores forest-only tau(X); compute cate_test for y_hat + if (sample_tau_0) { + if (adaptive_coding) { + cate_test <- sweep( + tau_hat_test, + 2, + (b_1_samples - b_0_samples) * tau_0_vec * y_std_train, + "+" + ) + mu_hat_test <- sweep( + mu_hat_test, + 2, + b_0_samples * tau_0_vec * y_std_train, + "+" + ) + } else if (!has_multivariate_treatment) { + cate_test <- sweep(tau_hat_test, 2, tau_0_vec * y_std_train, "+") + } else { + cate_test <- tau_hat_test + for (j in seq_len(p_tau0)) { + cate_test[, j, ] <- cate_test[, j, ] + + outer(rep(1, nrow(X_test)), tau_0_samples[j, ] * y_std_train) + } + } + } else { + cate_test <- tau_hat_test + } if (has_multivariate_treatment) { - tau_test_dim <- dim(tau_hat_test) + tau_test_dim <- dim(cate_test) tau_num_obs <- tau_test_dim[1] tau_num_samples <- tau_test_dim[3] treatment_term_test <- matrix( @@ -2740,11 +2950,11 @@ bcf <- function( ) for (i in 1:nrow(Z_test)) { treatment_term_test[i, ] <- colSums( - tau_hat_test[i, , ] * Z_test[i, ] + cate_test[i, , ] * Z_test[i, ] ) } } else { - treatment_term_test <- tau_hat_test * as.numeric(Z_test) + treatment_term_test <- cate_test * as.numeric(Z_test) } y_hat_test <- mu_hat_test + treatment_term_test } @@ -2850,6 +3060,8 @@ bcf <- function( "binary_treatment" = binary_treatment, "multivariate_treatment" = has_multivariate_treatment, "adaptive_coding" = adaptive_coding, + "sample_tau_0" = sample_tau_0, + "tau_0_prior_var" = if (sample_tau_0) tau_0_prior_var else NULL, "internal_propensity_model" = internal_propensity_model, "num_samples" = num_retained_samples, "num_gfr" = num_gfr, @@ -2904,6 +3116,9 @@ bcf <- function( result[["b_0_samples"]] = b_0_samples result[["b_1_samples"]] = b_1_samples } + if (sample_tau_0) { + result[["tau_0_samples"]] = tau_0_samples * y_std_train + } if (has_rfx) { result[["rfx_samples"]] = rfx_samples result[["rfx_preds_train"]] = rfx_preds_train @@ -3285,16 +3500,48 @@ predict.bcfmodel <- function( tau_hat_forest <- object$forests_tau$predict_raw(forest_dataset_pred) * y_std } + # tau_hat_forest is the forest-only component tau(X); compute cate_hat_forest + # (tau_0 + tau(X)) for the "cate" term and treatment_term used in y_hat + if (object$model_params$sample_tau_0 && !is.null(object$tau_0_samples)) { + tau_0_samp <- object$tau_0_samples # p_tau0 x num_samples (already in original scale) + if (object$model_params$adaptive_coding) { + cate_hat_forest <- sweep( + tau_hat_forest, + 2, + (object$b_1_samples - object$b_0_samples) * as.numeric(tau_0_samp), + "+" + ) + if (predict_mu_forest || predict_mu_forest_intermediate) { + mu_hat_forest <- sweep( + mu_hat_forest, + 2, + object$b_0_samples * as.numeric(tau_0_samp), + "+" + ) + } + } else if (!object$model_params$multivariate_treatment) { + cate_hat_forest <- sweep(tau_hat_forest, 2, as.numeric(tau_0_samp), "+") + } else { + p_tau0 <- nrow(tau_0_samp) + cate_hat_forest <- tau_hat_forest + for (j in seq_len(p_tau0)) { + cate_hat_forest[, j, ] <- cate_hat_forest[, j, ] + + outer(rep(1, nrow(X)), tau_0_samp[j, ]) + } + } + } else { + cate_hat_forest <- tau_hat_forest + } if (object$model_params$multivariate_treatment) { - tau_dim <- dim(tau_hat_forest) + tau_dim <- dim(cate_hat_forest) tau_num_obs <- tau_dim[1] tau_num_samples <- tau_dim[3] treatment_term <- matrix(NA_real_, nrow = tau_num_obs, tau_num_samples) for (i in 1:nrow(Z)) { - treatment_term[i, ] <- colSums(tau_hat_forest[i, , ] * Z[i, ]) + treatment_term[i, ] <- colSums(cate_hat_forest[i, , ] * Z[i, ]) } } else { - treatment_term <- tau_hat_forest * as.numeric(Z) + treatment_term <- cate_hat_forest * as.numeric(Z) } } @@ -3344,10 +3591,10 @@ predict.bcfmodel <- function( } if (predict_cate_function) { if (tau_cate_separate) { - cate <- (tau_hat_forest + + cate <- (cate_hat_forest + rfx_predictions_raw[, 2:ncol(rfx_basis), ]) } else { - cate <- tau_hat_forest + cate <- cate_hat_forest } } @@ -3521,6 +3768,9 @@ print.bcfmodel <- function(x, ...) { if (x$model_params$sample_sigma2_leaf_tau) { model_terms <- c(model_terms, "treatment effect forest leaf scale model") } + if (x$model_params$sample_tau_0) { + model_terms <- c(model_terms, "treatment effect intercept model") + } if (length(model_terms) > 2) { summary_message <- paste0( "stochtree::bcf() run with ", @@ -3766,6 +4016,25 @@ summary.bcfmodel <- function(object, ...) { print(quantiles_b1) } + # Treatment effect intercept (tau_0) + if (object$model_params$sample_tau_0 && !is.null(object$tau_0_samples)) { + tau_0_vec <- as.numeric(object$tau_0_samples) + n_samples <- ncol(object$tau_0_samples) + mean_tau_0 <- mean(tau_0_vec) + sd_tau_0 <- sd(tau_0_vec) + quantiles_tau_0 <- quantile( + tau_0_vec, + probs = c(0.025, 0.1, 0.25, 0.5, 0.75, 0.9, 0.975) + ) + cat(sprintf( + "Summary of treatment effect intercept (tau_0) posterior: \n%d samples, mean = %.3f, standard deviation = %.3f, quantiles:\n", + n_samples, + mean_tau_0, + sd_tau_0 + )) + print(quantiles_tau_0) + } + # In-sample predictions if (!is.null(object$y_hat_train)) { y_hat_train_mean <- rowMeans(object$y_hat_train) @@ -3942,6 +4211,7 @@ plot.bcfmodel <- function(x, ...) { #' - Test set mean function predictions: `"y_hat_test"` #' - In-sample treatment effect forest predictions: `"tau_hat_train"` #' - Test set treatment effect forest predictions: `"tau_hat_test"` +#' - Treatment effect intercept: `"tau_0"`, `"treatment_intercept"`, `"tau_intercept"` #' - In-sample variance forest predictions: `"sigma2_x_train"`, `"var_x_train"` #' - Test set variance forest predictions: `"sigma2_x_test"`, `"var_x_test"` #' @@ -4080,6 +4350,16 @@ extractParameter.bcfmodel <- function(object, term) { } } + if (term %in% c("tau_0", "treatment_intercept", "tau_intercept")) { + if (!is.null(object$tau_0_samples)) { + return(object$tau_0_samples) + } else { + stop( + "This model does not have treatment effect intercept (tau_0) samples" + ) + } + } + stop(paste0("term ", term, " is not a valid BCF model term")) } @@ -4278,6 +4558,7 @@ saveBCFModelToJson <- function(object) { object$model_params$multivariate_treatment ) jsonobj$add_boolean("adaptive_coding", object$model_params$adaptive_coding) + jsonobj$add_boolean("sample_tau_0", object$model_params$sample_tau_0) jsonobj$add_boolean( "internal_propensity_model", object$model_params$internal_propensity_model @@ -4330,6 +4611,14 @@ saveBCFModelToJson <- function(object) { jsonobj$add_vector("b_1_samples", object$b_1_samples, "parameters") jsonobj$add_vector("b_0_samples", object$b_0_samples, "parameters") } + if (object$model_params$sample_tau_0 && !is.null(object$tau_0_samples)) { + jsonobj$add_scalar("tau_0_dim", nrow(object$tau_0_samples)) + jsonobj$add_vector( + "tau_0_samples", + as.numeric(object$tau_0_samples), + "parameters" + ) + } # Add random effects (if present) if (object$model_params$has_rfx) { @@ -4472,6 +4761,7 @@ createBCFModelFromJson <- function(json_object) { model_params[["adaptive_coding"]] <- json_object$get_boolean( "adaptive_coding" ) + model_params[["sample_tau_0"]] <- json_object$get_boolean("sample_tau_0") model_params[["multivariate_treatment"]] <- json_object$get_boolean( "multivariate_treatment" ) @@ -4526,6 +4816,11 @@ createBCFModelFromJson <- function(json_object) { "parameters" ) } + if (model_params[["sample_tau_0"]]) { + tau_0_dim <- as.integer(json_object$get_scalar("tau_0_dim")) + tau_0_vec <- json_object$get_vector("tau_0_samples", "parameters") + output[["tau_0_samples"]] <- matrix(tau_0_vec, nrow = tau_0_dim) + } # Unpack random effects if (model_params[["has_rfx"]]) { @@ -4698,6 +4993,9 @@ createBCFModelFromCombinedJson <- function(json_object_list) { model_params[["adaptive_coding"]] <- json_object_default$get_boolean( "adaptive_coding" ) + model_params[["sample_tau_0"]] <- json_object_default$get_boolean( + "sample_tau_0" + ) model_params[["multivariate_treatment"]] <- json_object_default$get_boolean( "multivariate_treatment" ) @@ -4845,6 +5143,24 @@ createBCFModelFromCombinedJson <- function(json_object_list) { } } } + if (model_params[["sample_tau_0"]]) { + tau_0_dim <- as.integer(json_object_default$get_scalar("tau_0_dim")) + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + tau_0_mat_i <- matrix( + json_object$get_vector("tau_0_samples", "parameters"), + nrow = tau_0_dim + ) + if (i == 1) { + output[["tau_0_samples"]] <- tau_0_mat_i + } else { + output[["tau_0_samples"]] <- cbind( + output[["tau_0_samples"]], + tau_0_mat_i + ) + } + } + } # Unpack random effects if (model_params[["has_rfx"]]) { @@ -5000,6 +5316,9 @@ createBCFModelFromCombinedJsonString <- function(json_string_list) { model_params[["adaptive_coding"]] <- json_object_default$get_boolean( "adaptive_coding" ) + model_params[["sample_tau_0"]] <- json_object_default$get_boolean( + "sample_tau_0" + ) model_params[[ "internal_propensity_model" ]] <- json_object_default$get_boolean("internal_propensity_model") @@ -5144,6 +5463,24 @@ createBCFModelFromCombinedJsonString <- function(json_string_list) { } } } + if (model_params[["sample_tau_0"]]) { + tau_0_dim <- as.integer(json_object_default$get_scalar("tau_0_dim")) + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + tau_0_mat_i <- matrix( + json_object$get_vector("tau_0_samples", "parameters"), + nrow = tau_0_dim + ) + if (i == 1) { + output[["tau_0_samples"]] <- tau_0_mat_i + } else { + output[["tau_0_samples"]] <- cbind( + output[["tau_0_samples"]], + tau_0_mat_i + ) + } + } + } # Unpack random effects if (model_params[["has_rfx"]]) { diff --git a/R/posterior_transformation.R b/R/posterior_transformation.R index b101981e..e3776393 100644 --- a/R/posterior_transformation.R +++ b/R/posterior_transformation.R @@ -1055,11 +1055,12 @@ computeBCFPosteriorInterval <- function( ) } } - needs_covariates_intermediate <- ((("y_hat" %in% terms) || - ("all" %in% terms))) - needs_covariates <- (("prognostic_function" %in% terms) || - ("cate" %in% terms) || - ("variance_forest" %in% terms) || + predict_terms <- terms + needs_covariates_intermediate <- ((("y_hat" %in% predict_terms) || + ("all" %in% predict_terms))) + needs_covariates <- (("prognostic_function" %in% predict_terms) || + ("cate" %in% predict_terms) || + ("variance_forest" %in% predict_terms) || (needs_covariates_intermediate)) if (needs_covariates) { if (is.null(X)) { @@ -1119,10 +1120,10 @@ computeBCFPosteriorInterval <- function( } } } - needs_rfx_data_intermediate <- ((("y_hat" %in% terms) || - ("all" %in% terms)) && + needs_rfx_data_intermediate <- ((("y_hat" %in% predict_terms) || + ("all" %in% predict_terms)) && model_object$model_params$has_rfx) - needs_rfx_data <- (("rfx" %in% terms) || + needs_rfx_data <- (("rfx" %in% predict_terms) || (needs_rfx_data_intermediate)) if (needs_rfx_data) { if (is.null(rfx_group_ids)) { @@ -1154,42 +1155,47 @@ computeBCFPosteriorInterval <- function( } } - # Compute posterior matrices for the requested model terms - predictions <- predict( - model_object, - X = X, - Z = Z, - propensity = propensity, - rfx_group_ids = rfx_group_ids, - rfx_basis = rfx_basis, - type = "posterior", - terms = terms, - scale = scale - ) - has_multiple_terms <- ifelse(is.list(predictions), TRUE, FALSE) + result <- list() - # Compute the interval - if (has_multiple_terms) { - result <- list() - for (term_name in names(predictions)) { - if (!is.null(predictions[[term_name]])) { - result[[term_name]] <- summarize_interval( - predictions[[term_name]], - sample_dim = 2, - level = level - ) - } else { - result[[term_name]] <- NULL + # Compute posterior matrices for predict-able terms (if any) + if (length(predict_terms) > 0) { + predictions <- predict( + model_object, + X = X, + Z = Z, + propensity = propensity, + rfx_group_ids = rfx_group_ids, + rfx_basis = rfx_basis, + type = "posterior", + terms = predict_terms, + scale = scale + ) + if (is.list(predictions)) { + for (term_name in names(predictions)) { + if (!is.null(predictions[[term_name]])) { + result[[term_name]] <- summarize_interval( + predictions[[term_name]], + sample_dim = 2, + level = level + ) + } else { + result[[term_name]] <- NULL + } } + } else { + result[[predict_terms]] <- summarize_interval( + predictions, + sample_dim = 2, + level = level + ) } - return(result) - } else { - return(summarize_interval( - predictions, - sample_dim = 2, - level = level - )) } + + # Return single interval directly if only one term was requested + if (length(terms) == 1) { + return(result[[terms]]) + } + return(result) } #' @title Compute BART Posterior Credible Intervals diff --git a/demo/notebooks/reparameterized_causal_inference.ipynb b/demo/notebooks/reparameterized_causal_inference.ipynb new file mode 100644 index 00000000..77451e84 --- /dev/null +++ b/demo/notebooks/reparameterized_causal_inference.ipynb @@ -0,0 +1,907 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Reparameterized Causal Inference" + ] + }, + { + "cell_type": "markdown", + "id": "2d2af897", + "metadata": {}, + "source": [ + "The classic BCF model of Hahn, Murray, and Carvalho (2020) is defined as\n", + "\n", + "$$\n", + "Y_i \\mid x_i, z_i \\sim \\mathrm{N}\\!\\left(f_0(x_i) + \\tau(x_i)\\, z_i,\\, \\sigma^2\\right)\n", + "$$\n", + "\n", + "where $f_0$ and $\\tau$ each have BART priors. Separating the prognostic function $f_0(x)$ from the CATE function $\\tau(x)$ can improve estimation in settings with strong confounding and treatment effect heterogeneity.\n", + "\n", + "`stochtree` implements a modification of this model that decomposes the treatment effect function into parametric and nonparametric components:\n", + "\n", + "$$\n", + "Y_i \\mid x_i, z_i \\sim \\mathrm{N}\\!\\left(f_0(x_i) + (\\tau_0 + t(x_i))\\, z_i,\\, \\sigma^2\\right)\n", + "$$\n", + "\n", + "where $\\tau_0 \\sim \\mathrm{N}(0,\\, \\sigma_{\\tau_0}^2)$ is a global treatment effect intercept and $t(x_i)$ is a BART forest capturing heterogeneity around it. This allows the forest term to focus on heterogeneity \"offsets\" relative to a parametric average effect." + ] + }, + { + "cell_type": "markdown", + "id": "9750a3a5", + "metadata": {}, + "source": [ + "Load necessary libraries" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4c271064", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from scipy.stats import norm\n", + "\n", + "from stochtree import BCFModel" + ] + }, + { + "cell_type": "markdown", + "id": "f1dadc4f", + "metadata": {}, + "source": [ + "Set random seed for reproducibility" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d0b5720e", + "metadata": {}, + "outputs": [], + "source": [ + "random_seed = 1234\n", + "rng = np.random.default_rng(random_seed)" + ] + }, + { + "cell_type": "markdown", + "id": "76c88a6e", + "metadata": {}, + "source": [ + "## Binary Treatment with Homogeneous Treatment Effect\n", + "\n", + "Consider the following data generating process:\n", + "\n", + "$$\n", + "\\begin{aligned}\n", + "y &= \\mu(X) + \\tau(X)\\, Z + \\epsilon \\\\\n", + "\\mu(X) &= 2\\sin(2\\pi X_1) - 2(2X_3 - 1) \\\\\n", + "\\tau(X) &= 5 \\\\\n", + "\\pi(X) &= \\Phi\\!\\left(\\mu(X)/4\\right) \\\\\n", + "X_1,\\ldots,X_p &\\sim \\mathrm{Uniform}(0,1) \\\\\n", + "Z &\\sim \\mathrm{Bernoulli}(\\pi(X)) \\\\\n", + "\\epsilon &\\sim \\mathrm{N}(0, \\sigma^2)\n", + "\\end{aligned}\n", + "$$\n", + "\n", + "### Simulation" + ] + }, + { + "cell_type": "markdown", + "id": "7d98efb2", + "metadata": {}, + "source": [ + "We draw from the DGP defined above" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f7914119", + "metadata": {}, + "outputs": [], + "source": [ + "n = 500\n", + "p = 20\n", + "snr = 2\n", + "X = rng.uniform(0, 1, (n, p))\n", + "mu_X = 2 * np.sin(2 * np.pi * X[:, 0]) - 2 * (2 * X[:, 2] - 1)\n", + "tau_X = 5.0\n", + "pi_X = norm.cdf(mu_X / 4)\n", + "Z = rng.binomial(1, pi_X, n).astype(float)\n", + "E_XZ = mu_X + Z * tau_X\n", + "sigma_true = np.std(E_XZ) / snr\n", + "y = E_XZ + rng.standard_normal(n) * sigma_true" + ] + }, + { + "cell_type": "markdown", + "id": "5bd4f6ac", + "metadata": {}, + "source": [ + "And split data into test and train sets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "79f2b8b9", + "metadata": {}, + "outputs": [], + "source": [ + "n_test = round(0.5 * n)\n", + "n_train = n - n_test\n", + "test_inds = np.sort(rng.choice(n, n_test, replace=False))\n", + "train_inds = np.setdiff1d(np.arange(n), test_inds)\n", + "X_train, X_test = X[train_inds], X[test_inds]\n", + "Z_train, Z_test = Z[train_inds], Z[test_inds]\n", + "y_train, y_test = y[train_inds], y[test_inds]\n", + "pi_train, pi_test = pi_X[train_inds], pi_X[test_inds]\n", + "mu_train, mu_test = mu_X[train_inds], mu_X[test_inds]" + ] + }, + { + "cell_type": "markdown", + "id": "847209bd", + "metadata": {}, + "source": [ + "## Sampling and Analysis\n", + "\n", + "### Classic BCF Model\n", + "\n", + "We first fit the classic BCF model with no parametric treatment effect term." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30f5aa2c", + "metadata": {}, + "outputs": [], + "source": [ + "num_gfr = 0\n", + "num_burnin = 1000\n", + "num_mcmc = 500\n", + "num_trees_tau = 50\n", + "general_params = {\n", + " \"adaptive_coding\": True,\n", + " \"num_chains\": 4,\n", + " \"random_seed\": random_seed,\n", + " \"num_threads\": 1,\n", + "}\n", + "treatment_effect_forest_params = {\n", + " \"num_trees\": num_trees_tau,\n", + " \"sample_intercept\": False,\n", + " \"sigma2_leaf_init\": 1 / num_trees_tau,\n", + "}\n", + "bcf_model_classic = BCFModel()\n", + "bcf_model_classic.sample(\n", + " X_train=X_train,\n", + " Z_train=Z_train,\n", + " y_train=y_train,\n", + " propensity_train=pi_train,\n", + " X_test=X_test,\n", + " Z_test=Z_test,\n", + " propensity_test=pi_test,\n", + " num_gfr=num_gfr,\n", + " num_burnin=num_burnin,\n", + " num_mcmc=num_mcmc,\n", + " general_params=general_params,\n", + " treatment_effect_forest_params=treatment_effect_forest_params,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "4f895dbe", + "metadata": {}, + "source": [ + "Compare the posterior distribution of the ATE to its true value" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c951a26d", + "metadata": {}, + "outputs": [], + "source": [ + "cate_posterior_classic = bcf_model_classic.predict(\n", + " X=X_test,\n", + " Z=Z_test,\n", + " propensity=pi_test,\n", + " type=\"posterior\",\n", + " terms=\"cate\",\n", + ")\n", + "ate_posterior_classic = np.mean(cate_posterior_classic, axis=0)\n", + "plt.figure(figsize=(7, 5))\n", + "plt.hist(ate_posterior_classic, density=True, bins=30, color=\"steelblue\", edgecolor=\"white\")\n", + "plt.axvline(tau_X, color=\"red\", linestyle=\"dotted\", linewidth=2, label=f\"True ATE = {tau_X}\")\n", + "plt.xlabel(\"ATE\")\n", + "plt.ylabel(\"Density\")\n", + "plt.title(\"Posterior Distribution of ATE (Classic BCF)\")\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "bf8ef874", + "metadata": {}, + "source": [ + "As a rough convergence check, inspect the traceplot of the global error variance $\\sigma^2$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4c8d9301", + "metadata": {}, + "outputs": [], + "source": [ + "sigma2_samples = bcf_model_classic.global_var_samples\n", + "plt.figure(figsize=(7, 4))\n", + "plt.plot(sigma2_samples, color=\"steelblue\", linewidth=0.8)\n", + "plt.axhline(sigma_true**2, color=\"red\", linestyle=\"dotted\", linewidth=2, label=f\"True $\\\\sigma^2$ = {sigma_true**2:.3f}\")\n", + "plt.xlabel(\"Iteration\")\n", + "plt.ylabel(\"$\\\\sigma^2$\")\n", + "plt.title(\"Traceplot of $\\\\sigma^2$ (Classic BCF)\")\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "90baef7c", + "metadata": {}, + "source": [ + "### Reparameterized BCF Model\n", + "\n", + "Now we fit the reparameterized model, regularizing the $t(x)$ forest more heavily to account for the standard normal prior on $\\tau_0$." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a3df8cf9", + "metadata": {}, + "outputs": [], + "source": [ + "num_trees_tau = 50\n", + "general_params = {\n", + " \"adaptive_coding\": False,\n", + " \"num_chains\": 4,\n", + " \"random_seed\": random_seed,\n", + " \"num_threads\": 1,\n", + "}\n", + "treatment_effect_forest_params = {\n", + " \"num_trees\": num_trees_tau,\n", + " \"sample_intercept\": True,\n", + " \"sigma2_leaf_init\": 0.25 / num_trees_tau,\n", + " \"tau_0_prior_var\": 1.0,\n", + "}\n", + "bcf_model_reparam = BCFModel()\n", + "bcf_model_reparam.sample(\n", + " X_train=X_train,\n", + " Z_train=Z_train,\n", + " y_train=y_train,\n", + " propensity_train=pi_train,\n", + " X_test=X_test,\n", + " Z_test=Z_test,\n", + " propensity_test=pi_test,\n", + " num_gfr=num_gfr,\n", + " num_burnin=num_burnin,\n", + " num_mcmc=num_mcmc,\n", + " general_params=general_params,\n", + " treatment_effect_forest_params=treatment_effect_forest_params,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "b797238b", + "metadata": {}, + "source": [ + "Compare the posterior distribution of the ATE to its true value" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9178ed4c", + "metadata": {}, + "outputs": [], + "source": [ + "cate_posterior_reparam = bcf_model_reparam.predict(\n", + " X=X_test,\n", + " Z=Z_test,\n", + " propensity=pi_test,\n", + " type=\"posterior\",\n", + " terms=\"cate\",\n", + ")\n", + "ate_posterior_reparam = np.mean(cate_posterior_reparam, axis=0)\n", + "plt.figure(figsize=(7, 5))\n", + "plt.hist(ate_posterior_reparam, density=True, bins=30, color=\"steelblue\", edgecolor=\"white\")\n", + "plt.axvline(tau_X, color=\"red\", linestyle=\"dotted\", linewidth=2, label=f\"True ATE = {tau_X}\")\n", + "plt.xlabel(\"ATE\")\n", + "plt.ylabel(\"Density\")\n", + "plt.title(\"Posterior Distribution of ATE (Reparameterized BCF)\")\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "9491f868", + "metadata": {}, + "source": [ + "Convergence check: traceplot of $\\sigma^2$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2b9f8d88", + "metadata": {}, + "outputs": [], + "source": [ + "sigma2_samples = bcf_model_reparam.global_var_samples\n", + "plt.figure(figsize=(7, 4))\n", + "plt.plot(sigma2_samples, color=\"steelblue\", linewidth=0.8)\n", + "plt.axhline(sigma_true**2, color=\"red\", linestyle=\"dotted\", linewidth=2, label=f\"True $\\\\sigma^2$ = {sigma_true**2:.3f}\")\n", + "plt.xlabel(\"Iteration\")\n", + "plt.ylabel(\"$\\\\sigma^2$\")\n", + "plt.title(\"Traceplot of $\\\\sigma^2$ (Reparameterized BCF)\")\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "37cc6ca0", + "metadata": {}, + "source": [ + "Since $t(X)$ is not constrained to sum to zero, $\\tau_0$ does not directly identify the ATE. We can see this by comparing the posteriors of $\\tau_0$ and $\\bar{t}(X)$ (the test-set mean of $t(X)$ for each posterior draw) — they are strongly negatively correlated, reflecting the partial non-identifiability between the intercept and the forest mean." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a91a86da", + "metadata": {}, + "outputs": [], + "source": [ + "tau_0_posterior = bcf_model_reparam.tau_0_samples[0, :]\n", + "tau_x_posterior = bcf_model_reparam.predict(\n", + " X=X_test,\n", + " Z=Z_test,\n", + " propensity=pi_test,\n", + " type=\"posterior\",\n", + " terms=\"tau\",\n", + ")\n", + "t_x_mean = np.mean(tau_x_posterior, axis=0)\n", + "plt.figure(figsize=(6, 5))\n", + "plt.scatter(tau_0_posterior, t_x_mean, alpha=0.3, s=10, color=\"steelblue\")\n", + "plt.xlabel(\"$\\\\tau_0$\")\n", + "plt.ylabel(\"$\\\\bar{t}(X)$\")\n", + "plt.title(\"Posterior of $\\\\tau_0$ vs $\\\\bar{t}(X)$\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "e86ceb51", + "metadata": {}, + "source": [ + "While `stochtree` does not currently support constraining $t(X)$ to sum to zero over the training set, we can more heavily regularize $t(X)$ so its values stay close to zero. Using a single tree with a very small leaf scale effectively collapses the forest to a constant near zero, making $\\tau_0$ the primary vehicle for the treatment effect." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "637f916d", + "metadata": {}, + "outputs": [], + "source": [ + "general_params = {\n", + " \"adaptive_coding\": False,\n", + " \"num_chains\": 4,\n", + " \"random_seed\": random_seed,\n", + " \"num_threads\": 1,\n", + "}\n", + "treatment_effect_forest_params = {\n", + " \"num_trees\": 1,\n", + " \"sample_intercept\": True,\n", + " \"sigma2_leaf_init\": 1e-6,\n", + " \"tau_0_prior_var\": 1.0,\n", + "}\n", + "bcf_model_reparam_shrunk = BCFModel()\n", + "bcf_model_reparam_shrunk.sample(\n", + " X_train=X_train,\n", + " Z_train=Z_train,\n", + " y_train=y_train,\n", + " propensity_train=pi_train,\n", + " X_test=X_test,\n", + " Z_test=Z_test,\n", + " propensity_test=pi_test,\n", + " num_gfr=num_gfr,\n", + " num_burnin=num_burnin,\n", + " num_mcmc=num_mcmc,\n", + " general_params=general_params,\n", + " treatment_effect_forest_params=treatment_effect_forest_params,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6fb51063", + "metadata": {}, + "outputs": [], + "source": [ + "cate_posterior_shrunk = bcf_model_reparam_shrunk.predict(\n", + " X=X_test,\n", + " Z=Z_test,\n", + " propensity=pi_test,\n", + " type=\"posterior\",\n", + " terms=\"cate\",\n", + ")\n", + "ate_posterior_shrunk = np.mean(cate_posterior_shrunk, axis=0)\n", + "plt.figure(figsize=(7, 5))\n", + "plt.hist(ate_posterior_shrunk, density=True, bins=30, color=\"steelblue\", edgecolor=\"white\")\n", + "plt.axvline(tau_X, color=\"red\", linestyle=\"dotted\", linewidth=2, label=f\"True ATE = {tau_X}\")\n", + "plt.xlabel(\"ATE\")\n", + "plt.ylabel(\"Density\")\n", + "plt.title(\"Posterior Distribution of ATE (Shrunk Forest)\")\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "1ea05eb9", + "metadata": {}, + "source": [ + "With the forest heavily regularized, $\\tau_0$ and $\\bar{t}(X)$ are no longer correlated — $\\bar{t}(X)$ is near zero and $\\tau_0$ directly captures the treatment effect." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "99fa0c45", + "metadata": {}, + "outputs": [], + "source": [ + "tau_0_posterior_shrunk = bcf_model_reparam_shrunk.tau_0_samples[0, :]\n", + "tau_x_posterior_shrunk = bcf_model_reparam_shrunk.predict(\n", + " X=X_test,\n", + " Z=Z_test,\n", + " propensity=pi_test,\n", + " type=\"posterior\",\n", + " terms=\"tau\",\n", + ")\n", + "t_x_mean_shrunk = np.mean(tau_x_posterior_shrunk, axis=0)\n", + "plt.figure(figsize=(6, 5))\n", + "plt.scatter(tau_0_posterior_shrunk, t_x_mean_shrunk, alpha=0.3, s=10, color=\"steelblue\")\n", + "plt.xlabel(\"$\\\\tau_0$\")\n", + "plt.ylabel(\"$\\\\bar{t}(X)$\")\n", + "plt.title(\"Posterior of $\\\\tau_0$ vs $\\\\bar{t}(X)$ (Shrunk Forest)\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "ac87f5dd", + "metadata": {}, + "source": [ + "We can further regularize estimation of the ATE by reducing $\\sigma_{\\tau_0}^2$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a3e461f6", + "metadata": {}, + "outputs": [], + "source": [ + "general_params = {\n", + " \"adaptive_coding\": False,\n", + " \"num_chains\": 4,\n", + " \"random_seed\": random_seed,\n", + " \"num_threads\": 1,\n", + "}\n", + "treatment_effect_forest_params = {\n", + " \"num_trees\": 1,\n", + " \"sample_intercept\": True,\n", + " \"sigma2_leaf_init\": 1e-6,\n", + " \"tau_0_prior_var\": 0.05,\n", + "}\n", + "bcf_model_tight_prior = BCFModel()\n", + "bcf_model_tight_prior.sample(\n", + " X_train=X_train,\n", + " Z_train=Z_train,\n", + " y_train=y_train,\n", + " propensity_train=pi_train,\n", + " X_test=X_test,\n", + " Z_test=Z_test,\n", + " propensity_test=pi_test,\n", + " num_gfr=num_gfr,\n", + " num_burnin=num_burnin,\n", + " num_mcmc=num_mcmc,\n", + " general_params=general_params,\n", + " treatment_effect_forest_params=treatment_effect_forest_params,\n", + ")\n", + "cate_posterior_tight = bcf_model_tight_prior.predict(\n", + " X=X_test,\n", + " Z=Z_test,\n", + " propensity=pi_test,\n", + " type=\"posterior\",\n", + " terms=\"cate\",\n", + ")\n", + "ate_posterior_tight = np.mean(cate_posterior_tight, axis=0)\n", + "plt.figure(figsize=(7, 5))\n", + "plt.hist(ate_posterior_tight, density=True, bins=30, color=\"steelblue\", edgecolor=\"white\")\n", + "plt.axvline(tau_X, color=\"red\", linestyle=\"dotted\", linewidth=2, label=f\"True ATE = {tau_X}\")\n", + "plt.xlabel(\"ATE\")\n", + "plt.ylabel(\"Density\")\n", + "plt.title(\"Posterior Distribution of ATE\")\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "6b330905", + "metadata": {}, + "source": [ + "## Continuous Treatment with Homogeneous Treatment Effect\n", + "\n", + "The $\\tau_0 + t(x)$ reparameterization generalizes naturally to continuous treatment. With a continuous $Z$, $\\tau(x)$ represents the marginal effect of a one-unit increase in $Z$, and $\\tau_0$ captures the homogeneous component of that effect.\n", + "\n", + "Consider the following data generating process:\n", + "\n", + "$$\n", + "\\begin{aligned}\n", + "y &= \\mu(X) + \\tau(X)\\, Z + \\epsilon \\\\\n", + "\\mu(X) &= 2\\sin(2\\pi X_1) - 2(2X_3 - 1) \\\\\n", + "\\tau(X) &= 2 \\\\\n", + "\\pi(X) &= \\mathrm{E}[Z \\mid X] = \\mu(X)/8 \\\\\n", + "Z \\mid X &\\sim \\mathrm{N}(\\pi(X),\\, 1) \\\\\n", + "\\epsilon &\\sim \\mathrm{N}(0, \\sigma^2)\n", + "\\end{aligned}\n", + "$$\n", + "\n", + "### Simulation" + ] + }, + { + "cell_type": "markdown", + "id": "d2228d6c", + "metadata": {}, + "source": [ + "We draw from the DGP defined above" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9bdfce2c", + "metadata": {}, + "outputs": [], + "source": [ + "n = 500\n", + "p = 20\n", + "snr = 2\n", + "X = rng.uniform(0, 1, (n, p))\n", + "mu_X = 2 * np.sin(2 * np.pi * X[:, 0]) - 2 * (2 * X[:, 2] - 1)\n", + "tau_X = 2.0\n", + "pi_X = mu_X / 8\n", + "Z = pi_X + rng.standard_normal(n)\n", + "E_XZ = mu_X + Z * tau_X\n", + "sigma_true = np.std(E_XZ) / snr\n", + "y = E_XZ + rng.standard_normal(n) * sigma_true" + ] + }, + { + "cell_type": "markdown", + "id": "a9602f53", + "metadata": {}, + "source": [ + "And split data into test and train sets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f6653df2", + "metadata": {}, + "outputs": [], + "source": [ + "n_test = round(0.5 * n)\n", + "n_train = n - n_test\n", + "test_inds = np.sort(rng.choice(n, n_test, replace=False))\n", + "train_inds = np.setdiff1d(np.arange(n), test_inds)\n", + "X_train, X_test = X[train_inds], X[test_inds]\n", + "Z_train, Z_test = Z[train_inds], Z[test_inds]\n", + "y_train, y_test = y[train_inds], y[test_inds]\n", + "pi_train, pi_test = pi_X[train_inds], pi_X[test_inds]\n", + "mu_train, mu_test = mu_X[train_inds], mu_X[test_inds]" + ] + }, + { + "cell_type": "markdown", + "id": "a7ab61e7", + "metadata": {}, + "source": [ + "## Sampling and Analysis\n", + "\n", + "Note that `adaptive_coding` must be `False` for continuous treatment, since the adaptive coding scheme is designed for binary treatment.\n", + "\n", + "### Classic BCF Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c6d33f8", + "metadata": {}, + "outputs": [], + "source": [ + "num_gfr = 0\n", + "num_burnin = 1000\n", + "num_mcmc = 500\n", + "num_trees_tau = 50\n", + "general_params = {\n", + " \"adaptive_coding\": False,\n", + " \"num_chains\": 4,\n", + " \"random_seed\": random_seed,\n", + " \"num_threads\": 1,\n", + "}\n", + "treatment_effect_forest_params = {\n", + " \"num_trees\": num_trees_tau,\n", + " \"sample_intercept\": False,\n", + " \"sigma2_leaf_init\": 1 / num_trees_tau,\n", + "}\n", + "bcf_model_classic = BCFModel()\n", + "bcf_model_classic.sample(\n", + " X_train=X_train,\n", + " Z_train=Z_train,\n", + " y_train=y_train,\n", + " propensity_train=pi_train,\n", + " X_test=X_test,\n", + " Z_test=Z_test,\n", + " propensity_test=pi_test,\n", + " num_gfr=num_gfr,\n", + " num_burnin=num_burnin,\n", + " num_mcmc=num_mcmc,\n", + " general_params=general_params,\n", + " treatment_effect_forest_params=treatment_effect_forest_params,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "c4599eb3", + "metadata": {}, + "source": [ + "We compare the posterior distribution of the ATE to its true value" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "09767571", + "metadata": {}, + "outputs": [], + "source": [ + "cate_posterior_classic = bcf_model_classic.predict(\n", + " X=X_test,\n", + " Z=Z_test,\n", + " propensity=pi_test,\n", + " type=\"posterior\",\n", + " terms=\"cate\",\n", + ")\n", + "ate_posterior_classic = np.mean(cate_posterior_classic, axis=0)\n", + "plt.figure(figsize=(7, 5))\n", + "plt.hist(ate_posterior_classic, density=True, bins=30, color=\"steelblue\", edgecolor=\"white\")\n", + "plt.axvline(tau_X, color=\"red\", linestyle=\"dotted\", linewidth=2, label=\"True ATE\")\n", + "plt.xlabel(\"ATE\")\n", + "plt.ylabel(\"Density\")\n", + "plt.title(\"Posterior Distribution of ATE\")\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "e6d5e3da", + "metadata": {}, + "source": [ + "As a rough convergence check, we inspect the traceplot of $\\sigma^2$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c420812e", + "metadata": {}, + "outputs": [], + "source": [ + "sigma2_samples = bcf_model_classic.global_var_samples\n", + "plt.figure(figsize=(7, 4))\n", + "plt.plot(sigma2_samples, color=\"steelblue\", linewidth=0.8)\n", + "plt.axhline(sigma_true**2, color=\"red\", linestyle=\"dotted\", linewidth=2, label=\"True $\\\\sigma^2$\")\n", + "plt.xlabel(\"Iteration\")\n", + "plt.ylabel(\"$\\\\sigma^2$\")\n", + "plt.title(\"Traceplot of $\\\\sigma^2$\")\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "6d3dff4d", + "metadata": {}, + "source": [ + "### Reparameterized BCF Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d66ec013", + "metadata": {}, + "outputs": [], + "source": [ + "general_params = {\n", + " \"adaptive_coding\": False,\n", + " \"num_chains\": 4,\n", + " \"random_seed\": random_seed,\n", + " \"num_threads\": 1,\n", + "}\n", + "treatment_effect_forest_params = {\n", + " \"num_trees\": num_trees_tau,\n", + " \"sample_intercept\": True,\n", + " \"sigma2_leaf_init\": 0.25 / num_trees_tau,\n", + " \"tau_0_prior_var\": 1.0,\n", + "}\n", + "bcf_model_reparam = BCFModel()\n", + "bcf_model_reparam.sample(\n", + " X_train=X_train,\n", + " Z_train=Z_train,\n", + " y_train=y_train,\n", + " propensity_train=pi_train,\n", + " X_test=X_test,\n", + " Z_test=Z_test,\n", + " propensity_test=pi_test,\n", + " num_gfr=num_gfr,\n", + " num_burnin=num_burnin,\n", + " num_mcmc=num_mcmc,\n", + " general_params=general_params,\n", + " treatment_effect_forest_params=treatment_effect_forest_params,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "7f3122b8", + "metadata": {}, + "source": [ + "And we compare the posterior distribution of the ATE to its true value" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "da9a95aa", + "metadata": {}, + "outputs": [], + "source": [ + "cate_posterior_reparam = bcf_model_reparam.predict(\n", + " X=X_test,\n", + " Z=Z_test,\n", + " propensity=pi_test,\n", + " type=\"posterior\",\n", + " terms=\"cate\",\n", + ")\n", + "ate_posterior_reparam = np.mean(cate_posterior_reparam, axis=0)\n", + "plt.figure(figsize=(7, 5))\n", + "plt.hist(ate_posterior_reparam, density=True, bins=30, color=\"steelblue\", edgecolor=\"white\")\n", + "plt.axvline(tau_X, color=\"red\", linestyle=\"dotted\", linewidth=2, label=f\"True ATE = {tau_X}\")\n", + "plt.xlabel(\"ATE\")\n", + "plt.ylabel(\"Density\")\n", + "plt.title(\"Posterior Distribution of ATE (Reparameterized BCF, Continuous Treatment)\")\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "12494281", + "metadata": {}, + "source": [ + "As above, we check convergence by inspecting the traceplot of $\\sigma^2$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e1a1c2d9", + "metadata": {}, + "outputs": [], + "source": [ + "sigma2_samples = bcf_model_reparam.global_var_samples\n", + "plt.figure(figsize=(7, 4))\n", + "plt.plot(sigma2_samples, color=\"steelblue\", linewidth=0.8)\n", + "plt.axhline(sigma_true**2, color=\"red\", linestyle=\"dotted\", linewidth=2, label=\"True $\\\\sigma^2$\")\n", + "plt.xlabel(\"Iteration\")\n", + "plt.ylabel(\"$\\\\sigma^2$\")\n", + "plt.title(\"Traceplot of $\\\\sigma^2$\")\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "cccae071", + "metadata": {}, + "source": [ + "As in the binary treatment case, $\\tau_0$ and $\\bar{t}(X)$ are negatively correlated across posterior draws" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d45496ef", + "metadata": {}, + "outputs": [], + "source": [ + "tau_0_posterior = bcf_model_reparam.tau_0_samples[0, :]\n", + "tau_x_posterior = bcf_model_reparam.predict(\n", + " X=X_test,\n", + " Z=Z_test,\n", + " propensity=pi_test,\n", + " type=\"posterior\",\n", + " terms=\"tau\",\n", + ")\n", + "t_x_mean = np.mean(tau_x_posterior, axis=0)\n", + "plt.figure(figsize=(6, 5))\n", + "plt.scatter(tau_0_posterior, t_x_mean, alpha=0.3, s=10, color=\"steelblue\")\n", + "plt.xlabel(\"$\\\\tau_0$\")\n", + "plt.ylabel(\"$\\\\bar{t}(X)$\")\n", + "plt.title(\"Posterior of $\\\\tau_0$ vs $\\\\bar{t}(X)$\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "0c1cb45b", + "metadata": {}, + "source": [ + "# References\n", + "\n", + "Hahn, P Richard, Jared S Murray, and Carlos M Carvalho. 2020. “Bayesian Regression Tree Models for Causal Inference: Regularization, Confounding, and Heterogeneous Effects (with Discussion).” *Bayesian Analysis* 15 (3): 965–1056." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/stochtree/bcf.py b/stochtree/bcf.py index fef6951f..8476e04f 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -184,19 +184,21 @@ def sample( treatment_effect_forest_params : dict, optional Dictionary of treatment effect forest model parameters, each of which has a default value processed internally, so this argument is optional. - * `num_trees` (`int`): Number of trees in the treatment effect forest. Defaults to `50`. Must be a positive integer. + * `num_trees` (`int`): Number of trees in the treatment effect forest. Defaults to `100`. Must be a positive integer. * `alpha` (`float`): Prior probability of splitting for a tree of depth 0 in the treatment effect forest. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Defaults to `0.25`. * `beta` (`float`): Exponent that decreases split probabilities for nodes of depth > 0 in the treatment effect forest. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Defaults to `3`. * `min_samples_leaf` (`int`): Minimum allowable size of a leaf, in terms of training samples, in the treatment effect forest. Defaults to `5`. * `max_depth` (`int`): Maximum depth of any tree in the ensemble in the treatment effect forest. Defaults to `5`. Can be overriden with `-1` which does not enforce any depth limits on trees. * `sample_sigma2_leaf` (`bool`): Whether or not to update the `tau` leaf scale variance parameter based on `IG(sigma2_leaf_shape, sigma2_leaf_scale)`. Cannot (currently) be set to true if `basis_train` has more than one column. Defaults to `False`. - * `sigma2_leaf_init` (`float`): Starting value of leaf node scale parameter. Calibrated internally as `1/num_trees` if not set here. + * `sigma2_leaf_init` (`float`): Starting value of leaf node scale parameter. Calibrated internally as `0.5 * np.var(y) / num_trees` if not set here (`0.5 / num_trees` if `y` is continuous and `standardize: True` in the `general_params` dictionary). * `sigma2_leaf_shape` (`float`): Shape parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Defaults to `3`. * `sigma2_leaf_scale` (`float`): Scale parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here. * `delta_max` (`float`): Maximum plausible conditional distributional treatment effect (i.e. P(Y(1) = 1 | X) - P(Y(0) = 1 | X)) when the outcome is binary. Only used when the outcome is specified as a probit model in `general_params`. Must be > 0 and < 1. Defaults to `0.9`. Ignored if `sigma2_leaf_init` is set directly, as this parameter is used to calibrate `sigma2_leaf_init`. * `keep_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be included in the treatment effect (`tau(X)`) forest. Defaults to `None`. * `drop_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be excluded from the treatment effect (`tau(X)`) forest. Defaults to `None`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored. * `num_features_subsample` (`int`): How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset. + * `sample_intercept` (`bool`): Whether to sample a global treatment effect intercept `tau_0` so the full CATE is `tau_0 + tau(X)`. Defaults to `True`. Compatible with `adaptive_coding = True`, in which case the recoded treatment basis is used. + * `tau_0_prior_var` (`float`): Variance of the normal prior on `tau_0` (applied independently to each treatment dimension). Auto-calibrated to outcome variance when `None` and outcome is continuous. Only used when `sample_intercept = True`. variance_forest_params : dict, optional Dictionary of variance forest model parameters, each of which has a default value processed internally, so this argument is optional. @@ -282,7 +284,7 @@ def sample( # Update tau forest BART parameters treatment_effect_forest_params_default = { - "num_trees": 50, + "num_trees": 100, "alpha": 0.25, "beta": 3.0, "min_samples_leaf": 5, @@ -295,6 +297,8 @@ def sample( "keep_vars": None, "drop_vars": None, "num_features_subsample": None, + "sample_intercept": True, + "tau_0_prior_var": None, } treatment_effect_forest_params_updated = _preprocess_params( treatment_effect_forest_params_default, treatment_effect_forest_params @@ -391,6 +395,8 @@ def sample( num_features_subsample_tau = treatment_effect_forest_params_updated[ "num_features_subsample" ] + self.sample_tau_0 = treatment_effect_forest_params_updated["sample_intercept"] + tau_0_prior_var = treatment_effect_forest_params_updated["tau_0_prior_var"] # 4. Variance forest parameters num_trees_variance = variance_forest_params_updated["num_trees"] @@ -1372,6 +1378,11 @@ def sample( ) self.adaptive_coding = False + # Validate tau_0_prior_var if sample_tau_0 is True + if self.sample_tau_0 and tau_0_prior_var is not None: + if not isinstance(tau_0_prior_var, (int, float)) or tau_0_prior_var <= 0: + raise ValueError("tau_0_prior_var must be a single positive numeric value") + # Sampling sigma2_leaf_tau will be ignored for multivariate treatments if sample_sigma2_leaf_tau and self.multivariate_treatment: warnings.warn( @@ -1623,7 +1634,7 @@ def sample( else: raise ValueError("sigma2_leaf_mu must be a scalar") sigma2_leaf_tau = ( - np.squeeze(np.var(resid_train)) / (num_trees_tau) + np.squeeze(np.var(resid_train) * 0.5) / (num_trees_tau) if sigma2_leaf_tau is None else sigma2_leaf_tau ) @@ -1830,6 +1841,9 @@ def sample( self.leaf_scale_mu_samples = np.empty(self.num_samples, dtype=np.float64) if sample_sigma2_leaf_tau: self.leaf_scale_tau_samples = np.empty(self.num_samples, dtype=np.float64) + if self.sample_tau_0: + p_tau0 = Z_train.shape[1] if Z_train.ndim > 1 else 1 + self.tau_0_samples = np.empty((p_tau0, self.num_samples), dtype=np.float64) muhat_train_raw = np.empty((self.n_train, self.num_samples), dtype=np.float64) if self.include_variance_forest: sigma2_x_train_raw = np.empty( @@ -1855,6 +1869,13 @@ def sample( if self.has_test: tau_basis_test = Z_test + # Prepare tau_0 (global treatment effect intercept) structure + if self.sample_tau_0: + tau_0 = np.zeros(p_tau0) + # Auto-calibrate prior variance if not provided + if tau_0_prior_var is None: + tau_0_prior_var = np.var(resid_train) + # Prognostic Forest Dataset (covariates) forest_dataset_train = Dataset() forest_dataset_train.add_covariates(X_train_processed) @@ -2077,6 +2098,28 @@ def sample( current_leaf_scale_mu[0, 0] ) + # Sample tau_0 (global treatment effect intercept, if requested) + if self.sample_tau_0: + mu_x_tau0 = np.squeeze(active_forest_mu.predict_raw(forest_dataset_train)) + tau_x_raw_tau0 = active_forest_tau.predict_raw(forest_dataset_train) + Z_basis = tau_basis_train.reshape(-1, 1) if tau_basis_train.ndim == 1 else tau_basis_train + tau_x_raw_2d = tau_x_raw_tau0.reshape(self.n_train, -1) + tau_x_full = np.sum(Z_basis * tau_x_raw_2d, axis=1) + partial_resid_tau0 = np.squeeze(resid_train) - mu_x_tau0 - tau_x_full + if self.has_rfx: + partial_resid_tau0 = partial_resid_tau0 - np.squeeze( + rfx_model.predict(rfx_dataset_train, rfx_tracker) + ) + Ztr = Z_basis.T @ partial_resid_tau0 + ZtZ_current = Z_basis.T @ Z_basis + Sigma_post = np.linalg.inv(ZtZ_current / current_sigma2 + np.eye(p_tau0) / tau_0_prior_var) + mu_post = Sigma_post @ Ztr / current_sigma2 + tau_0_new = self.rng.multivariate_normal(mean=mu_post, cov=Sigma_post) + residual_train.add_vector(-np.squeeze(Z_basis @ (tau_0_new - tau_0))) + tau_0 = tau_0_new + if keep_sample: + self.tau_0_samples[:, sample_counter] = tau_0 + # Sample the treatment forest forest_sampler_tau.sample_one_iteration( self.forest_container_tau, @@ -2107,13 +2150,15 @@ def sample( rfx_model.predict(rfx_dataset_train, rfx_tracker) ) partial_resid_train = partial_resid_train - rfx_pred - s_tt0 = np.sum(tau_x * tau_x * (np.squeeze(Z_train) == 0)) - s_tt1 = np.sum(tau_x * tau_x * (np.squeeze(Z_train) == 1)) + # Use tau_total = tau_0 + tau(X) for sufficient stats when sample_tau_0 + tau_x_for_coding = (tau_x + tau_0[0]) if self.sample_tau_0 else tau_x + s_tt0 = np.sum(tau_x_for_coding * tau_x_for_coding * (np.squeeze(Z_train) == 0)) + s_tt1 = np.sum(tau_x_for_coding * tau_x_for_coding * (np.squeeze(Z_train) == 1)) s_ty0 = np.sum( - tau_x * partial_resid_train * (np.squeeze(Z_train) == 0) + tau_x_for_coding * partial_resid_train * (np.squeeze(Z_train) == 0) ) s_ty1 = np.sum( - tau_x * partial_resid_train * (np.squeeze(Z_train) == 1) + tau_x_for_coding * partial_resid_train * (np.squeeze(Z_train) == 1) ) current_b_0 = self.rng.normal( loc=(s_ty0 / (s_tt0 + 2 * current_sigma2)), @@ -2125,6 +2170,8 @@ def sample( scale=np.sqrt(current_sigma2 / (s_tt1 + 2 * current_sigma2)), size=1, )[0] + if self.sample_tau_0: + tau_basis_old = np.squeeze(tau_basis_train).copy() tau_basis_train = ( 1 - np.squeeze(Z_train) ) * current_b_0 + np.squeeze(Z_train) * current_b_1 @@ -2143,6 +2190,12 @@ def sample( forest_dataset_train, residual_train, active_forest_tau ) + # Fix tau_0 component of residual after basis change + if self.sample_tau_0: + residual_train.add_vector( + -(np.squeeze(tau_basis_train) - tau_basis_old) * tau_0[0] + ) + # Sample the variance forest if self.include_variance_forest: forest_sampler_variance.sample_one_iteration( @@ -2557,6 +2610,28 @@ def sample( current_leaf_scale_mu[0, 0] ) + # Sample tau_0 (global treatment effect intercept, if requested) + if self.sample_tau_0: + mu_x_tau0 = np.squeeze(active_forest_mu.predict_raw(forest_dataset_train)) + tau_x_raw_tau0 = active_forest_tau.predict_raw(forest_dataset_train) + Z_basis = tau_basis_train.reshape(-1, 1) if tau_basis_train.ndim == 1 else tau_basis_train + tau_x_raw_2d = tau_x_raw_tau0.reshape(self.n_train, -1) + tau_x_full = np.sum(Z_basis * tau_x_raw_2d, axis=1) + partial_resid_tau0 = np.squeeze(resid_train) - mu_x_tau0 - tau_x_full + if self.has_rfx: + partial_resid_tau0 = partial_resid_tau0 - np.squeeze( + rfx_model.predict(rfx_dataset_train, rfx_tracker) + ) + Ztr = Z_basis.T @ partial_resid_tau0 + ZtZ_current = Z_basis.T @ Z_basis + Sigma_post = np.linalg.inv(ZtZ_current / current_sigma2 + np.eye(p_tau0) / tau_0_prior_var) + mu_post = Sigma_post @ Ztr / current_sigma2 + tau_0_new = self.rng.multivariate_normal(mean=mu_post, cov=Sigma_post) + residual_train.add_vector(-np.squeeze(Z_basis @ (tau_0_new - tau_0))) + tau_0 = tau_0_new + if keep_sample: + self.tau_0_samples[:, sample_counter] = tau_0 + # Sample the treatment forest forest_sampler_tau.sample_one_iteration( self.forest_container_tau, @@ -2587,13 +2662,15 @@ def sample( rfx_model.predict(rfx_dataset_train, rfx_tracker) ) partial_resid_train = partial_resid_train - rfx_pred - s_tt0 = np.sum(tau_x * tau_x * (np.squeeze(Z_train) == 0)) - s_tt1 = np.sum(tau_x * tau_x * (np.squeeze(Z_train) == 1)) + # Use tau_total = tau_0 + tau(X) for sufficient stats when sample_tau_0 + tau_x_for_coding = (tau_x + tau_0[0]) if self.sample_tau_0 else tau_x + s_tt0 = np.sum(tau_x_for_coding * tau_x_for_coding * (np.squeeze(Z_train) == 0)) + s_tt1 = np.sum(tau_x_for_coding * tau_x_for_coding * (np.squeeze(Z_train) == 1)) s_ty0 = np.sum( - tau_x * partial_resid_train * (np.squeeze(Z_train) == 0) + tau_x_for_coding * partial_resid_train * (np.squeeze(Z_train) == 0) ) s_ty1 = np.sum( - tau_x * partial_resid_train * (np.squeeze(Z_train) == 1) + tau_x_for_coding * partial_resid_train * (np.squeeze(Z_train) == 1) ) current_b_0 = self.rng.normal( loc=(s_ty0 / (s_tt0 + 2 * current_sigma2)), @@ -2609,6 +2686,8 @@ def sample( ), size=1, )[0] + if self.sample_tau_0: + tau_basis_old = np.squeeze(tau_basis_train).copy() tau_basis_train = ( 1 - np.squeeze(Z_train) ) * current_b_0 + np.squeeze(Z_train) * current_b_1 @@ -2627,6 +2706,12 @@ def sample( forest_dataset_train, residual_train, active_forest_tau ) + # Fix tau_0 component of residual after basis change + if self.sample_tau_0: + residual_train.add_vector( + -(tau_basis_train - tau_basis_old) * tau_0[0] + ) + # Sample the variance forest if self.include_variance_forest: forest_sampler_variance.sample_one_iteration( @@ -2697,6 +2782,8 @@ def sample( if self.adaptive_coding: self.b1_samples = self.b1_samples[num_gfr:] self.b0_samples = self.b0_samples[num_gfr:] + if self.sample_tau_0: + self.tau_0_samples = self.tau_0_samples[:, num_gfr:] if self.sample_sigma2_global: self.global_var_samples = self.global_var_samples[num_gfr:] if self.sample_sigma2_leaf_mu: @@ -2723,12 +2810,34 @@ def sample( self.tau_hat_train = self.tau_hat_train * adaptive_coding_weights self.mu_hat_train = self.mu_hat_train + np.squeeze(control_adj_train) self.tau_hat_train = np.squeeze(self.tau_hat_train * self.y_std) + # tau_hat_train stores the forest-only component tau(X); compute cate_train + # (tau_0 + tau(X)) separately for the treatment term used in y_hat + if self.sample_tau_0: + tau_0_vec = self.tau_0_samples[0, :] # num_samples vector (scalar treatment) + if self.adaptive_coding: + # CATE = (b_1 - b_0) * (tau_0 + tau(X)); control adj to mu = b_0 * (tau_0 + tau(X)) + cate_train = self.tau_hat_train + ( + (self.b1_samples - self.b0_samples) * tau_0_vec * self.y_std + ) + self.mu_hat_train = self.mu_hat_train + ( + self.b0_samples * tau_0_vec * self.y_std + ) + elif self.multivariate_treatment: + cate_train = self.tau_hat_train.copy() + for j in range(p_tau0): + cate_train[:, :, j] = cate_train[:, :, j] + ( + self.tau_0_samples[j, :] * self.y_std + ) + else: + cate_train = self.tau_hat_train + tau_0_vec * self.y_std + else: + cate_train = self.tau_hat_train if self.multivariate_treatment: treatment_term_train = np.multiply( - np.atleast_3d(Z_train).swapaxes(1, 2), self.tau_hat_train + np.atleast_3d(Z_train).swapaxes(1, 2), cate_train ).sum(axis=2) else: - treatment_term_train = Z_train * np.squeeze(self.tau_hat_train) + treatment_term_train = Z_train * np.squeeze(cate_train) self.y_hat_train = self.mu_hat_train + treatment_term_train if self.has_test: mu_raw_test = self.forest_container_mu.forest_container_cpp.Predict( @@ -2748,12 +2857,31 @@ def sample( self.tau_hat_test = self.tau_hat_test * adaptive_coding_weights_test self.mu_hat_test = self.mu_hat_test + np.squeeze(control_adj_test) self.tau_hat_test = np.squeeze(self.tau_hat_test * self.y_std) + # tau_hat_test stores forest-only tau(X); compute cate_test for y_hat + if self.sample_tau_0: + if self.adaptive_coding: + cate_test = self.tau_hat_test + ( + (self.b1_samples - self.b0_samples) * tau_0_vec * self.y_std + ) + self.mu_hat_test = self.mu_hat_test + ( + self.b0_samples * tau_0_vec * self.y_std + ) + elif self.multivariate_treatment: + cate_test = self.tau_hat_test.copy() + for j in range(p_tau0): + cate_test[:, :, j] = cate_test[:, :, j] + ( + self.tau_0_samples[j, :] * self.y_std + ) + else: + cate_test = self.tau_hat_test + tau_0_vec * self.y_std + else: + cate_test = self.tau_hat_test if self.multivariate_treatment: treatment_term_test = np.multiply( - np.atleast_3d(Z_test).swapaxes(1, 2), self.tau_hat_test + np.atleast_3d(Z_test).swapaxes(1, 2), cate_test ).sum(axis=2) else: - treatment_term_test = Z_test * np.squeeze(self.tau_hat_test) + treatment_term_test = Z_test * np.squeeze(cate_test) self.y_hat_test = self.mu_hat_test + treatment_term_test # TODO: make rfx_preds_train and rfx_preds_test persistent properties @@ -2784,6 +2912,9 @@ def sample( self.b0_samples = self.b0_samples self.b1_samples = self.b1_samples + if self.sample_tau_0: + self.tau_0_samples = self.tau_0_samples * self.y_std + if self.include_variance_forest: if self.sample_sigma2_global: self.sigma2_x_train = np.empty_like(sigma2_x_train_raw) @@ -3057,12 +3188,35 @@ def predict( mu_x_forest = mu_x_forest + np.squeeze(control_adj) tau_raw = tau_raw * adaptive_coding_weights tau_x_forest = np.squeeze(tau_raw * self.y_std) + # tau_x_forest is the forest-only component tau(X); compute cate_x_forest + # (tau_0 + tau(X)) for the "cate" term and treatment_term used in y_hat + if getattr(self, "sample_tau_0", False) and hasattr(self, "tau_0_samples"): + tau_0_vec = self.tau_0_samples[0, :] + if self.adaptive_coding: + cate_x_forest = tau_x_forest + ( + (self.b1_samples - self.b0_samples) * tau_0_vec + ) + if predict_mu_forest or predict_mu_forest_intermediate: + mu_x_forest = mu_x_forest + ( + self.b0_samples * tau_0_vec + ) + elif Z.shape[1] > 1: + p_tau0 = Z.shape[1] + cate_x_forest = tau_x_forest.copy() + for j in range(p_tau0): + cate_x_forest[:, :, j] = cate_x_forest[:, :, j] + ( + self.tau_0_samples[j, :] + ) + else: + cate_x_forest = tau_x_forest + tau_0_vec + else: + cate_x_forest = tau_x_forest if Z.shape[1] > 1: treatment_term = np.multiply( - np.atleast_3d(Z).swapaxes(1, 2), tau_x_forest + np.atleast_3d(Z).swapaxes(1, 2), cate_x_forest ).sum(axis=2) else: - treatment_term = Z * np.squeeze(tau_x_forest) + treatment_term = Z * np.squeeze(cate_x_forest) # Random effects data checks if has_rfx: @@ -3142,9 +3296,9 @@ def predict( prognostic_function = mu_x_forest if predict_cate_function: if tau_cate_separate: - cate = tau_x_forest + np.squeeze(rfx_predictions_raw[:, 1:, :]) + cate = cate_x_forest + np.squeeze(rfx_predictions_raw[:, 1:, :]) else: - cate = tau_x_forest + cate = cate_x_forest # Combine into y hat predictions needs_mean_term_preds = ( @@ -3411,7 +3565,7 @@ def compute_posterior_interval( rfx_basis : np.array, optional Optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects. terms : str, optional - Character string specifying the model term(s) for which to compute intervals. Options for BCF models are `"prognostic_function"`, `"mu"`, `"cate"`, `"tau"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`. Defaults to `"all"`. Note that `"mu"` is only different from `"prognostic_function"` if random effects are included with a model spec of `"intercept_only"` or `"intercept_plus_treatment"` and `"tau"` is only different from `"cate"` if random effects are included with a model spec of `"intercept_plus_treatment"`. + Character string specifying the model term(s) for which to compute intervals. Options for BCF models are `"prognostic_function"`, `"mu"`, `"cate"`, `"tau"`, `"tau_0"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`. Defaults to `"all"`. Note that `"mu"` is only different from `"prognostic_function"` if random effects are included with a model spec of `"intercept_only"` or `"intercept_plus_treatment"` and `"tau"` is only different from `"cate"` if random effects are included with a model spec of `"intercept_plus_treatment"`. `"tau_0"` is only available when the model was fit with `sample_intercept = True`. scale : str, optional Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Defaults to `"linear"`. level : float, optional @@ -3452,23 +3606,24 @@ def compute_posterior_interval( terms = [terms] for term in terms: if term not in [ - "y_hat", "prognostic_function", "mu", "cate", "tau", "rfx", "variance_forest", + "y_hat", "all", ]: raise ValueError( - f"term '{term}' was requested. Valid terms are 'y_hat', 'prognostic_function', 'mu', 'cate', 'tau', 'rfx', 'variance_forest', and 'all'" + f"term '{term}' was requested. Valid terms are 'prognostic_function', 'mu', 'cate', 'tau', 'rfx', 'variance_forest', 'y_hat', and 'all'" ) - needs_covariates_intermediate = ("y_hat" in terms) or ("all" in terms) + predict_terms = terms + needs_covariates_intermediate = ("y_hat" in predict_terms) or ("all" in predict_terms) needs_covariates = ( - ("prognostic_function" in terms) - or ("cate" in terms) - or ("variance_forest" in terms) + ("prognostic_function" in predict_terms) + or ("cate" in predict_terms) + or ("variance_forest" in predict_terms) or needs_covariates_intermediate ) if needs_covariates: @@ -3505,9 +3660,9 @@ def compute_posterior_interval( "'propensity' must have the same number of rows as 'X'" ) needs_rfx_data_intermediate = ( - ("y_hat" in terms) or ("all" in terms) + ("y_hat" in predict_terms) or ("all" in predict_terms) ) and self.has_rfx - needs_rfx_data = ("rfx" in terms) or needs_rfx_data_intermediate + needs_rfx_data = ("rfx" in predict_terms) or needs_rfx_data_intermediate if needs_rfx_data: if rfx_group_ids is None: raise ValueError( @@ -3532,30 +3687,35 @@ def compute_posterior_interval( "'rfx_basis' must have the same number of rows as 'X'" ) - # Compute posterior matrices for the requested model terms - predictions = self.predict( - X=X, - Z=Z, - propensity=propensity, - rfx_group_ids=rfx_group_ids, - rfx_basis=rfx_basis, - type="posterior", - terms=terms, - scale=scale, - ) - has_multiple_terms = True if isinstance(predictions, dict) else False + result = dict() + + # Compute posterior matrices for predict-able terms (if any) + if len(predict_terms) > 0: + predictions = self.predict( + X=X, + Z=Z, + propensity=propensity, + rfx_group_ids=rfx_group_ids, + rfx_basis=rfx_basis, + type="posterior", + terms=predict_terms, + scale=scale, + ) + if isinstance(predictions, dict): + for term in predictions.keys(): + if predictions[term] is not None: + result[term] = _summarize_interval( + predictions[term], 1, level=level + ) + else: + result[predict_terms[0]] = _summarize_interval( + predictions, 1, level=level + ) - # Compute posterior intervals - if has_multiple_terms: - result = dict() - for term in predictions.keys(): - if predictions[term] is not None: - result[term] = _summarize_interval( - predictions[term], 1, level=level - ) - return result - else: - return _summarize_interval(predictions, 1, level=level) + # Return single interval directly if only one specific term was requested + if len(terms) == 1 and terms[0] in result: + return result[terms[0]] + return result def sample_posterior_predictive( self, @@ -3761,6 +3921,7 @@ def to_json(self) -> str: bcf_json.add_scalar("keep_every", self.keep_every) bcf_json.add_scalar("num_samples", self.num_samples) bcf_json.add_boolean("adaptive_coding", self.adaptive_coding) + bcf_json.add_boolean("sample_tau_0", self.sample_tau_0) bcf_json.add_string("propensity_covariate", self.propensity_covariate) bcf_json.add_boolean( "internal_propensity_model", self.internal_propensity_model @@ -3787,6 +3948,11 @@ def to_json(self) -> str: if self.adaptive_coding: bcf_json.add_numeric_vector("b0_samples", self.b0_samples, "parameters") bcf_json.add_numeric_vector("b1_samples", self.b1_samples, "parameters") + if self.sample_tau_0 and hasattr(self, "tau_0_samples"): + bcf_json.add_scalar("tau_0_dim", self.tau_0_samples.shape[0]) + bcf_json.add_numeric_vector( + "tau_0_samples", self.tau_0_samples.ravel(), "parameters" + ) # Add propensity model (if it exists) if self.internal_propensity_model: @@ -3855,6 +4021,7 @@ def from_json(self, json_string: str) -> None: self.keep_every = int(bcf_json.get_scalar("keep_every")) self.num_samples = int(bcf_json.get_scalar("num_samples")) self.adaptive_coding = bcf_json.get_boolean("adaptive_coding") + self.sample_tau_0 = bcf_json.get_boolean("sample_tau_0") self.propensity_covariate = bcf_json.get_string("propensity_covariate") self.internal_propensity_model = bcf_json.get_boolean( "internal_propensity_model" @@ -3881,6 +4048,10 @@ def from_json(self, json_string: str) -> None: if self.adaptive_coding: self.b1_samples = bcf_json.get_numeric_vector("b1_samples", "parameters") self.b0_samples = bcf_json.get_numeric_vector("b0_samples", "parameters") + if self.sample_tau_0: + tau_0_dim = int(bcf_json.get_scalar("tau_0_dim")) + tau_0_vec = bcf_json.get_numeric_vector("tau_0_samples", "parameters") + self.tau_0_samples = tau_0_vec.reshape(tau_0_dim, -1) # Unpack internal propensity model if self.internal_propensity_model: @@ -3990,6 +4161,7 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: self.num_chains = int(json_object_default.get_scalar("num_chains")) self.keep_every = int(json_object_default.get_scalar("keep_every")) self.adaptive_coding = json_object_default.get_boolean("adaptive_coding") + self.sample_tau_0 = json_object_default.get_boolean("sample_tau_0") self.propensity_covariate = json_object_default.get_string( "propensity_covariate" ) @@ -4057,6 +4229,18 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: sample_sigma2_leaf_tau, )) + if self.sample_tau_0: + tau_0_dim = int(json_object_default.get_scalar("tau_0_dim")) + for i in range(len(json_object_list)): + tau_0_vec_i = json_object_list[i].get_numeric_vector( + "tau_0_samples", "parameters" + ) + tau_0_mat_i = tau_0_vec_i.reshape(tau_0_dim, -1) + if i == 0: + self.tau_0_samples = tau_0_mat_i + else: + self.tau_0_samples = np.hstack((self.tau_0_samples, tau_0_mat_i)) + # Unpack internal propensity model if self.internal_propensity_model: bart_propensity_string = json_object_default.get_string( @@ -4132,6 +4316,7 @@ def extract_parameter(self, term: str) -> np.array: - Test set mean function predictions: `"y_hat_test"` - In-sample treatment effect forest predictions: `"tau_hat_train"` - Test set treatment effect forest predictions: `"tau_hat_test"` + - Treatment effect intercept: `"tau_0"`, `"treatment_intercept"`, `"tau_intercept"` - In-sample variance forest predictions: `"sigma2_x_train"`, `"var_x_train"` - Test set variance forest predictions: `"sigma2_x_test"`, `"var_x_test"` @@ -4213,6 +4398,13 @@ def extract_parameter(self, term: str) -> np.array: else: raise ValueError("This model does not have test set variance forest predictions") + if term in ["tau_0", "treatment_intercept", "tau_intercept"]: + t0 = getattr(self, "tau_0_samples", None) + if t0 is not None: + return t0 + else: + raise ValueError("This model does not have treatment effect intercept (tau_0) samples") + raise ValueError(f"term {term} is not a valid BCF model term") def summary(self) -> None: @@ -4289,6 +4481,20 @@ def summary(self) -> None: for p, q in zip(probs, quantiles_b1): output_str += f" {p*100:5.1f}%: {q:.3f}\n" + # Treatment effect intercept (tau_0) + if self.sample_tau_0: + tau_0_samp = getattr(self, "tau_0_samples", None) + if tau_0_samp is not None: + tau_0_vec = tau_0_samp.ravel() + n_samples = tau_0_samp.shape[1] + mean_tau_0 = np.mean(tau_0_vec) + sd_tau_0 = np.std(tau_0_vec) + quantiles_tau_0 = np.quantile(tau_0_vec, probs) + output_str += f"Summary of treatment effect intercept (tau_0) posterior: " + output_str += f"{n_samples} samples, mean = {mean_tau_0:.3f}, standard deviation = {sd_tau_0:.3f}, quantiles:\n" + for p, q in zip(probs, quantiles_tau_0): + output_str += f" {p*100:5.1f}%: {q:.3f}\n" + # In-sample predictions yht = getattr(self, "y_hat_train", None) if yht is not None: @@ -4388,6 +4594,8 @@ def __str__(self) -> str: model_terms.append("prognostic forest leaf scale model") if self.sample_sigma2_leaf_tau: model_terms.append("treatment effect forest leaf scale model") + if self.sample_tau_0: + model_terms.append("treatment effect intercept model") if len(model_terms) > 2: output_str = f"BCFModel run with {', '.join(model_terms[:-1])}, and {model_terms[-1]}" elif len(model_terms) == 2: diff --git a/test/python/test_bcf.py b/test/python/test_bcf.py index 5396f779..f49cd6ff 100644 --- a/test/python/test_bcf.py +++ b/test/python/test_bcf.py @@ -338,20 +338,22 @@ def test_continuous_univariate_bcf(self): # Assertions bcf_preds_3 = bcf_model_3.predict(X_test, Z_test, pi_test) - tau_hat_3, mu_hat_3, y_hat_3 = ( - bcf_preds_3["tau_hat"], + # Use "cate" (tau_0 + tau(X)) for the CATE comparison, consistent with how + # tau_hat and tau_hat_2 were set above (via predict(terms="cate")) + cate_hat_3, mu_hat_3, y_hat_3 = ( + bcf_preds_3["cate"], bcf_preds_3["mu_hat"], bcf_preds_3["y_hat"], ) - assert tau_hat_3.shape == (n_train, num_mcmc * 2) + assert cate_hat_3.shape == (n_train, num_mcmc * 2) assert mu_hat_3.shape == (n_train, num_mcmc * 2) assert y_hat_3.shape == (n_train, num_mcmc * 2) np.testing.assert_allclose(y_hat_3[:, 0:num_mcmc], y_hat) np.testing.assert_allclose(y_hat_3[:, num_mcmc : (2 * num_mcmc)], y_hat_2) np.testing.assert_allclose(mu_hat_3[:, 0:num_mcmc], mu_hat) np.testing.assert_allclose(mu_hat_3[:, num_mcmc : (2 * num_mcmc)], mu_hat_2) - np.testing.assert_allclose(tau_hat_3[:, 0:num_mcmc], tau_hat) - np.testing.assert_allclose(tau_hat_3[:, num_mcmc : (2 * num_mcmc)], tau_hat_2) + np.testing.assert_allclose(cate_hat_3[:, 0:num_mcmc], tau_hat) + np.testing.assert_allclose(cate_hat_3[:, num_mcmc : (2 * num_mcmc)], tau_hat_2) np.testing.assert_allclose( bcf_model_3.global_var_samples[0:num_mcmc], bcf_model.global_var_samples ) @@ -513,17 +515,17 @@ def test_continuous_univariate_bcf(self): # Assertions bcf_preds_3 = bcf_model_3.predict(X_test, Z_test) - tau_hat_3, mu_hat_3, y_hat_3 = ( - bcf_preds_3["tau_hat"], + cate_hat_3, mu_hat_3, y_hat_3 = ( + bcf_preds_3["cate"], bcf_preds_3["mu_hat"], bcf_preds_3["y_hat"], ) - assert tau_hat_3.shape == (n_train, num_mcmc * 2) + assert cate_hat_3.shape == (n_train, num_mcmc * 2) assert mu_hat_3.shape == (n_train, num_mcmc * 2) assert y_hat_3.shape == (n_train, num_mcmc * 2) np.testing.assert_allclose(y_hat_3[:, 0:num_mcmc], y_hat) np.testing.assert_allclose(mu_hat_3[:, 0:num_mcmc], mu_hat) - np.testing.assert_allclose(tau_hat_3[:, 0:num_mcmc], tau_hat) + np.testing.assert_allclose(cate_hat_3[:, 0:num_mcmc], tau_hat) np.testing.assert_allclose( bcf_model_3.global_var_samples[0:num_mcmc], bcf_model.global_var_samples ) diff --git a/test/python/test_str.py b/test/python/test_str.py index ae2ccfd7..55b81459 100644 --- a/test/python/test_str.py +++ b/test/python/test_str.py @@ -355,7 +355,7 @@ def test_default_model(self, bcf_data): assert "retaining every iteration" in s def test_more_than_two_model_terms(self, bcf_data): - """Adding sigma2_global gives >2 terms and triggers Oxford-comma format.""" + """Adding sigma2_global and tau_0 gives >2 terms and triggers Oxford-comma format.""" model = BCFModel() model.sample( X_train=bcf_data["X_train"], @@ -372,7 +372,7 @@ def test_more_than_two_model_terms(self, bcf_data): prognostic_forest_params={"sample_sigma2_leaf": False}, treatment_effect_forest_params={"sample_sigma2_leaf": False}, ) - assert ", and global error variance model" in str(model) + assert ", and treatment effect intercept model" in str(model) def test_adaptive_coding_disabled(self, bcf_data): """Binary treatment without adaptive coding shows 'default coding'.""" diff --git a/tools/debug/bcf_parametric_treatment_term.R b/tools/debug/bcf_parametric_treatment_term.R new file mode 100644 index 00000000..63b6cab2 --- /dev/null +++ b/tools/debug/bcf_parametric_treatment_term.R @@ -0,0 +1,210 @@ +# Load libraries +library(stochtree) + +# Set seed +random_seed <- 1234 +set.seed(random_seed) + +# Prepare simulation study +n_sim <- 50 +ate_squared_errors_classic_homogeneous <- rep(NA_real_, n_sim) +ate_coverage_classic_homogeneous <- rep(NA_integer_, n_sim) +ate_squared_errors_parametric_homogeneous <- rep(NA_real_, n_sim) +ate_coverage_parametric_homogeneous <- rep(NA_integer_, n_sim) +ate_squared_errors_classic_heterogeneous <- rep(NA_real_, n_sim) +ate_coverage_classic_heterogeneous <- rep(NA_integer_, n_sim) +ate_squared_errors_parametric_heterogeneous <- rep(NA_real_, n_sim) +ate_coverage_parametric_heterogeneous <- rep(NA_integer_, n_sim) + +# Below we run two different simulation studies, in which +# we compare "traditional" BCF with a treatment effect forest +# to a modified version that includes a parametric treatment effect term +# and a forest-based offset in the case where the true treatment effect +# is homogeneous and where the true treatment effect is heterogeneous +num_trees_tau_classic <- 100 +num_trees_tau_parametric <- 100 +leaf_scale_tau_classic <- 1 / num_trees_tau_classic +leaf_scale_tau_parametric <- 0.25 / num_trees_tau_parametric +leaf_scale_tau_0_parametric <- 1 +num_gfr <- 3 +num_burnin <- 200 +num_mcmc <- 500 + +for (i in 1:n_sim) { + # Shared aspects of both DGPs + n <- 500 + p <- 5 + X <- matrix(runif(n * p), n, p) + pi_X <- X[, 1] * 0.6 + 0.2 + mu_X <- (pi_X - 0.5) * 5 + Z <- rbinom(n, 1, pi_X) + + # Generate data with no treatment effect heterogeneity + tau_X <- 0.5 + y <- mu_X + tau_X * Z + rnorm(n) + ATE_true <- tau_X + + # Run traditional BCF + bcf_model_classic <- bcf( + X_train = X, + Z_train = Z, + y_train = y, + propensity_train = pi_X, + num_gfr = num_gfr, + num_burnin = num_burnin, + num_mcmc = num_mcmc, + general_params = list( + adaptive_coding = FALSE + ), + treatment_effect_forest_params = list( + num_trees = num_trees_tau_classic, + sample_intercept = FALSE, + sigma2_leaf_init = leaf_scale_tau_classic + ) + ) + CATE_posterior_classic <- predict( + bcf_model_classic, + X = X, + Z = Z, + propensity = pi_X, + type = "posterior", + terms = "cate" + ) + ATE_posterior_classic <- colMeans(CATE_posterior_classic) + ate_squared_errors_classic_homogeneous[i] <- (mean(ATE_posterior_classic) - + ATE_true)^2 + ate_coverage_classic_homogeneous[i] <- (quantile( + ATE_posterior_classic, + 0.025 + ) <= + ATE_true & + ATE_true <= quantile(ATE_posterior_classic, 0.975)) + + # Run BCF with parametric term + bcf_model_parametric <- bcf( + X_train = X, + Z_train = Z, + y_train = y, + propensity_train = pi_X, + num_gfr = num_gfr, + num_burnin = num_burnin, + num_mcmc = num_mcmc, + general_params = list( + adaptive_coding = FALSE + ), + treatment_effect_forest_params = list( + num_trees = num_trees_tau_parametric, + sample_intercept = TRUE, + tau_0_prior_var = leaf_scale_tau_0_parametric, + sigma2_leaf_init = leaf_scale_tau_parametric + ) + ) + CATE_posterior_parametric <- predict( + bcf_model_parametric, + X = X, + Z = Z, + propensity = pi_X, + type = "posterior", + terms = "cate" + ) + ATE_posterior_parametric <- colMeans(CATE_posterior_parametric) + ate_squared_errors_parametric_homogeneous[i] <- (mean( + ATE_posterior_parametric + ) - + ATE_true)^2 + ate_coverage_parametric_homogeneous[i] <- (quantile( + ATE_posterior_parametric, + 0.025 + ) <= + ATE_true & + ATE_true <= quantile(ATE_posterior_parametric, 0.975)) + + # Generate data with significant treatment effect heterogeneity + tau_X <- 2 * X[, 2] - 1 + y <- mu_X + tau_X * Z + rnorm(n) + ATE_true <- mean(tau_X) + + # Run traditional BCF + bcf_model_classic <- bcf( + X_train = X, + Z_train = Z, + y_train = y, + propensity_train = pi_X, + num_gfr = num_gfr, + num_burnin = num_burnin, + num_mcmc = num_mcmc, + general_params = list( + adaptive_coding = FALSE + ), + treatment_effect_forest_params = list( + num_trees = num_trees_tau_classic, + sample_intercept = FALSE, + sigma2_leaf_init = leaf_scale_tau_classic + ) + ) + CATE_posterior_classic <- predict( + bcf_model_classic, + X = X, + Z = Z, + propensity = pi_X, + type = "posterior", + terms = "cate" + ) + ATE_posterior_classic <- colMeans(CATE_posterior_classic) + ate_squared_errors_classic_heterogeneous[i] <- (mean(ATE_posterior_classic) - + ATE_true)^2 + ate_coverage_classic_heterogeneous[i] <- (quantile( + ATE_posterior_classic, + 0.025 + ) <= + ATE_true & + ATE_true <= quantile(ATE_posterior_classic, 0.975)) + + # Run BCF with parametric term + bcf_model_parametric <- bcf( + X_train = X, + Z_train = Z, + y_train = y, + propensity_train = pi_X, + num_gfr = num_gfr, + num_burnin = num_burnin, + num_mcmc = num_mcmc, + general_params = list( + adaptive_coding = FALSE + ), + treatment_effect_forest_params = list( + num_trees = num_trees_tau_parametric, + sample_intercept = TRUE, + tau_0_prior_var = leaf_scale_tau_0_parametric, + sigma2_leaf_init = leaf_scale_tau_parametric + ) + ) + CATE_posterior_parametric <- predict( + bcf_model_parametric, + X = X, + Z = Z, + propensity = pi_X, + type = "posterior", + terms = "cate" + ) + ATE_posterior_parametric <- colMeans(CATE_posterior_parametric) + ate_squared_errors_parametric_heterogeneous[i] <- (mean( + ATE_posterior_parametric + ) - + ATE_true)^2 + ate_coverage_parametric_heterogeneous[i] <- (quantile( + ATE_posterior_parametric, + 0.025 + ) <= + ATE_true & + ATE_true <= quantile(ATE_posterior_parametric, 0.975)) +} + +mean(ate_squared_errors_classic_homogeneous) +mean(ate_squared_errors_parametric_homogeneous) +mean(ate_coverage_classic_homogeneous) +mean(ate_coverage_parametric_homogeneous) +mean(ate_squared_errors_classic_heterogeneous) +mean(ate_squared_errors_parametric_heterogeneous) +mean(ate_coverage_classic_heterogeneous) +mean(ate_coverage_parametric_heterogeneous) diff --git a/tools/debug/bcf_tau0_scaling_debug.R b/tools/debug/bcf_tau0_scaling_debug.R new file mode 100644 index 00000000..002e6763 --- /dev/null +++ b/tools/debug/bcf_tau0_scaling_debug.R @@ -0,0 +1,121 @@ +# Debug script: verify that colMeans(tau(X)) + tau_0 == colMeans(CATE) +# for the reparameterized BCF model (sample_intercept = TRUE). +# Based on DGP from vignettes/ReparameterizedCausalInference.Rmd + +library(stochtree) + +set.seed(42) + +# --- DGP (from vignette) --- +n <- 400 +p <- 20 +snr <- 2 +X <- matrix(runif(n * p), n, p) +mu_x <- sin(pi * X[, 1] * X[, 2]) + 2 * (X[, 3] - 0.5)^2 + X[, 4] +tau_x <- 5 +pi_x <- pnorm(1.5 * X[, 1] - 0.5 * X[, 2]) +Z <- rbinom(n, 1, pi_x) +E_XZ <- mu_x + Z * tau_x +sigma_true <- sd(E_XZ) / snr +y <- E_XZ + rnorm(n, 0, 1) * sigma_true + +test_set_pct <- 0.2 +n_test <- round(test_set_pct * n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds, ]; X_train <- X[train_inds, ] +pi_test <- pi_x[test_inds]; pi_train <- pi_x[train_inds] +Z_test <- Z[test_inds]; Z_train <- Z[train_inds] +y_test <- y[test_inds]; y_train <- y[train_inds] + +# --- Fit reparameterized BCF --- +bcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 0, + num_burnin = 1000, + num_mcmc = 500, + general_params = list(adaptive_coding = FALSE, num_chains = 4, keep_every = 5), + treatment_effect_forest_params = list( + num_trees = 50, + sample_intercept = TRUE, + sigma2_leaf_init = 0.25 / 50, + tau_0_prior_var = 1 + ) +) + +# --- Extract components --- +# tau(X) forest-only predictions: n_test x num_samples +tau_x_posterior <- predict( + bcf_model, X = X_test, Z = Z_test, propensity = pi_test, + type = "posterior", terms = "tau" +) + +# CATE = tau_0 + tau(X): n_test x num_samples +cate_posterior <- predict( + bcf_model, X = X_test, Z = Z_test, propensity = pi_test, + type = "posterior", terms = "cate" +) + +# tau_0 samples (stored as p_tau0 x num_samples matrix, p_tau0 = 1 here) +tau_0_samples <- extractParameter(bcf_model, "tau_0") + +cat("--- Dimensions ---\n") +cat("tau_x_posterior:", paste(dim(tau_x_posterior), collapse = " x "), "\n") +cat("cate_posterior: ", paste(dim(cate_posterior), collapse = " x "), "\n") +cat("tau_0_samples: ", paste(dim(tau_0_samples), collapse = " x "), "\n") + +# --- ATE posteriors --- +# ATE via colMeans of CATE +ate_via_cate <- colMeans(cate_posterior) + +# ATE via colMeans(tau(X)) + tau_0 +ate_via_parts <- colMeans(tau_x_posterior) + as.numeric(tau_0_samples) + +cat("\n--- First 10 sample-level comparison: colMeans(tau) + tau_0 vs colMeans(cate) ---\n") +comparison <- data.frame( + tau_x_mean = colMeans(tau_x_posterior)[1:10], + tau_0 = as.numeric(tau_0_samples)[1:10], + sum_parts = ate_via_parts[1:10], + cate_mean = ate_via_cate[1:10], + diff = (ate_via_parts - ate_via_cate)[1:10] +) +print(round(comparison, 6)) + +cat("\n--- Max absolute difference across all samples ---\n") +cat("max|colMeans(tau) + tau_0 - colMeans(cate)|:", max(abs(ate_via_parts - ate_via_cate)), "\n") + +# --- Observation-level check: tau_x[i,s] + tau_0[s] vs cate[i,s] --- +# Reconstruct expected CATE from parts +cate_reconstructed <- sweep(tau_x_posterior, 2, as.numeric(tau_0_samples), "+") +cat("\n--- Max absolute difference (observation-level): tau_x + tau_0 vs cate ---\n") +cat("max|tau_x[i,s] + tau_0[s] - cate[i,s]|:", max(abs(cate_reconstructed - cate_posterior)), "\n") + +# --- Scale checks --- +cat("\n--- Scale diagnostics ---\n") +cat("outcome_scale (y_std):", bcf_model$model_params$outcome_scale, "\n") +cat("outcome_mean (y_bar):", bcf_model$model_params$outcome_mean, "\n") +cat("mean(tau_x_posterior):", mean(tau_x_posterior), "\n") +cat("mean(tau_0_samples): ", mean(tau_0_samples), "\n") +cat("mean(cate_posterior): ", mean(cate_posterior), "\n") +cat("true ATE: ", tau_x, "\n") + +# --- Posterior summaries --- +cat("\n--- Posterior mean of ATE (via CATE) ---\n") +cat("mean:", mean(ate_via_cate), " 95% CI: [", + quantile(ate_via_cate, 0.025), ",", quantile(ate_via_cate, 0.975), "]\n") + +cat("\n--- Posterior of tau_0 alone ---\n") +cat("mean:", mean(tau_0_samples), " 95% CI: [", + quantile(tau_0_samples, 0.025), ",", quantile(tau_0_samples, 0.975), "]\n") + +cat("\n--- Posterior of colMeans(tau(X)) alone ---\n") +tau_x_test_mean <- colMeans(tau_x_posterior) +cat("mean:", mean(tau_x_test_mean), " 95% CI: [", + quantile(tau_x_test_mean, 0.025), ",", quantile(tau_x_test_mean, 0.975), "]\n") diff --git a/vignettes/ReparameterizedCausalInference.Rmd b/vignettes/ReparameterizedCausalInference.Rmd new file mode 100644 index 00000000..bd2ff034 --- /dev/null +++ b/vignettes/ReparameterizedCausalInference.Rmd @@ -0,0 +1,609 @@ +--- +title: "Semiparametric Causal Inference in StochTree" +output: rmarkdown::html_vignette +bibliography: vignettes.bib +vignette: > + %\VignetteIndexEntry{Reparameterized-BCF} + %\VignetteEngine{knitr::rmarkdown} + %\VignetteEncoding{UTF-8} +--- + +```{r, include = FALSE} +knitr::opts_chunk$set( + collapse = TRUE, + comment = "#>" +) +``` + +The classic BCF model of @hahn2020bayesian is defined as +\begin{equation} +\begin{aligned} +Y_i \mid x_i, z_i &\sim \mathrm{N}(f_0(x_i) + \tau(x_i) z_i, \sigma^2)\\ +f_0 &\sim \mathrm{BART}(\alpha_0, \beta_0, m_0)\\ +\tau &\sim \mathrm{BART}(\alpha_{\tau}, \beta_{\tau}, m_{\tau}). +\end{aligned} +\end{equation} +where $\mathrm{BART}(\alpha, \beta, m)$ defines a BART model with $m$ trees and split prior parameters $\alpha$ and $\beta$. + +The authors noted that separating estimation / regularization of a control function, $f_0(x)$, and a CATE function, $\tau(x)$, can give better estimation error and interval coverage in settings with strong confounding and treatment effect moderation. + +`stochtree` now defaults to a slight modification of this model, with the treatment effect function decomposed into parametric and nonparametric components +\begin{equation} +\begin{aligned} +Y_i \mid x_i, z_i &\sim \mathrm{N}(f_0(x_i) + (\tau_0 + t(x_i)) z_i, \sigma^2)\\ +f_0 &\sim \mathrm{BART}(\alpha_0, \beta_0, m_0)\\ +t &\sim \mathrm{BART}(\alpha_{t}, \beta_{t}, m_{t})\\ +\tau_0 &\sim \mathrm{N}\left(0, \sigma_{\tau_0}^2 \right), +\end{aligned} +\end{equation} +where $\tau_0 + t(x_i)$ takes the place of the $\tau(x_i)$ forest term in the original BCF model. This decomposition allows the forest term to focus on capturing heterogeneity "offsets" to a parametric model of homogeneous treatment effects. + +Below we demonstrate the advantages of this "reparameterization" of BCF on a synthetic dataset. + +First, we load the necessary libraries + +```{r setup} +library(stochtree) +``` + +We set a seed for reproducibility + +```{r} +random_seed <- 1234 +set.seed(random_seed) +``` + +# Binary Treatment with Homogeneous Treatment Effect + +Consider the following data generating process + +\begin{equation*} +\begin{aligned} +y &= \mu(X) + \tau(X) Z + \epsilon\\ +\mu(X) &= 2 \sin(2 \pi X_1) - 2 (2 X_3 - 1)\\ +\tau(X) &= 5\\ +\pi(X) &= \phi\left(\frac{\mu(X)}{4}\right)\\ +X_1,\dots,X_p &\sim \text{Uniform}\left(0,1\right)\\ +Z &\sim \text{Bernoulli}\left(\pi(X)\right)\\ +\epsilon &\sim N\left(0,\sigma^2\right) +\end{aligned} +\end{equation*} + +### Simulation + +We draw from the DGP defined above + +```{r data} +n <- 500 +p <- 20 +snr <- 2 +X <- matrix(runif(n * p), n, p) +mu_x <- 2 * sin(2 * pi * X[, 1]) - 2 * (2 * X[, 3] - 1) +tau_x <- 5 +pi_x <- pnorm(mu_x / 4) +Z <- rbinom(n, 1, pi_x) +E_XZ <- mu_x + Z * tau_x +sigma_true <- sd(E_XZ) / snr +y <- E_XZ + rnorm(n, 0, 1) * sigma_true +``` + +And split data into test and train sets + +```{r} +test_set_pct <- 0.5 +n_test <- round(test_set_pct * n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds, ] +X_train <- X[train_inds, ] +pi_test <- pi_x[test_inds] +pi_train <- pi_x[train_inds] +Z_test <- Z[test_inds] +Z_train <- Z[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] +mu_test <- mu_x[test_inds] +mu_train <- mu_x[train_inds] +tau_test <- tau_x[test_inds] +tau_train <- tau_x[train_inds] +``` + +## Sampling and Analysis + +### Classic BCF Model + +We first simulate from the classic BCF model with no parametric treatment effect term + +```{r} +num_gfr <- 0 +num_burnin <- 1000 +num_mcmc <- 500 +general_params <- list( + adaptive_coding = TRUE, + num_chains = 4, + random_seed = random_seed, + num_threads = 1 +) +num_trees_tau <- 50 +treatment_effect_forest_params <- list( + num_trees = num_trees_tau, + sample_intercept = FALSE, + sigma2_leaf_init = 1 / num_trees_tau +) +bcf_model_classic <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = num_gfr, + num_burnin = num_burnin, + num_mcmc = num_mcmc, + general_params = general_params, + treatment_effect_forest_params = treatment_effect_forest_params +) +``` + +And we compare the posterior distribution of the ATE to its true value + +```{r} +cate_posterior_classic <- predict( + bcf_model_classic, + X = X_test, + Z = Z_test, + propensity = pi_test, + type = "posterior", + terms = "cate" +) +ate_posterior_classic <- colMeans(cate_posterior_classic) +hist( + ate_posterior_classic, + freq = F, + xlab = "ATE", + ylab = "Density", + main = "Posterior Distribution of ATE" +) +abline(v = tau_x, col = "red", lty = 3, lwd = 3) +``` + +As a rough convergence check, we inspect the traceplot of the global error scale parameter, $\sigma^2$ + +```{r} +sigma2_samples <- extractParameter(bcf_model_classic, "sigma2") +plot( + sigma2_samples, + type = "l", + main = "Traceplot of Sigma^2", + ylab = "Sigma^2", + xlab = "Iteration" +) +abline(h = sigma_true^2, col = "red", lty = 3, lwd = 3) +``` + +### Reparameterized BCF Model + +Now we fit the reparameterized model, regularizing the $t(x)$ forest more heavily to account for the standard normal prior on the $\tau_0$ term. + +```{r} +num_trees_tau <- 50 +general_params <- list( + adaptive_coding = FALSE, + num_chains = 4, + random_seed = random_seed, + num_threads = 1 +) +treatment_effect_forest_params <- list( + num_trees = num_trees_tau, + sample_intercept = TRUE, + sigma2_leaf_init = 0.25 / num_trees_tau, + tau_0_prior_var = 1 +) +bcf_model_reparam <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = num_gfr, + num_burnin = num_burnin, + num_mcmc = num_mcmc, + general_params = general_params, + treatment_effect_forest_params = treatment_effect_forest_params +) +``` + +And we compare the posterior distribution of the ATE to its true value + +```{r} +cate_posterior_reparam <- predict( + bcf_model_reparam, + X = X_test, + Z = Z_test, + propensity = pi_test, + type = "posterior", + terms = "cate" +) +ate_posterior_reparam <- colMeans(cate_posterior_reparam) +hist( + ate_posterior_reparam, + freq = F, + xlab = "ATE", + ylab = "Density", + main = "Posterior Distribution of ATE" +) +abline(v = tau_x, col = "red", lty = 3, lwd = 3) +``` + +As above, we check convergence by inspecting the traceplot of the global error scale parameter, $\sigma^2$ + +```{r} +sigma2_samples <- extractParameter(bcf_model_reparam, "sigma2") +plot( + sigma2_samples, + type = "l", + main = "Traceplot of Sigma^2", + ylab = "Sigma^2", + xlab = "Iteration" +) +abline(h = sigma_true^2, col = "red", lty = 3, lwd = 3) +``` + +Since $t(X)$ is not constrained to sum to 0, the parameter $\tau_0$ does not identify the ATE. We can see this by averaging each posterior draw of $t(X)$ over the test set and comparing the posterior point estimates $\tau_0$ and $\bar{t}(X)$. + +```{r} +tau_0_posterior <- extractParameter(bcf_model_reparam, "tau_0") +t_x_posterior_reparam <- predict( + bcf_model_reparam, + X = X_test, + Z = Z_test, + propensity = pi_test, + type = "posterior", + terms = "tau" +) +t_x_posterior_reparam <- colMeans(t_x_posterior_reparam) +plot( + tau_0_posterior, + t_x_posterior_reparam, + xlab = "tau_0", + ylab = "t(X)", + main = "Posterior of tau_0 vs t(X), averaged over X" +) +``` + +While `stochtree` does not currently support constraining $t(X)$ to sum to 0 over the training set, we can more heavily regularize $t(X)$ so that its values are much closer to zero. Using a single tree with a very small leaf scale effectively collapses the forest to a constant near zero, making $\tau_0$ the primary vehicle for the treatment effect. + +```{r} +general_params <- list( + adaptive_coding = FALSE, + num_chains = 4, + random_seed = random_seed, + num_threads = 1 +) +treatment_effect_forest_params <- list( + num_trees = 1, + sample_intercept = TRUE, + sigma2_leaf_init = 1e-6, + tau_0_prior_var = 1 +) +bcf_model_reparam <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = num_gfr, + num_burnin = num_burnin, + num_mcmc = num_mcmc, + general_params = general_params, + treatment_effect_forest_params = treatment_effect_forest_params +) +``` + +Again we plot the posterior distribution of the ATE + +```{r} +cate_posterior_reparam <- predict( + bcf_model_reparam, + X = X_test, + Z = Z_test, + propensity = pi_test, + type = "posterior", + terms = "cate" +) +ate_posterior_reparam <- colMeans(cate_posterior_reparam) +hist( + ate_posterior_reparam, + freq = F, + xlab = "ATE", + ylab = "Density", + main = "Posterior Distribution of ATE" +) +abline(v = tau_x, col = "red", lty = 3, lwd = 3) +``` + +This time we see no correlation between the $\tau_0$ posterior and the (highly-regularized) $\bar{t}(X)$ posterior -- $\tau_0$ more directly captures the majority of the ATE + +```{r} +tau_0_posterior <- extractParameter(bcf_model_reparam, "tau_0") +t_x_posterior_reparam <- predict( + bcf_model_reparam, + X = X_test, + Z = Z_test, + propensity = pi_test, + type = "posterior", + terms = "tau" +) +t_x_posterior_reparam <- colMeans(t_x_posterior_reparam) +plot( + tau_0_posterior, + t_x_posterior_reparam, + xlab = "tau_0", + ylab = "t(X)", + main = "Posterior of tau_0 vs t(X), averaged over X" +) +abline(0, 1, col = "red", lty = 3, lwd = 3) +``` + +We can further regularize estimation of the ATE by reducing $\sigma_{\tau_0}^2$ + +```{r} +general_params <- list( + adaptive_coding = FALSE, + num_chains = 4, + random_seed = random_seed, + num_threads = 1 +) +treatment_effect_forest_params <- list( + num_trees = 1, + sample_intercept = TRUE, + sigma2_leaf_init = 1e-6, + tau_0_prior_var = 0.05 +) +bcf_model_reparam <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = num_gfr, + num_burnin = num_burnin, + num_mcmc = num_mcmc, + general_params = general_params, + treatment_effect_forest_params = treatment_effect_forest_params +) +cate_posterior_reparam <- predict( + bcf_model_reparam, + X = X_test, + Z = Z_test, + propensity = pi_test, + type = "posterior", + terms = "cate" +) +ate_posterior_reparam <- colMeans(cate_posterior_reparam) +hist( + ate_posterior_reparam, + freq = F, + xlab = "ATE", + ylab = "Density", + main = "Posterior Distribution of ATE" +) +abline(v = tau_x, col = "red", lty = 3, lwd = 3) +``` + +# Continuous Treatment with Homogeneous Treatment Effect + +The $\tau_0 + t(x)$ reparameterization generalizes naturally to continuous treatment. With a continuous $Z$, $\tau(x)$ represents the marginal effect of a one-unit increase in $Z$, and $\tau_0$ captures the homogeneous component of that effect. + +Consider the following data generating process: + +\begin{equation*} +\begin{aligned} +y &= \mu(X) + \tau(X)\, Z + \epsilon\\ +\mu(X) &= 2 \sin(2 \pi X_1) - 2 (2 X_3 - 1)\\ +\tau(X) &= 2\\ +\pi(X) &= \mathrm{E}[Z \mid X] = \mu(X)/8\\ +Z \mid X &\sim \mathrm{N}(\pi(X),\, 1)\\ +\epsilon &\sim N\left(0,\sigma^2\right) +\end{aligned} +\end{equation*} + +### Simulation + +We draw from the DGP defined above + +```{r} +n <- 500 +p <- 20 +snr <- 2 +X <- matrix(runif(n * p), n, p) +mu_x <- 2 * sin(2 * pi * X[, 1]) - 2 * (2 * X[, 3] - 1) +tau_x <- 2 +pi_x <- mu_x / 8 +Z <- pi_x + rnorm(n, 0, 1) +E_XZ <- mu_x + Z * tau_x +sigma_true <- sd(E_XZ) / snr +y <- E_XZ + rnorm(n, 0, 1) * sigma_true +``` + +And split data into test and train sets + +```{r} +test_inds <- sort(sample(1:n, round(0.5 * n), replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds, ] +X_train <- X[train_inds, ] +pi_test <- pi_x[test_inds] +pi_train <- pi_x[train_inds] +Z_test <- Z[test_inds] +Z_train <- Z[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] +``` + +## Sampling and Analysis + +Note that `adaptive_coding` must be `FALSE` for continuous treatment, since the adaptive coding scheme is designed for binary treatment. + +### Classic BCF Model + +```{r} +num_gfr <- 0 +num_burnin <- 1000 +num_mcmc <- 500 +general_params <- list( + adaptive_coding = FALSE, + num_chains = 4, + random_seed = random_seed, + num_threads = 1 +) +num_trees_tau <- 50 +treatment_effect_forest_params <- list( + num_trees = num_trees_tau, + sample_intercept = FALSE, + sigma2_leaf_init = 1 / num_trees_tau +) +bcf_model_classic <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = num_gfr, + num_burnin = num_burnin, + num_mcmc = num_mcmc, + general_params = general_params, + treatment_effect_forest_params = treatment_effect_forest_params +) +``` + +We compare the posterior distribution of the ATE to its true value + +```{r} +cate_posterior_classic <- predict( + bcf_model_classic, + X = X_test, + Z = Z_test, + propensity = pi_test, + type = "posterior", + terms = "cate" +) +ate_posterior_classic <- colMeans(cate_posterior_classic) +hist( + ate_posterior_classic, + freq = F, + xlab = "ATE", + ylab = "Density", + main = "Posterior Distribution of ATE (Classic BCF, Continuous Treatment)" +) +abline(v = tau_x, col = "red", lty = 3, lwd = 3) +``` + +As a rough convergence check, we inspect the traceplot of $\sigma^2$ + +```{r} +sigma2_samples <- extractParameter(bcf_model_classic, "sigma2") +plot( + sigma2_samples, + type = "l", + main = "Traceplot of Sigma^2 (Classic BCF, Continuous Treatment)", + ylab = "Sigma^2", + xlab = "Iteration" +) +abline(h = sigma_true^2, col = "red", lty = 3, lwd = 3) +``` + +### Reparameterized BCF Model + +```{r} +treatment_effect_forest_params <- list( + num_trees = num_trees_tau, + sample_intercept = TRUE, + sigma2_leaf_init = 0.25 / num_trees_tau, + tau_0_prior_var = 1 +) +bcf_model_reparam <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = num_gfr, + num_burnin = num_burnin, + num_mcmc = num_mcmc, + general_params = general_params, + treatment_effect_forest_params = treatment_effect_forest_params +) +``` + +And we compare the posterior distribution of the ATE to its true value + +```{r} +cate_posterior_reparam <- predict( + bcf_model_reparam, + X = X_test, + Z = Z_test, + propensity = pi_test, + type = "posterior", + terms = "cate" +) +ate_posterior_reparam <- colMeans(cate_posterior_reparam) +hist( + ate_posterior_reparam, + freq = F, + xlab = "ATE", + ylab = "Density", + main = "Posterior Distribution of ATE (Reparameterized BCF, Continuous Treatment)" +) +abline(v = tau_x, col = "red", lty = 3, lwd = 3) +``` + +As above, we check convergence by inspecting the traceplot of $\sigma^2$ + +```{r} +sigma2_samples <- extractParameter(bcf_model_reparam, "sigma2") +plot( + sigma2_samples, + type = "l", + main = "Traceplot of Sigma^2 (Reparameterized BCF, Continuous Treatment)", + ylab = "Sigma^2", + xlab = "Iteration" +) +abline(h = sigma_true^2, col = "red", lty = 3, lwd = 3) +``` + +As in the binary treatment case, $\tau_0$ and $\bar{t}(X)$ are negatively correlated across posterior draws + +```{r} +tau_0_posterior <- extractParameter(bcf_model_reparam, "tau_0") +t_x_posterior <- predict( + bcf_model_reparam, + X = X_test, + Z = Z_test, + propensity = pi_test, + type = "posterior", + terms = "tau" +) +t_x_posterior <- colMeans(t_x_posterior) +plot( + tau_0_posterior, + t_x_posterior, + xlab = "tau_0", + ylab = "t(X)", + main = "Posterior of tau_0 vs t(X), averaged over X (Continuous Treatment)" +) +``` + +# References