Skip to content

Commit e4b6c06

Browse files
authored
Merge pull request #327 from StochasticTree/python_docs_rework
Moving parameter list documentation to notes section in python docstrings for BART and BCF and change BCF adaptive coding default to False
2 parents be809c6 + fa7a629 commit e4b6c06

9 files changed

Lines changed: 237 additions & 158 deletions

File tree

R/bcf.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ NULL
102102
#' - `sigma2_global_scale` Scale parameter in the `IG(sigma2_global_shape, sigma2_global_scale)` global error variance model. Default: `0`.
103103
#' - `variable_weights` Numeric weights reflecting the relative probability of splitting on each variable. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here. Note that if the propensity score is included as a covariate in either forest, its weight will default to `1/ncol(X_train)`. A workaround if you wish to provide a custom weight for the propensity score is to include it as a column in `X_train` and then set `propensity_covariate` to `'none'` adjust `keep_vars` accordingly for the `prognostic` or `treatment_effect` forests.
104104
#' - `propensity_covariate` Whether to include the propensity score as a covariate in either or both of the forests. Enter `"none"` for neither, `"prognostic"` for the prognostic forest, `"treatment_effect"` for the treatment forest, and `"both"` for both forests. If this is not `"none"` and a propensity score is not provided, it will be estimated from (`X_train`, `Z_train`) using `stochtree::bart()`. Default: `"mu"`.
105-
#' - `adaptive_coding` Whether or not to use an "adaptive coding" scheme in which a binary treatment variable is not coded manually as (0,1) or (-1,1) but learned via parameters `b_0` and `b_1` that attach to the outcome model `[b_0 (1-Z) + b_1 Z] tau(X)`. This is ignored when Z is not binary. Default: `TRUE`.
105+
#' - `adaptive_coding` Whether or not to use an "adaptive coding" scheme in which a binary treatment variable is not coded manually as (0,1) or (-1,1) but learned via parameters `b_0` and `b_1` that attach to the outcome model `[b_0 (1-Z) + b_1 Z] tau(X)`. This is ignored when Z is not binary. Default: `FALSE`.
106106
#' - `control_coding_init` Initial value of the "control" group coding parameter. This is ignored when Z is not binary. Default: `-0.5`.
107107
#' - `treated_coding_init` Initial value of the "treatment" group coding parameter. This is ignored when Z is not binary. Default: `0.5`.
108108
#' - `rfx_prior_var` Prior on the (diagonals of the) covariance of the additive group-level random regression coefficients. Must be a vector of length `ncol(rfx_basis_train)`. Default: `rep(1, ncol(rfx_basis_train))`
@@ -258,7 +258,7 @@ bcf <- function(
258258
sigma2_global_scale = 0,
259259
variable_weights = NULL,
260260
propensity_covariate = "prognostic",
261-
adaptive_coding = TRUE,
261+
adaptive_coding = FALSE,
262262
control_coding_init = -0.5,
263263
treated_coding_init = 0.5,
264264
rfx_prior_var = NULL,

stochtree/bart.py

Lines changed: 88 additions & 60 deletions
Large diffs are not rendered by default.

stochtree/bcf.py

Lines changed: 115 additions & 86 deletions
Large diffs are not rendered by default.

test/R/testthat/test-multi-chain.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ test_that("BCF multi-chain: sample counts with GFR warm-start", {
224224
propensity_train = d$pi_train,
225225
X_test = d$X_test, Z_test = d$Z_test, propensity_test = d$pi_test,
226226
num_gfr = n_gfr, num_burnin = 5, num_mcmc = n_mcmc,
227-
general_params = list(num_chains = n_chains, num_threads = 1)
227+
general_params = list(num_chains = n_chains, num_threads = 1, adaptive_coding = TRUE)
228228
)
229229
expected <- n_chains * n_mcmc
230230
expect_length(m$sigma2_global_samples, expected)
@@ -279,7 +279,7 @@ test_that("BCF multi-chain: all samples finite with GFR + multiple chains", {
279279
propensity_train = d$pi_train,
280280
X_test = d$X_test, Z_test = d$Z_test, propensity_test = d$pi_test,
281281
num_gfr = 6, num_burnin = 20, num_mcmc = 10,
282-
general_params = list(num_chains = 3, num_threads = 1)
282+
general_params = list(num_chains = 3, num_threads = 1, adaptive_coding = TRUE)
283283
)
284284
expect_true(all(is.finite(m$sigma2_global_samples)),
285285
label = "sigma2 samples must be finite (no chain-transition blowup)")

test/R/testthat/test-print-summary.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ test_that("BCF print method", {
234234
pi_train <- pi_X[train_inds]; pi_test <- pi_X[test_inds]
235235
y_train <- y[train_inds]; y_test <- y[test_inds]
236236

237-
# --- User-provided propensity, binary treatment, adaptive coding (defaults) ---
237+
# --- User-provided propensity, binary treatment, default coding (defaults) ---
238238
bcf_model <- bcf(
239239
X_train = X_train, y_train = y_train, Z_train = Z_train,
240240
propensity_train = pi_train,
@@ -249,7 +249,7 @@ test_that("BCF print method", {
249249
expect_true(any(grepl("prognostic forest", out, fixed = TRUE)))
250250
expect_true(any(grepl("treatment effect forest", out, fixed = TRUE)))
251251
expect_true(any(grepl("User-provided propensity scores", out, fixed = TRUE)))
252-
expect_true(any(grepl("adaptive coding", out, fixed = TRUE)))
252+
expect_true(any(grepl("default coding", out, fixed = TRUE)))
253253
expect_true(any(grepl("1 chain of", out, fixed = TRUE)))
254254
expect_true(any(grepl("retaining every iteration", out, fixed = TRUE)))
255255

@@ -327,7 +327,7 @@ test_that("BCF summary method", {
327327
propensity_train = pi_train,
328328
X_test = X_test, Z_test = Z_test, propensity_test = pi_test,
329329
num_gfr = 0, num_burnin = 10, num_mcmc = 10,
330-
general_params = list(sample_sigma2_global = TRUE),
330+
general_params = list(sample_sigma2_global = TRUE, adaptive_coding = TRUE),
331331
prognostic_forest_params = list(sample_sigma2_leaf = TRUE),
332332
treatment_effect_forest_params = list(sample_sigma2_leaf = TRUE)
333333
)

test/R/testthat/test-serialization.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,8 @@ test_that("BCF JSON uses canonical field names (sigma2_init, b1_samples, b0_samp
193193
bcf_model <- bcf(
194194
X_train = X, Z_train = Z, y_train = y,
195195
propensity_train = pi_x,
196-
num_gfr = 0, num_burnin = 0, num_mcmc = 10
196+
num_gfr = 0, num_burnin = 0, num_mcmc = 10,
197+
general_params = list(adaptive_coding = TRUE)
197198
)
198199
json_string <- saveBCFModelToJsonString(bcf_model)
199200

@@ -224,7 +225,8 @@ test_that("BCF JSON deserialization handles legacy field names with warnings", {
224225
bcf_model <- bcf(
225226
X_train = X, Z_train = Z, y_train = y,
226227
propensity_train = pi_x,
227-
num_gfr = 0, num_burnin = 0, num_mcmc = 10
228+
num_gfr = 0, num_burnin = 0, num_mcmc = 10,
229+
general_params = list(adaptive_coding = TRUE)
228230
)
229231
preds_orig <- predict(bcf_model, X_test, Z_test, pi_test)
230232

test/python/test_bcf.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ def test_continuous_univariate_bcf(self):
247247
num_mcmc = 10
248248

249249
# Run BCF with test set and propensity score
250+
# adaptive_coding=True triggers a UserWarning for non-binary treatment
250251
with pytest.warns(UserWarning):
251252
bcf_model = BCFModel()
252253
variance_forest_params = {"num_trees": 0}
@@ -261,6 +262,7 @@ def test_continuous_univariate_bcf(self):
261262
num_gfr=num_gfr,
262263
num_burnin=num_burnin,
263264
num_mcmc=num_mcmc,
265+
general_params={"adaptive_coding": True},
264266
variance_forest_params=variance_forest_params,
265267
)
266268

@@ -304,6 +306,7 @@ def test_continuous_univariate_bcf(self):
304306
num_gfr=num_gfr,
305307
num_burnin=num_burnin,
306308
num_mcmc=num_mcmc,
309+
general_params={"adaptive_coding": True},
307310
variance_forest_params=variance_forest_params,
308311
)
309312

@@ -375,6 +378,7 @@ def test_continuous_univariate_bcf(self):
375378
num_gfr=num_gfr,
376379
num_burnin=num_burnin,
377380
num_mcmc=num_mcmc,
381+
general_params={"adaptive_coding": True},
378382
variance_forest_params=variance_forest_params,
379383
)
380384

@@ -413,6 +417,7 @@ def test_continuous_univariate_bcf(self):
413417
num_gfr=num_gfr,
414418
num_burnin=num_burnin,
415419
num_mcmc=num_mcmc,
420+
general_params={"adaptive_coding": True},
416421
variance_forest_params=variance_forest_params,
417422
)
418423

@@ -452,6 +457,7 @@ def test_continuous_univariate_bcf(self):
452457
num_gfr=num_gfr,
453458
num_burnin=num_burnin,
454459
num_mcmc=num_mcmc,
460+
general_params={"adaptive_coding": True},
455461
variance_forest_params=variance_forest_params,
456462
)
457463

@@ -486,6 +492,7 @@ def test_continuous_univariate_bcf(self):
486492
num_gfr=num_gfr,
487493
num_burnin=num_burnin,
488494
num_mcmc=num_mcmc,
495+
general_params={"adaptive_coding": True},
489496
variance_forest_params=variance_forest_params,
490497
)
491498

@@ -576,6 +583,7 @@ def test_multivariate_bcf(self):
576583
num_mcmc = 10
577584

578585
# Run BCF with test set and propensity score
586+
# adaptive_coding=True triggers a UserWarning for non-binary treatment
579587
with pytest.warns(UserWarning):
580588
bcf_model = BCFModel()
581589
variance_forest_params = {"num_trees": 0}
@@ -590,6 +598,7 @@ def test_multivariate_bcf(self):
590598
num_gfr=num_gfr,
591599
num_burnin=num_burnin,
592600
num_mcmc=num_mcmc,
601+
general_params={"adaptive_coding": True},
593602
variance_forest_params=variance_forest_params,
594603
)
595604

@@ -630,6 +639,7 @@ def test_multivariate_bcf(self):
630639
num_gfr=num_gfr,
631640
num_burnin=num_burnin,
632641
num_mcmc=num_mcmc,
642+
general_params={"adaptive_coding": True},
633643
variance_forest_params=variance_forest_params,
634644
)
635645

@@ -668,6 +678,7 @@ def test_multivariate_bcf(self):
668678
num_gfr=num_gfr,
669679
num_burnin=num_burnin,
670680
num_mcmc=num_mcmc,
681+
general_params={"adaptive_coding": True},
671682
variance_forest_params=variance_forest_params,
672683
)
673684

@@ -682,6 +693,7 @@ def test_multivariate_bcf(self):
682693
num_gfr=num_gfr,
683694
num_burnin=num_burnin,
684695
num_mcmc=num_mcmc,
696+
general_params={"adaptive_coding": True},
685697
variance_forest_params=variance_forest_params,
686698
)
687699

test/python/test_multi_chain.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,14 @@ def test_sample_counts_with_gfr(self, bcf_data):
215215
n_chains = self.NUM_CHAINS
216216
n_mcmc = self.NUM_MCMC
217217
n_gfr = self.NUM_GFR
218-
m = _bcf(bcf_data, num_gfr=n_gfr, num_burnin=5, num_mcmc=n_mcmc, num_chains=n_chains)
218+
m = _bcf(
219+
bcf_data,
220+
num_gfr=n_gfr,
221+
num_burnin=5,
222+
num_mcmc=n_mcmc,
223+
num_chains=n_chains,
224+
adaptive_coding=True,
225+
)
219226
expected = n_chains * n_mcmc
220227
assert m.global_var_samples.shape == (expected,)
221228
# BCF-specific samples
@@ -254,6 +261,7 @@ def test_samples_finite_gfr_multi_chain(self, bcf_data):
254261
num_burnin=20,
255262
num_mcmc=self.NUM_MCMC,
256263
num_chains=self.NUM_CHAINS,
264+
adaptive_coding=True,
257265
)
258266
assert np.all(np.isfinite(m.global_var_samples)), (
259267
"sigma2 samples contain non-finite values; possible chain-transition blowup."

test/python/test_str.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ def test_unsampled_model(self):
332332
assert "Empty BCFModel()" in str(BCFModel())
333333

334334
def test_default_model(self, bcf_data):
335-
"""Binary treatment, user propensity, adaptive coding (defaults): 2 base terms."""
335+
"""Binary treatment, user propensity, default coding (defaults): 2 base terms."""
336336
model = BCFModel()
337337
model.sample(
338338
X_train=bcf_data["X_train"],
@@ -350,7 +350,7 @@ def test_default_model(self, bcf_data):
350350
assert "BCFModel run with prognostic forest" in s
351351
assert "treatment effect forest" in s
352352
assert "User-provided propensity scores" in s
353-
assert "adaptive coding" in s
353+
assert "default coding" in s
354354
assert "1 chain" in s
355355
assert "retaining every iteration" in s
356356

0 commit comments

Comments
 (0)