Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
f89ef05
Added optional parametric treatment effect term to R BCF
andrewherren Mar 10, 2026
deebc33
Expose tau_0 (treatment effect intercept) in Python BCF interface
andrewherren Mar 10, 2026
f295248
Move tau_0 sampling before tau forest in BCF Gibbs sweep
andrewherren Mar 10, 2026
aa0273f
Expose tau_0 as a requestable term in computeBCFPosteriorInterval
andrewherren Mar 10, 2026
8b4d75f
Update BCF str test to reflect tau_0 as a default model term
andrewherren Mar 10, 2026
6e04eb6
Expose tau_0 as a requestable term in BCFModel.compute_posterior_inte…
andrewherren Mar 10, 2026
01c04ef
Fix KeyError when terms='all' in BCFModel.compute_posterior_interval
andrewherren Mar 10, 2026
ab1e8e1
Added simulation study for BCF with parametric treatment term
andrewherren Mar 13, 2026
0ffdeff
Merge branch 'main' into cate_forest_intercept
andrewherren Mar 13, 2026
3097d24
Update pybind11
andrewherren Mar 13, 2026
98edbc6
Updated implementation to only include tau_0 parameter in CATE, but n…
andrewherren Mar 14, 2026
c4a50a4
Reflect main branch update to the pybind11 dependency
andrewherren Mar 16, 2026
fe06967
Remove tau_0 from interfaces designed for covariate-dependent terms (…
andrewherren Mar 17, 2026
0274898
Added vignette on tau_0 and debug script
andrewherren Mar 17, 2026
5ccc38a
Updated tau_0 vignette
andrewherren Mar 17, 2026
515ef86
Updated R vignette
andrewherren Mar 17, 2026
4dc28c3
Updated R vignette and added python vignette
andrewherren Mar 17, 2026
bb4773d
Added continuous treatment example to the reparameterization vignette
andrewherren Mar 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
393 changes: 365 additions & 28 deletions R/bcf.R

Large diffs are not rendered by default.

86 changes: 46 additions & 40 deletions R/posterior_transformation.R
Original file line number Diff line number Diff line change
Expand Up @@ -1055,11 +1055,12 @@ computeBCFPosteriorInterval <- function(
)
}
}
needs_covariates_intermediate <- ((("y_hat" %in% terms) ||
("all" %in% terms)))
needs_covariates <- (("prognostic_function" %in% terms) ||
("cate" %in% terms) ||
("variance_forest" %in% terms) ||
predict_terms <- terms
needs_covariates_intermediate <- ((("y_hat" %in% predict_terms) ||
("all" %in% predict_terms)))
needs_covariates <- (("prognostic_function" %in% predict_terms) ||
("cate" %in% predict_terms) ||
("variance_forest" %in% predict_terms) ||
(needs_covariates_intermediate))
if (needs_covariates) {
if (is.null(X)) {
Expand Down Expand Up @@ -1119,10 +1120,10 @@ computeBCFPosteriorInterval <- function(
}
}
}
needs_rfx_data_intermediate <- ((("y_hat" %in% terms) ||
("all" %in% terms)) &&
needs_rfx_data_intermediate <- ((("y_hat" %in% predict_terms) ||
("all" %in% predict_terms)) &&
model_object$model_params$has_rfx)
needs_rfx_data <- (("rfx" %in% terms) ||
needs_rfx_data <- (("rfx" %in% predict_terms) ||
(needs_rfx_data_intermediate))
if (needs_rfx_data) {
if (is.null(rfx_group_ids)) {
Expand Down Expand Up @@ -1154,42 +1155,47 @@ computeBCFPosteriorInterval <- function(
}
}

# Compute posterior matrices for the requested model terms
predictions <- predict(
model_object,
X = X,
Z = Z,
propensity = propensity,
rfx_group_ids = rfx_group_ids,
rfx_basis = rfx_basis,
type = "posterior",
terms = terms,
scale = scale
)
has_multiple_terms <- ifelse(is.list(predictions), TRUE, FALSE)
result <- list()

# Compute the interval
if (has_multiple_terms) {
result <- list()
for (term_name in names(predictions)) {
if (!is.null(predictions[[term_name]])) {
result[[term_name]] <- summarize_interval(
predictions[[term_name]],
sample_dim = 2,
level = level
)
} else {
result[[term_name]] <- NULL
# Compute posterior matrices for predict-able terms (if any)
if (length(predict_terms) > 0) {
predictions <- predict(
model_object,
X = X,
Z = Z,
propensity = propensity,
rfx_group_ids = rfx_group_ids,
rfx_basis = rfx_basis,
type = "posterior",
terms = predict_terms,
scale = scale
)
if (is.list(predictions)) {
for (term_name in names(predictions)) {
if (!is.null(predictions[[term_name]])) {
result[[term_name]] <- summarize_interval(
predictions[[term_name]],
sample_dim = 2,
level = level
)
} else {
result[[term_name]] <- NULL
}
}
} else {
result[[predict_terms]] <- summarize_interval(
predictions,
sample_dim = 2,
level = level
)
}
return(result)
} else {
return(summarize_interval(
predictions,
sample_dim = 2,
level = level
))
}

# Return single interval directly if only one term was requested
if (length(terms) == 1) {
return(result[[terms]])
}
return(result)
}

#' @title Compute BART Posterior Credible Intervals
Expand Down
Loading
Loading