Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions R/bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ NULL
#' - `sigma2_global_scale` Scale parameter in the `IG(sigma2_global_shape, sigma2_global_scale)` global error variance model. Default: `0`.
#' - `variable_weights` Numeric weights reflecting the relative probability of splitting on each variable. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here. Note that if the propensity score is included as a covariate in either forest, its weight will default to `1/ncol(X_train)`. A workaround if you wish to provide a custom weight for the propensity score is to include it as a column in `X_train` and then set `propensity_covariate` to `'none'` adjust `keep_vars` accordingly for the `prognostic` or `treatment_effect` forests.
#' - `propensity_covariate` Whether to include the propensity score as a covariate in either or both of the forests. Enter `"none"` for neither, `"prognostic"` for the prognostic forest, `"treatment_effect"` for the treatment forest, and `"both"` for both forests. If this is not `"none"` and a propensity score is not provided, it will be estimated from (`X_train`, `Z_train`) using `stochtree::bart()`. Default: `"mu"`.
#' - `adaptive_coding` Whether or not to use an "adaptive coding" scheme in which a binary treatment variable is not coded manually as (0,1) or (-1,1) but learned via parameters `b_0` and `b_1` that attach to the outcome model `[b_0 (1-Z) + b_1 Z] tau(X)`. This is ignored when Z is not binary. Default: `TRUE`.
#' - `adaptive_coding` Whether or not to use an "adaptive coding" scheme in which a binary treatment variable is not coded manually as (0,1) or (-1,1) but learned via parameters `b_0` and `b_1` that attach to the outcome model `[b_0 (1-Z) + b_1 Z] tau(X)`. This is ignored when Z is not binary. Default: `FALSE`.
#' - `control_coding_init` Initial value of the "control" group coding parameter. This is ignored when Z is not binary. Default: `-0.5`.
#' - `treated_coding_init` Initial value of the "treatment" group coding parameter. This is ignored when Z is not binary. Default: `0.5`.
#' - `rfx_prior_var` Prior on the (diagonals of the) covariance of the additive group-level random regression coefficients. Must be a vector of length `ncol(rfx_basis_train)`. Default: `rep(1, ncol(rfx_basis_train))`
Expand Down Expand Up @@ -258,7 +258,7 @@ bcf <- function(
sigma2_global_scale = 0,
variable_weights = NULL,
propensity_covariate = "prognostic",
adaptive_coding = TRUE,
adaptive_coding = FALSE,
control_coding_init = -0.5,
treated_coding_init = 0.5,
rfx_prior_var = NULL,
Expand Down
148 changes: 88 additions & 60 deletions stochtree/bart.py

Large diffs are not rendered by default.

201 changes: 115 additions & 86 deletions stochtree/bcf.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions test/R/testthat/test-multi-chain.R
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ test_that("BCF multi-chain: sample counts with GFR warm-start", {
propensity_train = d$pi_train,
X_test = d$X_test, Z_test = d$Z_test, propensity_test = d$pi_test,
num_gfr = n_gfr, num_burnin = 5, num_mcmc = n_mcmc,
general_params = list(num_chains = n_chains, num_threads = 1)
general_params = list(num_chains = n_chains, num_threads = 1, adaptive_coding = TRUE)
)
expected <- n_chains * n_mcmc
expect_length(m$sigma2_global_samples, expected)
Expand Down Expand Up @@ -279,7 +279,7 @@ test_that("BCF multi-chain: all samples finite with GFR + multiple chains", {
propensity_train = d$pi_train,
X_test = d$X_test, Z_test = d$Z_test, propensity_test = d$pi_test,
num_gfr = 6, num_burnin = 20, num_mcmc = 10,
general_params = list(num_chains = 3, num_threads = 1)
general_params = list(num_chains = 3, num_threads = 1, adaptive_coding = TRUE)
)
expect_true(all(is.finite(m$sigma2_global_samples)),
label = "sigma2 samples must be finite (no chain-transition blowup)")
Expand Down
6 changes: 3 additions & 3 deletions test/R/testthat/test-print-summary.R
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ test_that("BCF print method", {
pi_train <- pi_X[train_inds]; pi_test <- pi_X[test_inds]
y_train <- y[train_inds]; y_test <- y[test_inds]

# --- User-provided propensity, binary treatment, adaptive coding (defaults) ---
# --- User-provided propensity, binary treatment, default coding (defaults) ---
bcf_model <- bcf(
X_train = X_train, y_train = y_train, Z_train = Z_train,
propensity_train = pi_train,
Expand All @@ -249,7 +249,7 @@ test_that("BCF print method", {
expect_true(any(grepl("prognostic forest", out, fixed = TRUE)))
expect_true(any(grepl("treatment effect forest", out, fixed = TRUE)))
expect_true(any(grepl("User-provided propensity scores", out, fixed = TRUE)))
expect_true(any(grepl("adaptive coding", out, fixed = TRUE)))
expect_true(any(grepl("default coding", out, fixed = TRUE)))
expect_true(any(grepl("1 chain of", out, fixed = TRUE)))
expect_true(any(grepl("retaining every iteration", out, fixed = TRUE)))

Expand Down Expand Up @@ -327,7 +327,7 @@ test_that("BCF summary method", {
propensity_train = pi_train,
X_test = X_test, Z_test = Z_test, propensity_test = pi_test,
num_gfr = 0, num_burnin = 10, num_mcmc = 10,
general_params = list(sample_sigma2_global = TRUE),
general_params = list(sample_sigma2_global = TRUE, adaptive_coding = TRUE),
prognostic_forest_params = list(sample_sigma2_leaf = TRUE),
treatment_effect_forest_params = list(sample_sigma2_leaf = TRUE)
)
Expand Down
6 changes: 4 additions & 2 deletions test/R/testthat/test-serialization.R
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ test_that("BCF JSON uses canonical field names (sigma2_init, b1_samples, b0_samp
bcf_model <- bcf(
X_train = X, Z_train = Z, y_train = y,
propensity_train = pi_x,
num_gfr = 0, num_burnin = 0, num_mcmc = 10
num_gfr = 0, num_burnin = 0, num_mcmc = 10,
general_params = list(adaptive_coding = TRUE)
)
json_string <- saveBCFModelToJsonString(bcf_model)

Expand Down Expand Up @@ -224,7 +225,8 @@ test_that("BCF JSON deserialization handles legacy field names with warnings", {
bcf_model <- bcf(
X_train = X, Z_train = Z, y_train = y,
propensity_train = pi_x,
num_gfr = 0, num_burnin = 0, num_mcmc = 10
num_gfr = 0, num_burnin = 0, num_mcmc = 10,
general_params = list(adaptive_coding = TRUE)
)
preds_orig <- predict(bcf_model, X_test, Z_test, pi_test)

Expand Down
12 changes: 12 additions & 0 deletions test/python/test_bcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def test_continuous_univariate_bcf(self):
num_mcmc = 10

# Run BCF with test set and propensity score
# adaptive_coding=True triggers a UserWarning for non-binary treatment
with pytest.warns(UserWarning):
bcf_model = BCFModel()
variance_forest_params = {"num_trees": 0}
Expand All @@ -261,6 +262,7 @@ def test_continuous_univariate_bcf(self):
num_gfr=num_gfr,
num_burnin=num_burnin,
num_mcmc=num_mcmc,
general_params={"adaptive_coding": True},
variance_forest_params=variance_forest_params,
)

Expand Down Expand Up @@ -304,6 +306,7 @@ def test_continuous_univariate_bcf(self):
num_gfr=num_gfr,
num_burnin=num_burnin,
num_mcmc=num_mcmc,
general_params={"adaptive_coding": True},
variance_forest_params=variance_forest_params,
)

Expand Down Expand Up @@ -375,6 +378,7 @@ def test_continuous_univariate_bcf(self):
num_gfr=num_gfr,
num_burnin=num_burnin,
num_mcmc=num_mcmc,
general_params={"adaptive_coding": True},
variance_forest_params=variance_forest_params,
)

Expand Down Expand Up @@ -413,6 +417,7 @@ def test_continuous_univariate_bcf(self):
num_gfr=num_gfr,
num_burnin=num_burnin,
num_mcmc=num_mcmc,
general_params={"adaptive_coding": True},
variance_forest_params=variance_forest_params,
)

Expand Down Expand Up @@ -452,6 +457,7 @@ def test_continuous_univariate_bcf(self):
num_gfr=num_gfr,
num_burnin=num_burnin,
num_mcmc=num_mcmc,
general_params={"adaptive_coding": True},
variance_forest_params=variance_forest_params,
)

Expand Down Expand Up @@ -486,6 +492,7 @@ def test_continuous_univariate_bcf(self):
num_gfr=num_gfr,
num_burnin=num_burnin,
num_mcmc=num_mcmc,
general_params={"adaptive_coding": True},
variance_forest_params=variance_forest_params,
)

Expand Down Expand Up @@ -576,6 +583,7 @@ def test_multivariate_bcf(self):
num_mcmc = 10

# Run BCF with test set and propensity score
# adaptive_coding=True triggers a UserWarning for non-binary treatment
with pytest.warns(UserWarning):
bcf_model = BCFModel()
variance_forest_params = {"num_trees": 0}
Expand All @@ -590,6 +598,7 @@ def test_multivariate_bcf(self):
num_gfr=num_gfr,
num_burnin=num_burnin,
num_mcmc=num_mcmc,
general_params={"adaptive_coding": True},
variance_forest_params=variance_forest_params,
)

Expand Down Expand Up @@ -630,6 +639,7 @@ def test_multivariate_bcf(self):
num_gfr=num_gfr,
num_burnin=num_burnin,
num_mcmc=num_mcmc,
general_params={"adaptive_coding": True},
variance_forest_params=variance_forest_params,
)

Expand Down Expand Up @@ -668,6 +678,7 @@ def test_multivariate_bcf(self):
num_gfr=num_gfr,
num_burnin=num_burnin,
num_mcmc=num_mcmc,
general_params={"adaptive_coding": True},
variance_forest_params=variance_forest_params,
)

Expand All @@ -682,6 +693,7 @@ def test_multivariate_bcf(self):
num_gfr=num_gfr,
num_burnin=num_burnin,
num_mcmc=num_mcmc,
general_params={"adaptive_coding": True},
variance_forest_params=variance_forest_params,
)

Expand Down
10 changes: 9 additions & 1 deletion test/python/test_multi_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,14 @@ def test_sample_counts_with_gfr(self, bcf_data):
n_chains = self.NUM_CHAINS
n_mcmc = self.NUM_MCMC
n_gfr = self.NUM_GFR
m = _bcf(bcf_data, num_gfr=n_gfr, num_burnin=5, num_mcmc=n_mcmc, num_chains=n_chains)
m = _bcf(
bcf_data,
num_gfr=n_gfr,
num_burnin=5,
num_mcmc=n_mcmc,
num_chains=n_chains,
adaptive_coding=True,
)
expected = n_chains * n_mcmc
assert m.global_var_samples.shape == (expected,)
# BCF-specific samples
Expand Down Expand Up @@ -254,6 +261,7 @@ def test_samples_finite_gfr_multi_chain(self, bcf_data):
num_burnin=20,
num_mcmc=self.NUM_MCMC,
num_chains=self.NUM_CHAINS,
adaptive_coding=True,
)
assert np.all(np.isfinite(m.global_var_samples)), (
"sigma2 samples contain non-finite values; possible chain-transition blowup."
Expand Down
4 changes: 2 additions & 2 deletions test/python/test_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def test_unsampled_model(self):
assert "Empty BCFModel()" in str(BCFModel())

def test_default_model(self, bcf_data):
"""Binary treatment, user propensity, adaptive coding (defaults): 2 base terms."""
"""Binary treatment, user propensity, default coding (defaults): 2 base terms."""
model = BCFModel()
model.sample(
X_train=bcf_data["X_train"],
Expand All @@ -350,7 +350,7 @@ def test_default_model(self, bcf_data):
assert "BCFModel run with prognostic forest" in s
assert "treatment effect forest" in s
assert "User-provided propensity scores" in s
assert "adaptive coding" in s
assert "default coding" in s
assert "1 chain" in s
assert "retaining every iteration" in s

Expand Down
Loading