PR link: #4
Tracking issue: StochasticTree/stochtree#200
Add first-class support for sampling from the BART prior — the distribution over sum-of-trees functions before any data is observed. This is accomplished in two phases:
- Exposing observation weights ("case weights") in the high-level
bart()andbcf()interfaces. Weights are broadly useful independent of prior sampling (survey-weighted regression, importance resampling, subgroup upweighting). - A dedicated direct C++ prior sampler (
SampleForestFromPrior) that samples tree structures from the branching process prior and leaf parameters fromN(0, σ²_μ)— no outcome data, no MCMC, independent draws. Thin wrappersbart_prior_sample()(Python) /sampleBARTPrior()(R) expose this to users.
Sampling from the BART prior is a standard Bayesian workflow step with at least three concrete use cases:
-
Prior calibration. Users want to understand what functions the BART prior places non-negligible probability on before fitting, so they can choose hyperparameters (
num_trees,$\alpha$ ,$\beta$ , leaf scale$\sigma^2_{\mu}$ ) that encode their substantive prior beliefs rather than relying on defaults. -
Prior predictive checks. Generating outcome data conditioned on covariates and the model prior lets users verify that the prior is not inadvertently ruling out plausible outcomes (e.g., too-narrow leaf scale causing all predicted values to cluster near zero).
-
Standalone prior documentation / vignette. A worked example showing what BART "believes" out of the box is a valuable pedagogical resource.
The current stochtree high-level interface (bart() / bcf()) has no mechanism for users to tell the sampler to ignore the data. The only path today is through the low-level ForestDataset / TreeSampler API, which is not a particularly user-friendly interface.
Additionally, observation weights have independent value beyond prior sampling: survey-weighted regression, importance-weighted resampling, and upweighting rare subpopulations all require them. Exposing weights in the high-level interface is a broadly useful enhancement that prior sampling happens to need.
Observation weights are already fully supported in the stochtree C++ core. In include/stochtree/leaf_model.h the weighted BART likelihood is documented explicitly:
y_i | - ~ N(μ(X_i), σ² / w_i)
Weighted sufficient statistics (s_{w,ℓ}, s_{wy,ℓ}, s_{wyy,ℓ}) are used in both split evaluation and leaf parameter sampling. The ForestDataset class (include/stochtree/data.h) stores weights as a ColumnVector var_weights_ and exposes SetVarWeightValue(). The R-level ForestDataset wrapper (R/data.R) already has AddVarianceWeights() and update_variance_weights() methods. No new C++ work is required.
Add an observation_weights keyword argument to BARTModel.sample() in stochtree/bart.py:
def sample(self, X_train, y_train, ..., observation_weights=None, ...):- Type:
np.ndarrayof shape(n,), orNone(default = all ones). - Validated to be non-negative with
np.all(observation_weights >= 0). - Passed to the
ForestDatasetconstructor /add_weights()call that buildsforest_dataset_train(currently constructed around line 210 ofbart.py).
The same change applies to BCFModel.sample() in stochtree/bcf.py. Both the prognostic forest dataset and the treatment effect forest dataset should receive the same weights (observation weights are properties of the outcome, not the forest).
Add an observation_weights argument (default NULL) to the bart() function in R/bart.R and to bcf() in R/bcf.R:
bart <- function(X_train, y_train, ..., observation_weights = NULL, ...) {- When non-
NULL: validate as a numeric vector of lengthnrow(X_train)with all values>= 0. - Pass to
forest_dataset$AddVarianceWeights()afterForestDatasetconstruction (currently done around line 120 ofbart.R).
The bcf() function should apply weights to both mu_forest_data and tau_forest_data.
Observation weights are not part of the model — they are properties of the training data and do not need to be serialized to JSON. No changes to the serialization layer are required.
Most other stochtree modeling features should be compatible with general-purpose observation weights, though we need to be careful to accumulate weighted sufficient statistics for other model terms.
Variance forests and discrete outcome models should at minimum raise a warning and potentially preclude the use of observation weights. See [[#Open Questions]] below.
Phase 2 implements a standalone prior sampler that bypasses the MCMC machinery entirely. Rather than running the sampler on a degenerate dataset, a dedicated C++ function samples tree structures directly from the branching process prior and draws leaf parameters from their prior distribution.
Why not zero-weight MCMC (originally planned):
The original Phase 2 design called for running BARTModel.sample() / bart() with y_dummy = zeros and observation_weights = zeros. This was abandoned after implementation revealed two blocking problems:
-
C++ NaN with exact zero weights. Setting
w_i = 0produces NaN predictions from the C++ sampler. This is because they are treated as unit-specific variance adjustments that accumulate as1 / w_iwith no special code path for zero weights. -
Large-σ² workaround has MCMC autocorrelation. An alternative — setting
sigma2_init = 1e6andsample_sigma2_global = Falseso the leaf posterior collapses to its prior — works mathematically but produces correlated samples (the MCMC chain still mixes between tree topologies, so draws are not independent). It also means eschewing sampling from the$\sigma^2$ prior.
Chosen design — MCMCPriorSampleOneIter:
Add a new C++ function MCMCPriorSampleOneIter (in src/tree_sampler.cpp / include/stochtree/tree_sampler.h) that:
- Accepts most of the same arguments as
MCMCSampleOneIter(i.e.TreeEnsemble,ForestTracker,ForestContainer, etc). - For each sample, iterates over all trees and grows each tree by repeatedly sampling grow/prune moves from the branching process prior alone — no data, no sufficient statistics, no likelihood in the MH acceptance ratio.
- After each tree has been sampled, draws each leaf parameter independently from
N(0, σ²_μ).
This produces independent samples (no chain, no autocorrelation), requires no outcome data, and is free of any likelihood computation. The covariate matrix X is still needed to evaluate the forest at prediction time but plays no role during sampling.
New pybind11 bindings (py_stochtree.cpp) and cpp11 wrappers (R/sampler.R) will expose MCMCPriorSampleOneIter to the high-level interfaces.
def sample_bart_prior(
X: np.ndarray,
num_samples: int = 100,
general_params: dict | None = None,
mean_forest_params: dict | None = None,
) -> BARTModel:
"""
Sample from the BART prior over functions f: X -> R.
Produces independent prior draws by sampling tree structures from
the branching process prior and leaf parameters from N(0, sigma2_leaf).
No outcome data is required or used.
Parameters
----------
X : np.ndarray, shape (n, p)
Covariate matrix defining the domain.
num_samples : int
Number of independent prior draws.
general_params, mean_forest_params : dict, optional
Controls num_trees, alpha/beta, leaf scale, etc.
Returns
-------
BARTModel
Object whose y_hat_train contains num_samples prior draws of f(X).
"""
# initialize data structures
# call pybind11 wrapper around MCMCPriorSampleOneIter
# unpack and return predictions and prior parameter samplesModule-level function in stochtree/bart.py, exported from stochtree/__init__.py.
sampleBARTPrior <- function(
X,
num_samples = 100,
general_params = list(),
mean_forest_params = list()
) {
# initialize data structures
# call pybind11 wrapper around MCMCPriorSampleOneIter
# unpack and return predictions and prior parameter samples
}Lives in R/bart.R, exported from NAMESPACE.
Not applicable — the direct prior sampler has no GFR warm-start. The convenience wrappers do not expose a num_gfr argument.
A new multilingual vignette should demonstrate:
- Sampling
f ~ BART priorusingsampleBARTPrior()(R) /sample_bart_prior()(Python) - Plotting several prior draws as functions of a 1D covariate to build intuition about how
num_trees,alpha,beta, and leaf scale affect the prior - Computing prior predictive distributions:
p(y | X, prior)by adding$\sigma^2$ draws on top of thefdraws - A simple calibration example: choosing
num_treesso that the prior 95% interval forf(x)spans a user-specified range
- Prior calibration becomes accessible to any stochtree user with a one-line call, rather than the low-level API
- Observation weights independently enable weighted regression for survey data, importance sampling, and subgroup upweighting — use cases that appear repeatedly in applied work
- Pedagogical value: a "what does BART believe before data?" vignette is a commonly requested resource and helps new users build correct intuition about the model
- New C++ code path.
MCMCPriorSampleOneIteris a new function that bypasses the MCMC machinery entirely. It needs careful unit testing to confirm that the marginal variance off(x)matchesnum_trees × leaf_scale²for a flat leaf prior and that tree depth distributions match the branching process prior analytically. observation_weightsname vs.variance_weightsin C++. The C++ layer calls thesevar_weights(short for "variance weights") because they scale the residual variance asσ²/w_i. The high-level nameobservation_weightsis more familiar to applied users. This naming gap should be documented in the API docstrings to avoid confusion.- BCF prior sampling deferred. The BCF model has additional complexity (adaptive coding, parametric CATE intercept, propensity model) that makes designing a BCF prior sampler non-trivial. This is deferred until the BART prior sampler is complete and validated.
Run the MCMC sampler on a dummy dataset with y_train = zeros and observation_weights = zeros. With all weights zero, the weighted sufficient statistics vanish and the sampler should draw from the prior.
Why not chosen: Implemented and then abandoned. Exact zero weights produce NaN predictions from the C++ sampler (likely a σ²/w_i evaluation in the residual computation). A large-σ² workaround (sigma2_init = 1e6, sample_sigma2_global = False) makes the leaf posterior collapse to its prior and gives numerically correct results, but the chain still mixes between tree topologies so draws are correlated, not independent. The direct C++ prior sampler (chosen design) solves both problems cleanly.
Add observation_weights to the high-level interface (Phase 1) but skip the convenience wrapper (Phase 2), leaving it to users and documentation to combine the approach manually.
Why not chosen: The convenience wrapper is small and substantially lowers the barrier for the prior-sampling use case. Moot now that the zero-weight approach is off the table — the direct sampler has to be implemented regardless.
Use a tiny positive weight (e.g., w_i = 1e-10) instead of exact zero to avoid division-by-zero.
Why not chosen: Implemented and tested experimentally. With w = 1e-10, predictions are near-zero (the data contribution is effectively silenced) but the sampler appears stuck and the variance behavior is unpredictable. With w = 0.1, the chain mixes but is clearly data-influenced. No epsilon value makes the sampler cleanly sample from the prior without either NaN, stuck chains, or data contamination.
- Interactions with observation weights and complex models: resolved for Phase 1 — variance forests produce a warning (untested), cloglog link produces an error (not compatible), probit is allowed (latent Gaussian, weights well-defined at the latent level).