Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

## Bug Fixes

* Fixed multi-chain BCF bugs with the parametric intercept term in R and Python [#326](https://github.com/StochasticTree/stochtree/pull/326)
* Fixed indexing bugs for multivariate treatment BCF in Python [#326](https://github.com/StochasticTree/stochtree/pull/326)

## Documentation and Other Maintenance

# stochtree 0.4.1
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

## Bug Fixes

* Fixed multi-chain BCF bugs with the parametric intercept term in R and Python [#326](https://github.com/StochasticTree/stochtree/pull/326)
* Fixed indexing bugs for multivariate treatment BCF in Python [#326](https://github.com/StochasticTree/stochtree/pull/326)

## Documentation and Other Maintenance

# stochtree 0.4.1
Expand Down
79 changes: 77 additions & 2 deletions R/bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,11 @@ bcf <- function(
previous_b_1_samples <- NULL
previous_b_0_samples <- NULL
}
if (previous_bcf_model$model_params$sample_tau_0) {
previous_tau_0_samples <- previous_bcf_model$tau_0_samples
} else {
previous_tau_0_samples <- NULL
}
previous_model_num_samples <- previous_bcf_model$model_params$num_samples
if (previous_model_warmstart_sample_num > previous_model_num_samples) {
stop(
Expand All @@ -601,6 +606,7 @@ bcf <- function(
previous_forest_samples_variance <- NULL
previous_b_1_samples <- NULL
previous_b_0_samples <- NULL
previous_tau_0_samples <- NULL
}

# Determine whether conditional variance will be modeled
Expand Down Expand Up @@ -2162,6 +2168,7 @@ bcf <- function(
)
}
if (adaptive_coding) {
tau_basis_train_old <- tau_basis_train
current_b_1 <- b_1_samples[forest_ind + 1]
current_b_0 <- b_0_samples[forest_ind + 1]
tau_basis_train <- (1 - Z_train) *
Expand All @@ -2179,6 +2186,21 @@ bcf <- function(
outcome_train,
active_forest_tau
)
# Correct residual for tau_0 component of the basis change
if (sample_tau_0) {
outcome_train$subtract_vector(
as.numeric((tau_basis_train - tau_basis_train_old) * tau_0[1])
)
}
}
# Reset tau_0 intercept and correct the running residual
if (sample_tau_0) {
tau_0_old <- tau_0
tau_0 <- tau_0_samples[, forest_ind + 1]
Z_basis_gfr <- as.matrix(tau_basis_train)
outcome_train$subtract_vector(
as.numeric(Z_basis_gfr %*% matrix(tau_0 - tau_0_old, ncol = 1))
)
}
if (sample_sigma2_global) {
current_sigma2 <- global_var_samples[forest_ind + 1]
Expand Down Expand Up @@ -2255,6 +2277,7 @@ bcf <- function(
)
}
if (adaptive_coding) {
tau_basis_train_old <- tau_basis_train
if (!is.null(previous_b_1_samples)) {
current_b_1 <- previous_b_1_samples[
warmstart_index
Expand All @@ -2280,6 +2303,22 @@ bcf <- function(
outcome_train,
active_forest_tau
)
# Correct residual for tau_0 component of the basis change
if (sample_tau_0) {
outcome_train$subtract_vector(
as.numeric((tau_basis_train - tau_basis_train_old) * tau_0[1])
)
}
}
# Reset tau_0 intercept and correct the running residual
if (sample_tau_0 && !is.null(previous_tau_0_samples)) {
tau_0_old <- tau_0
# previous model stores tau_0 in original scale; convert to standardized scale
tau_0 <- as.numeric(previous_tau_0_samples[, warmstart_index] / previous_y_scale)
Z_basis_ws <- as.matrix(tau_basis_train)
outcome_train$subtract_vector(
as.numeric(Z_basis_ws %*% matrix(tau_0 - tau_0_old, ncol = 1))
)
}
if (has_rfx) {
if (is.null(previous_rfx_samples)) {
Expand Down Expand Up @@ -2389,6 +2428,7 @@ bcf <- function(
)
}
if (adaptive_coding) {
tau_basis_train_old <- tau_basis_train
current_b_1 <- b_1
current_b_0 <- b_0
tau_basis_train <- (1 - Z_train) *
Expand All @@ -2406,6 +2446,21 @@ bcf <- function(
outcome_train,
active_forest_tau
)
# Correct residual for tau_0 component of the basis change
if (sample_tau_0) {
outcome_train$subtract_vector(
as.numeric((tau_basis_train - tau_basis_train_old) * tau_0[1])
)
}
}
# Reset tau_0 to initial value (0) and correct the running residual
if (sample_tau_0) {
tau_0_old <- tau_0
tau_0 <- rep(0.0, p_tau0)
Z_basis_reset <- as.matrix(tau_basis_train)
outcome_train$subtract_vector(
as.numeric(Z_basis_reset %*% matrix(tau_0 - tau_0_old, ncol = 1))
)
}
if (sample_sigma2_global) {
current_sigma2 <- sigma2_init
Expand Down Expand Up @@ -4314,7 +4369,27 @@ extractParameter.bcfmodel <- function(object, term) {
}
}

if (term %in% c("tau_hat_train")) {
if (term %in% c("mu_hat_train", "prognostic_function_train")) {
if (!is.null(object$mu_hat_train)) {
return(object$mu_hat_train)
} else {
stop(
"This model does not have in-sample prognostic function predictions"
)
}
}

if (term %in% c("mu_hat_test", "prognostic_function_test")) {
if (!is.null(object$mu_hat_test)) {
return(object$mu_hat_test)
} else {
stop(
"This model does not have test set prognostic function predictions"
)
}
}

if (term %in% c("tau_hat_train", "cate_train")) {
if (!is.null(object$tau_hat_train)) {
return(object$tau_hat_train)
} else {
Expand All @@ -4324,7 +4399,7 @@ extractParameter.bcfmodel <- function(object, term) {
}
}

if (term %in% c("tau_hat_test")) {
if (term %in% c("tau_hat_test", "cate_test")) {
if (!is.null(object$tau_hat_test)) {
return(object$tau_hat_test)
} else {
Expand Down
88 changes: 77 additions & 11 deletions stochtree/bcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1637,7 +1637,7 @@ def sample(
else:
raise ValueError("sigma2_leaf_mu must be a scalar")
sigma2_leaf_tau = (
np.squeeze(np.var(resid_train) * 0.5) / (num_trees_tau)
np.squeeze(0.5 * np.var(resid_train)) / (num_trees_tau)
if sigma2_leaf_tau is None
else sigma2_leaf_tau
)
Expand Down Expand Up @@ -2306,6 +2306,7 @@ def sample(
)
# Reset adaptive coding parameters
if self.adaptive_coding:
tau_basis_train_old = tau_basis_train.copy()
if self.b0_samples is not None:
current_b_0 = self.b0_samples[forest_ind]
else:
Expand All @@ -2326,6 +2327,23 @@ def sample(
forest_sampler_tau.propagate_basis_update(
forest_dataset_train, residual_train, active_forest_tau
)
# Correct residual for tau_0 component of the basis change
if self.sample_tau_0:
residual_train.add_vector(
-(np.squeeze(tau_basis_train) - np.squeeze(tau_basis_train_old)) * tau_0[0]
)
# Reset tau_0 intercept and correct the running residual
if self.sample_tau_0:
tau_0_old = tau_0.copy()
tau_0 = self.tau_0_samples[:, forest_ind].copy()
Z_basis_gfr = (
tau_basis_train.reshape(-1, 1)
if tau_basis_train.ndim == 1
else tau_basis_train
)
residual_train.add_vector(
-np.squeeze(Z_basis_gfr @ (tau_0 - tau_0_old))
)
# Reset random effects terms
if self.has_rfx:
rfx_model.reset(
Expand Down Expand Up @@ -2405,6 +2423,7 @@ def sample(
)
# Reset adaptive coding parameters
if self.adaptive_coding:
tau_basis_train_old = tau_basis_train.copy()
if previous_b0_samples is not None:
current_b_0 = previous_b0_samples[warmstart_index]
if previous_b1_samples is not None:
Expand All @@ -2421,6 +2440,26 @@ def sample(
forest_sampler_tau.propagate_basis_update(
forest_dataset_train, residual_train, active_forest_tau
)
# Correct residual for tau_0 component of the basis change
if self.sample_tau_0:
residual_train.add_vector(
-(np.squeeze(tau_basis_train) - np.squeeze(tau_basis_train_old)) * tau_0[0]
)
# Reset tau_0 intercept and correct the running residual
if self.sample_tau_0:
prev_tau_0_samples = getattr(previous_bcf_model, "tau_0_samples", None)
if prev_tau_0_samples is not None:
tau_0_old = tau_0.copy()
# tau_0_samples in previous model are in original scale; convert back
tau_0 = (prev_tau_0_samples[:, warmstart_index] / previous_bcf_model.y_std).copy()
Z_basis_ws = (
tau_basis_train.reshape(-1, 1)
if tau_basis_train.ndim == 1
else tau_basis_train
)
residual_train.add_vector(
-np.squeeze(Z_basis_ws @ (tau_0 - tau_0_old))
)
# Reset random effects terms
if self.has_rfx:
rfx_model.reset(
Expand Down Expand Up @@ -2495,6 +2534,7 @@ def sample(
)
# Reset adaptive coding parameters
if self.adaptive_coding:
tau_basis_train_old = tau_basis_train.copy()
current_b_0 = b_0
current_b_1 = b_1
tau_basis_train = (
Expand All @@ -2509,6 +2549,24 @@ def sample(
forest_sampler_tau.propagate_basis_update(
forest_dataset_train, residual_train, active_forest_tau
)
# Correct residual for tau_0 component of the basis change
if self.sample_tau_0:
residual_train.add_vector(
-(np.squeeze(tau_basis_train) - np.squeeze(tau_basis_train_old))
* tau_0[0]
)
# Reset tau_0 to initial value (0) and correct the running residual
if self.sample_tau_0:
tau_0_old = tau_0.copy()
tau_0 = np.zeros_like(tau_0)
Z_basis_reset = (
tau_basis_train.reshape(-1, 1)
if tau_basis_train.ndim == 1
else tau_basis_train
)
residual_train.add_vector(
-np.squeeze(Z_basis_reset @ (tau_0 - tau_0_old))
)
# Reset random effects terms
if self.has_rfx:
rfx_model.root_reset(
Expand Down Expand Up @@ -3351,17 +3409,11 @@ def predict(
if predict_mu_forest:
mu_x = np.mean(mu_x, axis=1)
if predict_tau_forest:
if Z.shape[1] > 1:
tau_x = np.mean(tau_x, axis=2)
else:
tau_x = np.mean(tau_x, axis=1)
tau_x = np.mean(tau_x, axis=1)
if predict_prog_function:
prognostic_function = np.mean(prognostic_function, axis=1)
if predict_cate_function:
if Z.shape[1] > 1:
cate = np.mean(cate, axis=2)
else:
cate = np.mean(cate, axis=1)
cate = np.mean(cate, axis=1)
if predict_rfx:
rfx_preds = np.mean(rfx_preds, axis=1)
if predict_y_hat:
Expand Down Expand Up @@ -4524,20 +4576,34 @@ def extract_parameter(self, term: str) -> np.array:
else:
raise ValueError("This model does not have test set mean function prediction samples")

if term in ["tau_hat_train"]:
if term in ["tau_hat_train", "cate_train"]:
tht = getattr(self, "tau_hat_train", None)
if tht is not None:
return tht
else:
raise ValueError("This model does not have in-sample treatment effect forest predictions")

if term in ["tau_hat_test"]:
if term in ["tau_hat_test", "cate_test"]:
tht = getattr(self, "tau_hat_test", None)
if tht is not None:
return tht
else:
raise ValueError("This model does not have test set treatment effect forest predictions")

if term in ["mu_hat_train", "prognostic_function_train"]:
mht = getattr(self, "mu_hat_train", None)
if mht is not None:
return mht
else:
raise ValueError("This model does not have in-sample prognostic function predictions")

if term in ["mu_hat_test", "prognostic_function_test"]:
mht = getattr(self, "mu_hat_test", None)
if mht is not None:
return mht
else:
raise ValueError("This model does not have test set prognostic function predictions")

if term in ["sigma2_x_train", "var_x_train"]:
s2x = getattr(self, "sigma2_x_train", None)
if s2x is not None:
Expand Down
16 changes: 15 additions & 1 deletion test/R/testthat/test-extract-parameter.R
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,25 @@ test_that("extractParameter.bcfmodel", {
yhtest <- extractParameter(bcf_base, "y_hat_test")
expect_equal(dim(yhtest), c(n_test, num_mcmc))

# tau_hat_train and tau_hat_test
# mu_hat_train / prognostic_function_train
mht <- extractParameter(bcf_base, "mu_hat_train")
expect_equal(dim(mht), c(n_train, num_mcmc))
expect_equal(mht, extractParameter(bcf_base, "prognostic_function_train"))

# mu_hat_test / prognostic_function_test
mhtest <- extractParameter(bcf_base, "mu_hat_test")
expect_equal(dim(mhtest), c(n_test, num_mcmc))
expect_equal(mhtest, extractParameter(bcf_base, "prognostic_function_test"))

# tau_hat_train / cate_train
tht <- extractParameter(bcf_base, "tau_hat_train")
expect_equal(dim(tht), c(n_train, num_mcmc))
expect_equal(tht, extractParameter(bcf_base, "cate_train"))

# tau_hat_test / cate_test
thtest <- extractParameter(bcf_base, "tau_hat_test")
expect_equal(dim(thtest), c(n_test, num_mcmc))
expect_equal(thtest, extractParameter(bcf_base, "cate_test"))

# sigma2_x_train / var_x_train and sigma2_x_test / var_x_test (variance forest)
bcf_var <- bcf(
Expand Down
Loading
Loading