From 4e08e3ef34311285e688f967472326bcdb639266 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 19 Mar 2026 17:15:53 -0500 Subject: [PATCH 1/9] Add stochtree_version stamp and version inference helpers (#317, #318) Implements RFC 0002 sub-issues #317 and #318: - #317: Every to_json() / saveBARTModelToJson() / saveBCFModelToJson() call now writes a top-level "stochtree_version" string field so that JSONs serialized going forward carry an explicit version stamp. - #318: Two new helpers in both R and Python for use by the forthcoming from_json() guards (#319, #320): - Python: _get_stochtree_version() and _infer_stochtree_version(json_string) in stochtree/utils.py - R: getStochtreeVersion() and inferStorchtreeJsonVersion(json_object) in R/utils.R The inference helper fingerprints a JSON by field presence and returns a version bracket string (e.g. "<0.4.1") for use in warning messages when deserializing legacy JSONs without a stamp. Co-Authored-By: Claude Sonnet 4.6 --- R/bart.R | 3 ++- R/bcf.R | 3 ++- R/utils.R | 62 ++++++++++++++++++++++++++++++++++++++++++++++ stochtree/bart.py | 4 ++- stochtree/bcf.py | 4 ++- stochtree/utils.py | 58 +++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 130 insertions(+), 4 deletions(-) diff --git a/R/bart.R b/R/bart.R index a71060c8..c327e0af 100644 --- a/R/bart.R +++ b/R/bart.R @@ -3538,7 +3538,8 @@ saveBARTModelToJson <- function(object) { ) } - # Add global parameters + # Add version stamp and global parameters + jsonobj$add_string("stochtree_version", getStochtreeVersion()) jsonobj$add_scalar("outcome_scale", object$model_params$outcome_scale) jsonobj$add_scalar("outcome_mean", object$model_params$outcome_mean) jsonobj$add_boolean("standardize", object$model_params$standardize) diff --git a/R/bcf.R b/R/bcf.R index 14329b22..f871754a 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -4525,7 +4525,8 @@ saveBCFModelToJson <- function(object) { ) } - # Add global parameters + # Add version stamp and global parameters + jsonobj$add_string("stochtree_version", getStochtreeVersion()) jsonobj$add_scalar("outcome_scale", object$model_params$outcome_scale) jsonobj$add_scalar("outcome_mean", object$model_params$outcome_mean) jsonobj$add_boolean("standardize", object$model_params$standardize) diff --git a/R/utils.R b/R/utils.R index b540464c..7890e988 100644 --- a/R/utils.R +++ b/R/utils.R @@ -1174,3 +1174,65 @@ expand_dims_2d_diag <- function(input, output_size) { } return(output) } + +#' Return the current stochtree package version string. +#' +#' Falls back to "dev" if the package metadata is unavailable (e.g. during +#' development installs that have not been properly registered). +#' +#' @return A character string such as "0.4.1" or "dev". +#' @noRd +getStochtreeVersion <- function() { + tryCatch( + as.character(utils::packageVersion("stochtree")), + error = function(e) "dev" + ) +} + +#' Infer the stochtree version bracket from the fields present in a JSON object. +#' +#' When a JSON was serialized before version stamping was introduced, the version +#' can be approximated by checking which fields are present. The returned string +#' is intended for use in warning messages only, not to gate deserialization +#' behavior. +#' +#' @param json_object A \code{CppJson} object as returned by \code{createCppJson()} +#' or the various \code{saveBARTModelToJson} / \code{saveBCFModelToJson} functions. +#' @return A character string: the stamp value if \code{stochtree_version} is +#' present, otherwise a bracket string such as \code{"<0.4.1"}. +#' @noRd +inferStorchtreeJsonVersion <- function(json_object) { + has_field <- function(name) { + json_contains_field_cpp(json_object$json_ptr, name) + } + has_subfolder_field <- function(subfolder, name) { + json_contains_field_subfolder_cpp(json_object$json_ptr, subfolder, name) + } + + if (has_field("stochtree_version")) { + return(json_object$get_string("stochtree_version")) + } + + # outcome/link in outcome_model were added in ~0.4.1 + if (!has_subfolder_field("outcome_model", "outcome") || + !has_subfolder_field("outcome_model", "link")) { + return("<0.4.1") + } + + # has_rfx_basis / num_rfx_basis were added in ~0.4.0 + if (!has_field("has_rfx_basis") || !has_field("num_rfx_basis")) { + return("<0.4.0") + } + + # internal_propensity_model was added in ~0.3.2 (BCF only) + if (has_field("propensity_covariate") && !has_field("internal_propensity_model")) { + return("<0.3.2") + } + + # rfx_model_spec and preprocessor_metadata were added in ~0.3.0 + if (!has_field("rfx_model_spec") || !has_field("preprocessor_metadata")) { + return("<0.3.0") + } + + return("unknown") +} diff --git a/stochtree/bart.py b/stochtree/bart.py index adf2cc9a..d30a570b 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -27,6 +27,7 @@ _expand_dims_1d, _expand_dims_2d, _expand_dims_2d_diag, + _get_stochtree_version, _posterior_predictive_heuristic_multiplier, _summarize_interval, ) @@ -2988,7 +2989,8 @@ def to_json(self) -> str: if self.has_rfx: bart_json.add_random_effects(self.rfx_container) - # Add global parameters + # Add version stamp and global parameters + bart_json.add_string("stochtree_version", _get_stochtree_version()) bart_json.add_scalar("outcome_scale", self.y_std) bart_json.add_scalar("outcome_mean", self.y_bar) bart_json.add_boolean("standardize", self.standardize) diff --git a/stochtree/bcf.py b/stochtree/bcf.py index 8476e04f..e3410860 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -27,6 +27,7 @@ _expand_dims_1d, _expand_dims_2d, _expand_dims_2d_diag, + _get_stochtree_version, _posterior_predictive_heuristic_multiplier, _summarize_interval, ) @@ -3901,7 +3902,8 @@ def to_json(self) -> str: if self.has_rfx: bcf_json.add_random_effects(self.rfx_container) - # Add global parameters + # Add version stamp and global parameters + bcf_json.add_string("stochtree_version", _get_stochtree_version()) bcf_json.add_scalar("outcome_scale", self.y_std) bcf_json.add_scalar("outcome_mean", self.y_bar) bcf_json.add_boolean("standardize", self.standardize) diff --git a/stochtree/utils.py b/stochtree/utils.py index cfa6a7d8..9ca0f34f 100644 --- a/stochtree/utils.py +++ b/stochtree/utils.py @@ -1,9 +1,67 @@ from typing import Union, Tuple +import json import math import numpy as np +def _get_stochtree_version() -> str: + """Return the current stochtree package version, or 'dev' for editable installs.""" + try: + from importlib.metadata import version, PackageNotFoundError + return version("stochtree") + except Exception: + return "dev" + + +def _infer_stochtree_version(json_string: str) -> str: + """Infer the stochtree version bracket from the fields present in a JSON string. + + When a JSON was serialized before version stamping was introduced, the version + can be approximated by checking which fields are present. The returned string is + intended for use in warning messages only, not to gate deserialization behavior. + + Parameters + ---------- + json_string : str + Raw JSON string as produced by ``to_json()`` / ``saveBARTModelToJsonString()``. + + Returns + ------- + str + The stamp value if ``stochtree_version`` is present, otherwise a bracket + string such as ``"<0.4.1"`` indicating the latest version known to be + missing the observed fields. + """ + try: + d = json.loads(json_string) + except Exception: + return "unknown" + + if "stochtree_version" in d: + return d["stochtree_version"] + + # outcome/link were added in ~0.4.1 + outcome_model = d.get("outcome_model", {}) + if "outcome" not in outcome_model or "link" not in outcome_model: + return "<0.4.1" + + # has_rfx_basis / num_rfx_basis were added in ~0.4.0 + if "has_rfx_basis" not in d or "num_rfx_basis" not in d: + return "<0.4.0" + + # internal_propensity_model was added in ~0.3.2 (BCF only; absent in BART JSON) + # Only flag this if we can confirm it's a BCF JSON by checking a BCF-only field + if "propensity_covariate" in d and "internal_propensity_model" not in d: + return "<0.3.2" + + # rfx_model_spec and covariate_preprocessor were added in ~0.3.0 + if "rfx_model_spec" not in d or "covariate_preprocessor" not in d: + return "<0.3.0" + + return "unknown" + + def _set_output_defaults(outcome: str = "continuous", link: str = None) -> Tuple[str, str]: if outcome is None: raise ValueError("Outcome must be specified") From e57950827ae8596784d60f4976d0f2afae0dc00d Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 19 Mar 2026 20:41:25 -0500 Subject: [PATCH 2/9] Robust BART from_json: safe defaults and warnings for missing fields (#319) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Both from_json() and from_json_string_list() in Python (stochtree/bart.py) and both createBARTModelFromJson() and createBARTModelFromCombinedJson() in R (R/bart.R) now check for each optional field before reading it, falling back to a safe default and emitting a descriptive warning that includes the inferred legacy version bracket from inferStochtreeJsonVersion() / _infer_stochtree_version(). Fields guarded with defaults: - has_rfx_basis / num_rfx_basis → False / 1 - num_chains → 1 - keep_every → 1 - probit_outcome_model → False - outcome_model.outcome / outcome_model.link → "continuous" / "identity" - rfx_model_spec → "" (warns only when has_rfx=True) - covariate_preprocessor / preprocessor_metadata → None / NULL (warns always) Hard errors are preserved for genuinely unrecoverable fields (forest structures, outcome_scale, outcome_mean). Co-Authored-By: Claude Sonnet 4.6 --- R/bart.R | 285 ++++++++++++++++++++++++++++++++++++++-------- R/utils.R | 12 +- stochtree/bart.py | 185 +++++++++++++++++++++++++----- 3 files changed, 401 insertions(+), 81 deletions(-) diff --git a/R/bart.R b/R/bart.R index c327e0af..6deaafdb 100644 --- a/R/bart.R +++ b/R/bart.R @@ -3672,6 +3672,15 @@ createBARTModelFromJson <- function(json_object) { # Initialize the BCF model output <- list() + # Helpers for optional-field presence checks + .ver <- inferStochtreeJsonVersion(json_object) + has_field <- function(name) { + json_contains_field_cpp(json_object$json_ptr, name) + } + has_subfolder_field <- function(subfolder, name) { + json_contains_field_subfolder_cpp(json_object$json_ptr, subfolder, name) + } + # Unpack the forests include_mean_forest <- json_object$get_boolean("include_mean_forest") include_variance_forest <- json_object$get_boolean( @@ -3750,31 +3759,99 @@ createBARTModelFromJson <- function(json_object) { model_params[["include_mean_forest"]] <- include_mean_forest model_params[["include_variance_forest"]] <- include_variance_forest model_params[["has_rfx"]] <- json_object$get_boolean("has_rfx") - model_params[["has_rfx_basis"]] <- json_object$get_boolean("has_rfx_basis") - model_params[["num_rfx_basis"]] <- json_object$get_scalar("num_rfx_basis") + + if (has_field("has_rfx_basis")) { + model_params[["has_rfx_basis"]] <- json_object$get_boolean("has_rfx_basis") + model_params[["num_rfx_basis"]] <- json_object$get_scalar("num_rfx_basis") + } else { + model_params[["has_rfx_basis"]] <- FALSE + model_params[["num_rfx_basis"]] <- 1 + warning(paste0( + "Fields 'has_rfx_basis' and 'num_rfx_basis' not found in JSON (model appears to have been ", + "serialized under stochtree ", + .ver, + "). Defaulting to FALSE / 1. ", + "Re-save your model to suppress this warning." + )) + } + model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") model_params[["num_samples"]] <- json_object$get_scalar("num_samples") - model_params[["num_covariates"]] <- json_object$get_scalar("num_covariates") + model_params[["num_covariates"]] <- if (has_field("num_covariates")) { + json_object$get_scalar("num_covariates") + } else { + NA_real_ + } model_params[["num_basis"]] <- json_object$get_scalar("num_basis") - model_params[["num_chains"]] <- json_object$get_scalar("num_chains") - model_params[["keep_every"]] <- json_object$get_scalar("keep_every") - model_params[["requires_basis"]] <- json_object$get_boolean( - "requires_basis" - ) - model_params[["probit_outcome_model"]] <- json_object$get_boolean( - "probit_outcome_model" - ) - outcome_model_outcome <- json_object$get_string("outcome", "outcome_model") - outcome_model_link <- json_object$get_string("link", "outcome_model") + model_params[["requires_basis"]] <- json_object$get_boolean("requires_basis") + + if (has_field("num_chains")) { + model_params[["num_chains"]] <- json_object$get_scalar("num_chains") + } else { + model_params[["num_chains"]] <- 1 + warning(paste0( + "Field 'num_chains' not found in JSON (model appears to have been serialized under stochtree ", + .ver, + "). Defaulting to 1. Re-save your model to suppress this warning." + )) + } + + if (has_field("keep_every")) { + model_params[["keep_every"]] <- json_object$get_scalar("keep_every") + } else { + model_params[["keep_every"]] <- 1 + warning(paste0( + "Field 'keep_every' not found in JSON (model appears to have been serialized under stochtree ", + .ver, + "). Defaulting to 1. Re-save your model to suppress this warning." + )) + } + + model_params[["probit_outcome_model"]] <- if ( + has_field("probit_outcome_model") + ) { + json_object$get_boolean("probit_outcome_model") + } else { + FALSE + } + + if ( + has_subfolder_field("outcome_model", "outcome") && + has_subfolder_field("outcome_model", "link") + ) { + outcome_model_outcome <- json_object$get_string("outcome", "outcome_model") + outcome_model_link <- json_object$get_string("link", "outcome_model") + } else { + outcome_model_outcome <- "continuous" + outcome_model_link <- "identity" + warning(paste0( + "Fields 'outcome' and 'link' not found under 'outcome_model' in JSON (model appears to have ", + "been serialized under stochtree ", + .ver, + "). Defaulting to outcome='continuous', ", + "link='identity'. Re-save your model to suppress this warning." + )) + } model_params[["outcome_model"]] <- OutcomeModel( outcome = outcome_model_outcome, link = outcome_model_link ) - model_params[["rfx_model_spec"]] <- json_object$get_string( - "rfx_model_spec" - ) + + if (has_field("rfx_model_spec")) { + model_params[["rfx_model_spec"]] <- json_object$get_string("rfx_model_spec") + } else { + model_params[["rfx_model_spec"]] <- "" + if (model_params[["has_rfx"]]) { + warning(paste0( + "Field 'rfx_model_spec' not found in JSON (model appears to have been serialized under ", + "stochtree ", + .ver, + "). Defaulting to ''. Re-save your model to suppress this warning." + )) + } + } if (model_params[["outcome_model"]]$link == "cloglog") { cloglog_num_categories <- json_object$get_scalar("cloglog_num_categories") model_params[["cloglog_num_categories"]] <- cloglog_num_categories @@ -3824,12 +3901,23 @@ createBARTModelFromJson <- function(json_object) { } # Unpack covariate preprocessor - preprocessor_metadata_string <- json_object$get_string( - "preprocessor_metadata" - ) - output[["train_set_metadata"]] <- createPreprocessorFromJsonString( - preprocessor_metadata_string - ) + if (has_field("preprocessor_metadata")) { + preprocessor_metadata_string <- json_object$get_string( + "preprocessor_metadata" + ) + output[["train_set_metadata"]] <- createPreprocessorFromJsonString( + preprocessor_metadata_string + ) + } else { + output[["train_set_metadata"]] <- NULL + warning(paste0( + "Field 'preprocessor_metadata' not found in JSON (model appears to have been serialized ", + "under stochtree ", + .ver, + "). DataFrame covariates will not be supported for prediction. ", + "Re-save your model to suppress this warning." + )) + } class(output) <- "bartmodel" return(output) @@ -3875,6 +3963,19 @@ createBARTModelFromCombinedJson <- function(json_object_list) { # defer to the first json json_object_default <- json_object_list[[1]] + # Helpers for optional-field presence checks + .ver <- inferStochtreeJsonVersion(json_object_default) + has_field <- function(name) { + json_contains_field_cpp(json_object_default$json_ptr, name) + } + has_subfolder_field <- function(subfolder, name) { + json_contains_field_subfolder_cpp( + json_object_default$json_ptr, + subfolder, + name + ) + } + # Unpack the forests include_mean_forest <- json_object_default$get_boolean( "include_mean_forest" @@ -3963,36 +4064,109 @@ createBARTModelFromCombinedJson <- function(json_object_list) { model_params[["include_mean_forest"]] <- include_mean_forest model_params[["include_variance_forest"]] <- include_variance_forest model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx") - model_params[["has_rfx_basis"]] <- json_object_default$get_boolean( - "has_rfx_basis" - ) - model_params[["num_rfx_basis"]] <- json_object_default$get_scalar( - "num_rfx_basis" - ) - model_params[["num_covariates"]] <- json_object_default$get_scalar( - "num_covariates" - ) + + if (has_field("has_rfx_basis")) { + model_params[["has_rfx_basis"]] <- json_object_default$get_boolean( + "has_rfx_basis" + ) + model_params[["num_rfx_basis"]] <- json_object_default$get_scalar( + "num_rfx_basis" + ) + } else { + model_params[["has_rfx_basis"]] <- FALSE + model_params[["num_rfx_basis"]] <- 1 + warning(paste0( + "Fields 'has_rfx_basis' and 'num_rfx_basis' not found in JSON (model appears to have been ", + "serialized under stochtree ", + .ver, + "). Defaulting to FALSE / 1. ", + "Re-save your model to suppress this warning." + )) + } + + model_params[["num_covariates"]] <- if (has_field("num_covariates")) { + json_object_default$get_scalar("num_covariates") + } else { + NA_real_ + } model_params[["num_basis"]] <- json_object_default$get_scalar("num_basis") model_params[["requires_basis"]] <- json_object_default$get_boolean( "requires_basis" ) - model_params[["probit_outcome_model"]] <- json_object_default$get_boolean( - "probit_outcome_model" - ) - outcome_model_outcome <- json_object_default$get_string( - "outcome", - "outcome_model" - ) - outcome_model_link <- json_object_default$get_string("link", "outcome_model") + + model_params[["probit_outcome_model"]] <- if ( + has_field("probit_outcome_model") + ) { + json_object_default$get_boolean("probit_outcome_model") + } else { + FALSE + } + + if ( + has_subfolder_field("outcome_model", "outcome") && + has_subfolder_field("outcome_model", "link") + ) { + outcome_model_outcome <- json_object_default$get_string( + "outcome", + "outcome_model" + ) + outcome_model_link <- json_object_default$get_string( + "link", + "outcome_model" + ) + } else { + outcome_model_outcome <- "continuous" + outcome_model_link <- "identity" + warning(paste0( + "Fields 'outcome' and 'link' not found under 'outcome_model' in JSON (model appears to have ", + "been serialized under stochtree ", + .ver, + "). Defaulting to outcome='continuous', ", + "link='identity'. Re-save your model to suppress this warning." + )) + } model_params[["outcome_model"]] <- OutcomeModel( outcome = outcome_model_outcome, link = outcome_model_link ) - model_params[["rfx_model_spec"]] <- json_object_default$get_string( - "rfx_model_spec" - ) - model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") - model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") + + if (has_field("rfx_model_spec")) { + model_params[["rfx_model_spec"]] <- json_object_default$get_string( + "rfx_model_spec" + ) + } else { + model_params[["rfx_model_spec"]] <- "" + if (model_params[["has_rfx"]]) { + warning(paste0( + "Field 'rfx_model_spec' not found in JSON (model appears to have been serialized under ", + "stochtree ", + .ver, + "). Defaulting to ''. Re-save your model to suppress this warning." + )) + } + } + + if (has_field("num_chains")) { + model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") + } else { + model_params[["num_chains"]] <- 1 + warning(paste0( + "Field 'num_chains' not found in JSON (model appears to have been serialized under stochtree ", + .ver, + "). Defaulting to 1. Re-save your model to suppress this warning." + )) + } + + if (has_field("keep_every")) { + model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") + } else { + model_params[["keep_every"]] <- 1 + warning(paste0( + "Field 'keep_every' not found in JSON (model appears to have been serialized under stochtree ", + .ver, + "). Defaulting to 1. Re-save your model to suppress this warning." + )) + } if (model_params[["outcome_model"]]$link == "cloglog") { cloglog_num_categories <- json_object_default$get_scalar( "cloglog_num_categories" @@ -4343,12 +4517,23 @@ createBARTModelFromCombinedJsonString <- function(json_string_list) { } # Unpack covariate preprocessor - preprocessor_metadata_string <- json_object_default$get_string( - "preprocessor_metadata" - ) - output[["train_set_metadata"]] <- createPreprocessorFromJsonString( - preprocessor_metadata_string - ) + if (has_field("preprocessor_metadata")) { + preprocessor_metadata_string <- json_object_default$get_string( + "preprocessor_metadata" + ) + output[["train_set_metadata"]] <- createPreprocessorFromJsonString( + preprocessor_metadata_string + ) + } else { + output[["train_set_metadata"]] <- NULL + warning(paste0( + "Field 'preprocessor_metadata' not found in JSON (model appears to have been serialized ", + "under stochtree ", + .ver, + "). DataFrame covariates will not be supported for prediction. ", + "Re-save your model to suppress this warning." + )) + } class(output) <- "bartmodel" return(output) diff --git a/R/utils.R b/R/utils.R index 7890e988..f1fe2a79 100644 --- a/R/utils.R +++ b/R/utils.R @@ -1201,7 +1201,7 @@ getStochtreeVersion <- function() { #' @return A character string: the stamp value if \code{stochtree_version} is #' present, otherwise a bracket string such as \code{"<0.4.1"}. #' @noRd -inferStorchtreeJsonVersion <- function(json_object) { +inferStochtreeJsonVersion <- function(json_object) { has_field <- function(name) { json_contains_field_cpp(json_object$json_ptr, name) } @@ -1214,8 +1214,10 @@ inferStorchtreeJsonVersion <- function(json_object) { } # outcome/link in outcome_model were added in ~0.4.1 - if (!has_subfolder_field("outcome_model", "outcome") || - !has_subfolder_field("outcome_model", "link")) { + if ( + !has_subfolder_field("outcome_model", "outcome") || + !has_subfolder_field("outcome_model", "link") + ) { return("<0.4.1") } @@ -1225,7 +1227,9 @@ inferStorchtreeJsonVersion <- function(json_object) { } # internal_propensity_model was added in ~0.3.2 (BCF only) - if (has_field("propensity_covariate") && !has_field("internal_propensity_model")) { + if ( + has_field("propensity_covariate") && !has_field("internal_propensity_model") + ) { return("<0.3.2") } diff --git a/stochtree/bart.py b/stochtree/bart.py index d30a570b..645d9f5c 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -1,3 +1,4 @@ +import json import warnings from math import log, floor from numbers import Integral @@ -28,6 +29,7 @@ _expand_dims_2d, _expand_dims_2d_diag, _get_stochtree_version, + _infer_stochtree_version, _posterior_predictive_heuristic_multiplier, _summarize_interval, ) @@ -3051,7 +3053,10 @@ def from_json(self, json_string: str) -> None: json_string : str JSON string representing model metadata (hyperparameters), sampled parameters, and sampled forests """ - # Parse string to a JSON object in C++ + # Parse string to a JSON object in C++; also keep a plain dict for + # optional-field presence checks (JSONSerializer has no contains_field). + _raw = json.loads(json_string) + _ver = _infer_stochtree_version(json_string) bart_json = JSONSerializer() bart_json.load_from_json_string(json_string) @@ -3059,8 +3064,19 @@ def from_json(self, json_string: str) -> None: self.include_mean_forest = bart_json.get_boolean("include_mean_forest") self.include_variance_forest = bart_json.get_boolean("include_variance_forest") self.has_rfx = bart_json.get_boolean("has_rfx") - self.has_rfx_basis = bart_json.get_boolean("has_rfx_basis") - self.num_rfx_basis = bart_json.get_scalar("num_rfx_basis") + + if "has_rfx_basis" in _raw: + self.has_rfx_basis = bart_json.get_boolean("has_rfx_basis") + self.num_rfx_basis = bart_json.get_scalar("num_rfx_basis") + else: + self.has_rfx_basis = False + self.num_rfx_basis = 1 + warnings.warn( + f"Fields 'has_rfx_basis' and 'num_rfx_basis' not found in JSON (model appears to " + f"have been serialized under stochtree {_ver}). Defaulting to False / 1. " + f"Re-save your model to suppress this warning." + ) + if self.include_mean_forest: # TODO: don't just make this a placeholder that we overwrite self.forest_container_mean = ForestContainer(0, 0, False, False) @@ -3095,16 +3111,60 @@ def from_json(self, json_string: str) -> None: self.num_gfr = bart_json.get_integer("num_gfr") self.num_burnin = bart_json.get_integer("num_burnin") self.num_mcmc = bart_json.get_integer("num_mcmc") - self.num_chains = bart_json.get_integer("num_chains") - self.keep_every = bart_json.get_integer("keep_every") self.num_samples = bart_json.get_integer("num_samples") self.num_basis = bart_json.get_integer("num_basis") self.has_basis = bart_json.get_boolean("requires_basis") - self.probit_outcome_model = bart_json.get_boolean("probit_outcome_model") - outcome_model_outcome = bart_json.get_string("outcome", "outcome_model") - outcome_model_link = bart_json.get_string("link", "outcome_model") + + if "num_chains" in _raw: + self.num_chains = bart_json.get_integer("num_chains") + else: + self.num_chains = 1 + warnings.warn( + f"Field 'num_chains' not found in JSON (model appears to have been serialized " + f"under stochtree {_ver}). Defaulting to 1. " + f"Re-save your model to suppress this warning." + ) + + if "keep_every" in _raw: + self.keep_every = bart_json.get_integer("keep_every") + else: + self.keep_every = 1 + warnings.warn( + f"Field 'keep_every' not found in JSON (model appears to have been serialized " + f"under stochtree {_ver}). Defaulting to 1. " + f"Re-save your model to suppress this warning." + ) + + if "probit_outcome_model" in _raw: + self.probit_outcome_model = bart_json.get_boolean("probit_outcome_model") + else: + self.probit_outcome_model = False + + _outcome_model_raw = _raw.get("outcome_model", {}) + if "outcome" in _outcome_model_raw and "link" in _outcome_model_raw: + outcome_model_outcome = bart_json.get_string("outcome", "outcome_model") + outcome_model_link = bart_json.get_string("link", "outcome_model") + else: + outcome_model_outcome = "continuous" + outcome_model_link = "identity" + warnings.warn( + f"Fields 'outcome' and 'link' not found under 'outcome_model' in JSON (model " + f"appears to have been serialized under stochtree {_ver}). Defaulting to " + f"outcome='continuous', link='identity'. " + f"Re-save your model to suppress this warning." + ) self.outcome_model = OutcomeModel(outcome=outcome_model_outcome, link=outcome_model_link) - self.rfx_model_spec = bart_json.get_string("rfx_model_spec") + + if "rfx_model_spec" in _raw: + self.rfx_model_spec = bart_json.get_string("rfx_model_spec") + else: + self.rfx_model_spec = "" + if self.has_rfx: + warnings.warn( + f"Field 'rfx_model_spec' not found in JSON (model appears to have been " + f"serialized under stochtree {_ver}). Defaulting to ''. " + f"Re-save your model to suppress this warning." + ) # Unpack parameter samples if self.sample_sigma2_global: @@ -3129,9 +3189,17 @@ def from_json(self, json_string: str) -> None: ) # Unpack covariate preprocessor - covariate_preprocessor_string = bart_json.get_string("covariate_preprocessor") - self._covariate_preprocessor = CovariatePreprocessor() - self._covariate_preprocessor.from_json(covariate_preprocessor_string) + if "covariate_preprocessor" in _raw: + covariate_preprocessor_string = bart_json.get_string("covariate_preprocessor") + self._covariate_preprocessor = CovariatePreprocessor() + self._covariate_preprocessor.from_json(covariate_preprocessor_string) + else: + self._covariate_preprocessor = None + warnings.warn( + f"Field 'covariate_preprocessor' not found in JSON (model appears to have been " + f"serialized under stochtree {_ver}). DataFrame covariates will not be supported " + f"for prediction. Re-save your model to suppress this warning." + ) # Mark the deserialized model as "sampled" self.sampled = True @@ -3155,6 +3223,8 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: # For scalar / preprocessing details which aren't sample-dependent, defer to the first json json_object_default = json_object_list[0] + _raw = json.loads(json_string_list[0]) + _ver = _infer_stochtree_version(json_string_list[0]) # Unpack forests self.include_mean_forest = json_object_default.get_boolean( @@ -3202,8 +3272,19 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: # Unpack random effects self.has_rfx = json_object_default.get_boolean("has_rfx") - self.has_rfx_basis = json_object_default.get_boolean("has_rfx_basis") - self.num_rfx_basis = json_object_default.get_scalar("num_rfx_basis") + + if "has_rfx_basis" in _raw: + self.has_rfx_basis = json_object_default.get_boolean("has_rfx_basis") + self.num_rfx_basis = json_object_default.get_scalar("num_rfx_basis") + else: + self.has_rfx_basis = False + self.num_rfx_basis = 1 + warnings.warn( + f"Fields 'has_rfx_basis' and 'num_rfx_basis' not found in JSON (model appears to " + f"have been serialized under stochtree {_ver}). Defaulting to False / 1. " + f"Re-save your model to suppress this warning." + ) + if self.has_rfx: self.rfx_container = RandomEffectsContainer() for i in range(len(json_object_list)): @@ -3224,17 +3305,59 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: self.num_gfr = json_object_default.get_integer("num_gfr") self.num_burnin = json_object_default.get_integer("num_burnin") self.num_mcmc = json_object_default.get_integer("num_mcmc") - self.num_chains = json_object_default.get_integer("num_chains") - self.keep_every = json_object_default.get_integer("keep_every") self.num_basis = json_object_default.get_integer("num_basis") self.has_basis = json_object_default.get_boolean("requires_basis") - self.probit_outcome_model = json_object_default.get_boolean( - "probit_outcome_model" - ) - outcome_model_outcome = json_object_default.get_string("outcome", "outcome_model") - outcome_model_link = json_object_default.get_string("link", "outcome_model") + + if "num_chains" in _raw: + self.num_chains = json_object_default.get_integer("num_chains") + else: + self.num_chains = 1 + warnings.warn( + f"Field 'num_chains' not found in JSON (model appears to have been serialized " + f"under stochtree {_ver}). Defaulting to 1. " + f"Re-save your model to suppress this warning." + ) + + if "keep_every" in _raw: + self.keep_every = json_object_default.get_integer("keep_every") + else: + self.keep_every = 1 + warnings.warn( + f"Field 'keep_every' not found in JSON (model appears to have been serialized " + f"under stochtree {_ver}). Defaulting to 1. " + f"Re-save your model to suppress this warning." + ) + + if "probit_outcome_model" in _raw: + self.probit_outcome_model = json_object_default.get_boolean("probit_outcome_model") + else: + self.probit_outcome_model = False + + _outcome_model_raw = _raw.get("outcome_model", {}) + if "outcome" in _outcome_model_raw and "link" in _outcome_model_raw: + outcome_model_outcome = json_object_default.get_string("outcome", "outcome_model") + outcome_model_link = json_object_default.get_string("link", "outcome_model") + else: + outcome_model_outcome = "continuous" + outcome_model_link = "identity" + warnings.warn( + f"Fields 'outcome' and 'link' not found under 'outcome_model' in JSON (model " + f"appears to have been serialized under stochtree {_ver}). Defaulting to " + f"outcome='continuous', link='identity'. " + f"Re-save your model to suppress this warning." + ) self.outcome_model = OutcomeModel(outcome=outcome_model_outcome, link=outcome_model_link) - self.rfx_model_spec = json_object_default.get_string("rfx_model_spec") + + if "rfx_model_spec" in _raw: + self.rfx_model_spec = json_object_default.get_string("rfx_model_spec") + else: + self.rfx_model_spec = "" + if self.has_rfx: + warnings.warn( + f"Field 'rfx_model_spec' not found in JSON (model appears to have been " + f"serialized under stochtree {_ver}). Defaulting to ''. " + f"Re-save your model to suppress this warning." + ) # Unpack number of samples for i in range(len(json_object_list)): @@ -3297,11 +3420,19 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: ) # Unpack covariate preprocessor - covariate_preprocessor_string = json_object_default.get_string( - "covariate_preprocessor" - ) - self._covariate_preprocessor = CovariatePreprocessor() - self._covariate_preprocessor.from_json(covariate_preprocessor_string) + if "covariate_preprocessor" in _raw: + covariate_preprocessor_string = json_object_default.get_string( + "covariate_preprocessor" + ) + self._covariate_preprocessor = CovariatePreprocessor() + self._covariate_preprocessor.from_json(covariate_preprocessor_string) + else: + self._covariate_preprocessor = None + warnings.warn( + f"Field 'covariate_preprocessor' not found in JSON (model appears to have been " + f"serialized under stochtree {_ver}). DataFrame covariates will not be supported " + f"for prediction. Re-save your model to suppress this warning." + ) # Mark the deserialized model as "sampled" self.sampled = True From b5af1fe73e869d8eab04ac904d30ae5c9733036b Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 19 Mar 2026 20:49:48 -0500 Subject: [PATCH 3/9] Robust BCF from_json with safe defaults and version warnings (#320) Guard all optional fields in BCF deserialization (both from_json and from_json_string_list in Python; createBCFModelFromJson and createBCFModelFromCombinedJson in R) with presence checks, safe defaults, and descriptive warnings that include the inferred legacy version bracket. Fields guarded: has_rfx_basis/num_rfx_basis, multivariate_treatment, num_chains, keep_every, sample_tau_0, internal_propensity_model, probit_outcome_model, outcome_model subfolder, rfx_model_spec, and covariate_preprocessor/preprocessor_metadata. Co-Authored-By: Claude Sonnet 4.6 --- R/bcf.R | 298 ++++++++++++++++++++++++++++++++++++++--------- stochtree/bcf.py | 228 ++++++++++++++++++++++++++++++------ 2 files changed, 431 insertions(+), 95 deletions(-) diff --git a/R/bcf.R b/R/bcf.R index f871754a..b1aa4f4d 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -4684,6 +4684,13 @@ createBCFModelFromJson <- function(json_object) { # Initialize the BCF model output <- list() + # Version inference and presence-check helpers + .ver <- inferStochtreeJsonVersion(json_object) + has_field <- function(name) json_contains_field_cpp(json_object$json_ptr, name) + has_subfolder_field <- function(subfolder, name) { + json_contains_field_subfolder_cpp(json_object$json_ptr, subfolder, name) + } + # Unpack the forests output[["forests_mu"]] <- loadForestContainerJson(json_object, "forest_0") output[["forests_tau"]] <- loadForestContainerJson(json_object, "forest_1") @@ -4757,35 +4764,113 @@ createBCFModelFromJson <- function(json_object) { "propensity_covariate" ) model_params[["has_rfx"]] <- json_object$get_boolean("has_rfx") - model_params[["has_rfx_basis"]] <- json_object$get_boolean("has_rfx_basis") - model_params[["num_rfx_basis"]] <- json_object$get_scalar("num_rfx_basis") + if (has_field("has_rfx_basis")) { + model_params[["has_rfx_basis"]] <- json_object$get_boolean("has_rfx_basis") + model_params[["num_rfx_basis"]] <- json_object$get_scalar("num_rfx_basis") + } else { + model_params[["has_rfx_basis"]] <- FALSE + model_params[["num_rfx_basis"]] <- 1 + warning(sprintf( + "Fields 'has_rfx_basis' and 'num_rfx_basis' not found in BCF JSON (inferred version: %s). Defaulting to has_rfx_basis=FALSE, num_rfx_basis=1.", + .ver + )) + } model_params[["adaptive_coding"]] <- json_object$get_boolean( "adaptive_coding" ) - model_params[["sample_tau_0"]] <- json_object$get_boolean("sample_tau_0") - model_params[["multivariate_treatment"]] <- json_object$get_boolean( - "multivariate_treatment" - ) - model_params[["internal_propensity_model"]] <- json_object$get_boolean( - "internal_propensity_model" - ) + if (has_field("sample_tau_0")) { + model_params[["sample_tau_0"]] <- json_object$get_boolean("sample_tau_0") + } else { + model_params[["sample_tau_0"]] <- FALSE + warning(sprintf( + "Field 'sample_tau_0' not found in BCF JSON (inferred version: %s). Defaulting to FALSE.", + .ver + )) + } + if (has_field("multivariate_treatment")) { + model_params[["multivariate_treatment"]] <- json_object$get_boolean( + "multivariate_treatment" + ) + } else { + model_params[["multivariate_treatment"]] <- FALSE + warning(sprintf( + "Field 'multivariate_treatment' not found in BCF JSON (inferred version: %s). Defaulting to FALSE.", + .ver + )) + } + if (has_field("internal_propensity_model")) { + model_params[["internal_propensity_model"]] <- json_object$get_boolean( + "internal_propensity_model" + ) + } else { + model_params[["internal_propensity_model"]] <- FALSE + warning(sprintf( + "Field 'internal_propensity_model' not found in BCF JSON (inferred version: %s). Defaulting to FALSE.", + .ver + )) + } model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") model_params[["num_samples"]] <- json_object$get_scalar("num_samples") model_params[["num_covariates"]] <- json_object$get_scalar("num_covariates") - model_params[["probit_outcome_model"]] <- json_object$get_boolean( - "probit_outcome_model" - ) - outcome_model_outcome <- json_object$get_string("outcome", "outcome_model") - outcome_model_link <- json_object$get_string("link", "outcome_model") + if (has_field("num_chains")) { + model_params[["num_chains"]] <- json_object$get_scalar("num_chains") + } else { + model_params[["num_chains"]] <- 1 + warning(sprintf( + "Field 'num_chains' not found in BCF JSON (inferred version: %s). Defaulting to 1.", + .ver + )) + } + if (has_field("keep_every")) { + model_params[["keep_every"]] <- json_object$get_scalar("keep_every") + } else { + model_params[["keep_every"]] <- 1 + warning(sprintf( + "Field 'keep_every' not found in BCF JSON (inferred version: %s). Defaulting to 1.", + .ver + )) + } + if (has_field("probit_outcome_model")) { + model_params[["probit_outcome_model"]] <- json_object$get_boolean( + "probit_outcome_model" + ) + } else { + model_params[["probit_outcome_model"]] <- FALSE + warning(sprintf( + "Field 'probit_outcome_model' not found in BCF JSON (inferred version: %s). Defaulting to FALSE.", + .ver + )) + } + if (has_subfolder_field("outcome_model", "outcome")) { + outcome_model_outcome <- json_object$get_string("outcome", "outcome_model") + outcome_model_link <- json_object$get_string("link", "outcome_model") + } else { + outcome_model_outcome <- "continuous" + outcome_model_link <- "identity" + warning(sprintf( + "Subfolder 'outcome_model' not found in BCF JSON (inferred version: %s). Defaulting to outcome='continuous', link='identity'.", + .ver + )) + } model_params[["outcome_model"]] <- OutcomeModel( outcome = outcome_model_outcome, link = outcome_model_link ) - model_params[["rfx_model_spec"]] <- json_object$get_string( - "rfx_model_spec" - ) + if (has_field("rfx_model_spec")) { + model_params[["rfx_model_spec"]] <- json_object$get_string( + "rfx_model_spec" + ) + } else { + model_params[["rfx_model_spec"]] <- "" + if (model_params[["has_rfx"]]) { + warning(sprintf( + "Field 'rfx_model_spec' not found in BCF JSON (inferred version: %s) but has_rfx=TRUE.", + .ver + )) + } + } output[["model_params"]] <- model_params # Unpack sampled parameters @@ -4842,12 +4927,20 @@ createBCFModelFromJson <- function(json_object) { } # Unpack covariate preprocessor - preprocessor_metadata_string <- json_object$get_string( - "preprocessor_metadata" - ) - output[["train_set_metadata"]] <- createPreprocessorFromJsonString( - preprocessor_metadata_string - ) + if (has_field("preprocessor_metadata")) { + preprocessor_metadata_string <- json_object$get_string( + "preprocessor_metadata" + ) + output[["train_set_metadata"]] <- createPreprocessorFromJsonString( + preprocessor_metadata_string + ) + } else { + output[["train_set_metadata"]] <- NULL + warning(sprintf( + "Field 'preprocessor_metadata' not found in BCF JSON (inferred version: %s). Preprocessor is unavailable; prediction may fail.", + .ver + )) + } class(output) <- "bcfmodel" return(output) @@ -4893,6 +4986,13 @@ createBCFModelFromCombinedJson <- function(json_object_list) { # defer to the first json json_object_default <- json_object_list[[1]] + # Version inference and presence-check helpers + .ver <- inferStochtreeJsonVersion(json_object_default) + has_field <- function(name) json_contains_field_cpp(json_object_default$json_ptr, name) + has_subfolder_field <- function(subfolder, name) { + json_contains_field_subfolder_cpp(json_object_default$json_ptr, subfolder, name) + } + # Unpack the forests output[["forests_mu"]] <- loadForestContainerCombinedJson( json_object_list, @@ -4980,44 +5080,120 @@ createBCFModelFromCombinedJson <- function(json_object_list) { "propensity_covariate" ) model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx") - model_params[["has_rfx_basis"]] <- json_object_default$get_boolean( - "has_rfx_basis" - ) - model_params[["num_rfx_basis"]] <- json_object_default$get_scalar( - "num_rfx_basis" - ) + if (has_field("has_rfx_basis")) { + model_params[["has_rfx_basis"]] <- json_object_default$get_boolean( + "has_rfx_basis" + ) + model_params[["num_rfx_basis"]] <- json_object_default$get_scalar( + "num_rfx_basis" + ) + } else { + model_params[["has_rfx_basis"]] <- FALSE + model_params[["num_rfx_basis"]] <- 1 + warning(sprintf( + "Fields 'has_rfx_basis' and 'num_rfx_basis' not found in BCF JSON (inferred version: %s). Defaulting to has_rfx_basis=FALSE, num_rfx_basis=1.", + .ver + )) + } model_params[["num_covariates"]] <- json_object_default$get_scalar( "num_covariates" ) - model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") - model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") + if (has_field("num_chains")) { + model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") + } else { + model_params[["num_chains"]] <- 1 + warning(sprintf( + "Field 'num_chains' not found in BCF JSON (inferred version: %s). Defaulting to 1.", + .ver + )) + } + if (has_field("keep_every")) { + model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") + } else { + model_params[["keep_every"]] <- 1 + warning(sprintf( + "Field 'keep_every' not found in BCF JSON (inferred version: %s). Defaulting to 1.", + .ver + )) + } model_params[["adaptive_coding"]] <- json_object_default$get_boolean( "adaptive_coding" ) - model_params[["sample_tau_0"]] <- json_object_default$get_boolean( - "sample_tau_0" - ) - model_params[["multivariate_treatment"]] <- json_object_default$get_boolean( - "multivariate_treatment" - ) - model_params[[ - "internal_propensity_model" - ]] <- json_object_default$get_boolean("internal_propensity_model") - model_params[["probit_outcome_model"]] <- json_object_default$get_boolean( - "probit_outcome_model" - ) - outcome_model_outcome <- json_object_default$get_string( - "outcome", - "outcome_model" - ) - outcome_model_link <- json_object_default$get_string("link", "outcome_model") + if (has_field("sample_tau_0")) { + model_params[["sample_tau_0"]] <- json_object_default$get_boolean( + "sample_tau_0" + ) + } else { + model_params[["sample_tau_0"]] <- FALSE + warning(sprintf( + "Field 'sample_tau_0' not found in BCF JSON (inferred version: %s). Defaulting to FALSE.", + .ver + )) + } + if (has_field("multivariate_treatment")) { + model_params[["multivariate_treatment"]] <- json_object_default$get_boolean( + "multivariate_treatment" + ) + } else { + model_params[["multivariate_treatment"]] <- FALSE + warning(sprintf( + "Field 'multivariate_treatment' not found in BCF JSON (inferred version: %s). Defaulting to FALSE.", + .ver + )) + } + if (has_field("internal_propensity_model")) { + model_params[[ + "internal_propensity_model" + ]] <- json_object_default$get_boolean("internal_propensity_model") + } else { + model_params[["internal_propensity_model"]] <- FALSE + warning(sprintf( + "Field 'internal_propensity_model' not found in BCF JSON (inferred version: %s). Defaulting to FALSE.", + .ver + )) + } + if (has_field("probit_outcome_model")) { + model_params[["probit_outcome_model"]] <- json_object_default$get_boolean( + "probit_outcome_model" + ) + } else { + model_params[["probit_outcome_model"]] <- FALSE + warning(sprintf( + "Field 'probit_outcome_model' not found in BCF JSON (inferred version: %s). Defaulting to FALSE.", + .ver + )) + } + if (has_subfolder_field("outcome_model", "outcome")) { + outcome_model_outcome <- json_object_default$get_string( + "outcome", + "outcome_model" + ) + outcome_model_link <- json_object_default$get_string("link", "outcome_model") + } else { + outcome_model_outcome <- "continuous" + outcome_model_link <- "identity" + warning(sprintf( + "Subfolder 'outcome_model' not found in BCF JSON (inferred version: %s). Defaulting to outcome='continuous', link='identity'.", + .ver + )) + } model_params[["outcome_model"]] <- OutcomeModel( outcome = outcome_model_outcome, link = outcome_model_link ) - model_params[["rfx_model_spec"]] <- json_object_default$get_string( - "rfx_model_spec" - ) + if (has_field("rfx_model_spec")) { + model_params[["rfx_model_spec"]] <- json_object_default$get_string( + "rfx_model_spec" + ) + } else { + model_params[["rfx_model_spec"]] <- "" + if (model_params[["has_rfx"]]) { + warning(sprintf( + "Field 'rfx_model_spec' not found in BCF JSON (inferred version: %s) but has_rfx=TRUE.", + .ver + )) + } + } # Combine values that are sample-specific for (i in 1:length(json_object_list)) { @@ -5175,12 +5351,20 @@ createBCFModelFromCombinedJson <- function(json_object_list) { } # Unpack covariate preprocessor - preprocessor_metadata_string <- json_object_default$get_string( - "preprocessor_metadata" - ) - output[["train_set_metadata"]] <- createPreprocessorFromJsonString( - preprocessor_metadata_string - ) + if (has_field("preprocessor_metadata")) { + preprocessor_metadata_string <- json_object_default$get_string( + "preprocessor_metadata" + ) + output[["train_set_metadata"]] <- createPreprocessorFromJsonString( + preprocessor_metadata_string + ) + } else { + output[["train_set_metadata"]] <- NULL + warning(sprintf( + "Field 'preprocessor_metadata' not found in BCF JSON (inferred version: %s). Preprocessor is unavailable; prediction may fail.", + .ver + )) + } class(output) <- "bcfmodel" return(output) diff --git a/stochtree/bcf.py b/stochtree/bcf.py index e3410860..34802b7e 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -1,3 +1,4 @@ +import json import warnings from numbers import Integral from typing import Any, Dict, Optional, Union @@ -28,6 +29,7 @@ _expand_dims_2d, _expand_dims_2d_diag, _get_stochtree_version, + _infer_stochtree_version, _posterior_predictive_heuristic_multiplier, _summarize_interval, ) @@ -3979,13 +3981,30 @@ def from_json(self, json_string: str) -> None: # Parse string to a JSON object in C++ bcf_json = JSONSerializer() bcf_json.load_from_json_string(json_string) + _raw = json.loads(json_string) + _ver = _infer_stochtree_version(json_string) # Unpack forests self.include_variance_forest = bcf_json.get_boolean("include_variance_forest") self.has_rfx = bcf_json.get_boolean("has_rfx") - self.has_rfx_basis = bcf_json.get_boolean("has_rfx_basis") - self.num_rfx_basis = bcf_json.get_scalar("num_rfx_basis") - self.multivariate_treatment = bcf_json.get_boolean("multivariate_treatment") + if "has_rfx_basis" in _raw: + self.has_rfx_basis = bcf_json.get_boolean("has_rfx_basis") + self.num_rfx_basis = int(bcf_json.get_scalar("num_rfx_basis")) + else: + self.has_rfx_basis = False + self.num_rfx_basis = 1 + warnings.warn( + f"Fields 'has_rfx_basis' and 'num_rfx_basis' not found in BCF JSON " + f"(inferred version: {_ver}). Defaulting to has_rfx_basis=False, num_rfx_basis=1." + ) + if "multivariate_treatment" in _raw: + self.multivariate_treatment = bcf_json.get_boolean("multivariate_treatment") + else: + self.multivariate_treatment = False + warnings.warn( + f"Field 'multivariate_treatment' not found in BCF JSON " + f"(inferred version: {_ver}). Defaulting to False." + ) # TODO: don't just make this a placeholder that we overwrite self.forest_container_mu = ForestContainer(0, 0, False, False) self.forest_container_mu.forest_container_cpp.LoadFromJson( @@ -4019,20 +4038,71 @@ def from_json(self, json_string: str) -> None: self.num_gfr = int(bcf_json.get_scalar("num_gfr")) self.num_burnin = int(bcf_json.get_scalar("num_burnin")) self.num_mcmc = int(bcf_json.get_scalar("num_mcmc")) - self.num_chains = int(bcf_json.get_scalar("num_chains")) - self.keep_every = int(bcf_json.get_scalar("keep_every")) + if "num_chains" in _raw: + self.num_chains = int(bcf_json.get_scalar("num_chains")) + else: + self.num_chains = 1 + warnings.warn( + f"Field 'num_chains' not found in BCF JSON " + f"(inferred version: {_ver}). Defaulting to 1." + ) + if "keep_every" in _raw: + self.keep_every = int(bcf_json.get_scalar("keep_every")) + else: + self.keep_every = 1 + warnings.warn( + f"Field 'keep_every' not found in BCF JSON " + f"(inferred version: {_ver}). Defaulting to 1." + ) self.num_samples = int(bcf_json.get_scalar("num_samples")) self.adaptive_coding = bcf_json.get_boolean("adaptive_coding") - self.sample_tau_0 = bcf_json.get_boolean("sample_tau_0") + if "sample_tau_0" in _raw: + self.sample_tau_0 = bcf_json.get_boolean("sample_tau_0") + else: + self.sample_tau_0 = False + warnings.warn( + f"Field 'sample_tau_0' not found in BCF JSON " + f"(inferred version: {_ver}). Defaulting to False." + ) self.propensity_covariate = bcf_json.get_string("propensity_covariate") - self.internal_propensity_model = bcf_json.get_boolean( - "internal_propensity_model" - ) - self.probit_outcome_model = bcf_json.get_boolean("probit_outcome_model") - outcome_model_outcome = bcf_json.get_string("outcome", "outcome_model") - outcome_model_link = bcf_json.get_string("link", "outcome_model") + if "internal_propensity_model" in _raw: + self.internal_propensity_model = bcf_json.get_boolean( + "internal_propensity_model" + ) + else: + self.internal_propensity_model = False + warnings.warn( + f"Field 'internal_propensity_model' not found in BCF JSON " + f"(inferred version: {_ver}). Defaulting to False." + ) + if "probit_outcome_model" in _raw: + self.probit_outcome_model = bcf_json.get_boolean("probit_outcome_model") + else: + self.probit_outcome_model = False + warnings.warn( + f"Field 'probit_outcome_model' not found in BCF JSON " + f"(inferred version: {_ver}). Defaulting to False." + ) + if "outcome_model" in _raw: + outcome_model_outcome = bcf_json.get_string("outcome", "outcome_model") + outcome_model_link = bcf_json.get_string("link", "outcome_model") + else: + outcome_model_outcome = "continuous" + outcome_model_link = "identity" + warnings.warn( + f"Subfolder 'outcome_model' not found in BCF JSON " + f"(inferred version: {_ver}). Defaulting to outcome='continuous', link='identity'." + ) self.outcome_model = OutcomeModel(outcome=outcome_model_outcome, link=outcome_model_link) - self.rfx_model_spec = bcf_json.get_string("rfx_model_spec") + if "rfx_model_spec" in _raw: + self.rfx_model_spec = bcf_json.get_string("rfx_model_spec") + else: + self.rfx_model_spec = "" + if self.has_rfx: + warnings.warn( + f"Field 'rfx_model_spec' not found in BCF JSON " + f"(inferred version: {_ver}) but has_rfx=True." + ) # Unpack parameter samples if self.sample_sigma2_global: @@ -4062,9 +4132,16 @@ def from_json(self, json_string: str) -> None: self.bart_propensity_model.from_json(bart_propensity_string) # Unpack covariate preprocessor - covariate_preprocessor_string = bcf_json.get_string("covariate_preprocessor") - self._covariate_preprocessor = CovariatePreprocessor() - self._covariate_preprocessor.from_json(covariate_preprocessor_string) + if "covariate_preprocessor" in _raw: + covariate_preprocessor_string = bcf_json.get_string("covariate_preprocessor") + self._covariate_preprocessor = CovariatePreprocessor() + self._covariate_preprocessor.from_json(covariate_preprocessor_string) + else: + self._covariate_preprocessor = None + warnings.warn( + f"Field 'covariate_preprocessor' not found in BCF JSON " + f"(inferred version: {_ver}). Preprocessor is unavailable; prediction may fail." + ) # Mark the deserialized model as "sampled" self.sampled = True @@ -4088,6 +4165,8 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: # For scalar / preprocessing details which aren't sample-dependent, defer to the first json json_object_default = json_object_list[0] + _raw_default = json.loads(json_string_list[0]) + _ver = _infer_stochtree_version(json_string_list[0]) # Unpack forests # Mu forest @@ -4130,11 +4209,26 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: # Unpack random effects self.has_rfx = json_object_default.get_boolean("has_rfx") - self.has_rfx_basis = json_object_default.get_boolean("has_rfx_basis") - self.num_rfx_basis = json_object_default.get_scalar("num_rfx_basis") - self.multivariate_treatment = json_object_default.get_boolean( - "multivariate_treatment" - ) + if "has_rfx_basis" in _raw_default: + self.has_rfx_basis = json_object_default.get_boolean("has_rfx_basis") + self.num_rfx_basis = int(json_object_default.get_scalar("num_rfx_basis")) + else: + self.has_rfx_basis = False + self.num_rfx_basis = 1 + warnings.warn( + f"Fields 'has_rfx_basis' and 'num_rfx_basis' not found in BCF JSON " + f"(inferred version: {_ver}). Defaulting to has_rfx_basis=False, num_rfx_basis=1." + ) + if "multivariate_treatment" in _raw_default: + self.multivariate_treatment = json_object_default.get_boolean( + "multivariate_treatment" + ) + else: + self.multivariate_treatment = False + warnings.warn( + f"Field 'multivariate_treatment' not found in BCF JSON " + f"(inferred version: {_ver}). Defaulting to False." + ) if self.has_rfx: self.rfx_container = RandomEffectsContainer() for i in range(len(json_object_list)): @@ -4160,23 +4254,74 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: self.num_gfr = json_object_default.get_scalar("num_gfr") self.num_burnin = json_object_default.get_scalar("num_burnin") self.num_mcmc = json_object_default.get_scalar("num_mcmc") - self.num_chains = int(json_object_default.get_scalar("num_chains")) - self.keep_every = int(json_object_default.get_scalar("keep_every")) + if "num_chains" in _raw_default: + self.num_chains = int(json_object_default.get_scalar("num_chains")) + else: + self.num_chains = 1 + warnings.warn( + f"Field 'num_chains' not found in BCF JSON " + f"(inferred version: {_ver}). Defaulting to 1." + ) + if "keep_every" in _raw_default: + self.keep_every = int(json_object_default.get_scalar("keep_every")) + else: + self.keep_every = 1 + warnings.warn( + f"Field 'keep_every' not found in BCF JSON " + f"(inferred version: {_ver}). Defaulting to 1." + ) self.adaptive_coding = json_object_default.get_boolean("adaptive_coding") - self.sample_tau_0 = json_object_default.get_boolean("sample_tau_0") + if "sample_tau_0" in _raw_default: + self.sample_tau_0 = json_object_default.get_boolean("sample_tau_0") + else: + self.sample_tau_0 = False + warnings.warn( + f"Field 'sample_tau_0' not found in BCF JSON " + f"(inferred version: {_ver}). Defaulting to False." + ) self.propensity_covariate = json_object_default.get_string( "propensity_covariate" ) - self.internal_propensity_model = json_object_default.get_boolean( - "internal_propensity_model" - ) - self.probit_outcome_model = json_object_default.get_boolean( - "probit_outcome_model" - ) - outcome_model_outcome = json_object_default.get_string("outcome", "outcome_model") - outcome_model_link = json_object_default.get_string("link", "outcome_model") + if "internal_propensity_model" in _raw_default: + self.internal_propensity_model = json_object_default.get_boolean( + "internal_propensity_model" + ) + else: + self.internal_propensity_model = False + warnings.warn( + f"Field 'internal_propensity_model' not found in BCF JSON " + f"(inferred version: {_ver}). Defaulting to False." + ) + if "probit_outcome_model" in _raw_default: + self.probit_outcome_model = json_object_default.get_boolean( + "probit_outcome_model" + ) + else: + self.probit_outcome_model = False + warnings.warn( + f"Field 'probit_outcome_model' not found in BCF JSON " + f"(inferred version: {_ver}). Defaulting to False." + ) + if "outcome_model" in _raw_default: + outcome_model_outcome = json_object_default.get_string("outcome", "outcome_model") + outcome_model_link = json_object_default.get_string("link", "outcome_model") + else: + outcome_model_outcome = "continuous" + outcome_model_link = "identity" + warnings.warn( + f"Subfolder 'outcome_model' not found in BCF JSON " + f"(inferred version: {_ver}). Defaulting to outcome='continuous', link='identity'." + ) self.outcome_model = OutcomeModel(outcome=outcome_model_outcome, link=outcome_model_link) - self.rfx_model_spec = json_object_default.get_string("rfx_model_spec") + if "rfx_model_spec" in _raw_default: + self.rfx_model_spec = json_object_default.get_string("rfx_model_spec") + else: + self.rfx_model_spec = "" + if self.has_rfx: + warnings.warn( + f"Field 'rfx_model_spec' not found in BCF JSON " + f"(inferred version: {_ver}) but has_rfx=True." + ) # Unpack number of samples for i in range(len(json_object_list)): @@ -4252,11 +4397,18 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: self.bart_propensity_model.from_json(bart_propensity_string) # Unpack covariate preprocessor - covariate_preprocessor_string = json_object_default.get_string( - "covariate_preprocessor" - ) - self._covariate_preprocessor = CovariatePreprocessor() - self._covariate_preprocessor.from_json(covariate_preprocessor_string) + if "covariate_preprocessor" in _raw_default: + covariate_preprocessor_string = json_object_default.get_string( + "covariate_preprocessor" + ) + self._covariate_preprocessor = CovariatePreprocessor() + self._covariate_preprocessor.from_json(covariate_preprocessor_string) + else: + self._covariate_preprocessor = None + warnings.warn( + f"Field 'covariate_preprocessor' not found in BCF JSON " + f"(inferred version: {_ver}). Preprocessor is unavailable; prediction may fail." + ) # Mark the deserialized model as "sampled" self.sampled = True From f1273722480ee6e046fe10ab4595f18667e25047 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 19 Mar 2026 20:54:14 -0500 Subject: [PATCH 4/9] Fix missing has_field/.ver helpers in BART combined JSON string deserializer createBARTModelFromCombinedJsonString was referencing has_field() and .ver without defining them, causing R CMD check errors. Also: - Added the missing guard for preprocessor_metadata in createBARTModelFromCombinedJson (was using json_object loop variable instead of json_object_default) - Made createBARTModelFromCombinedJsonString fully symmetric with createBARTModelFromCombinedJson by adding guards for has_rfx_basis, num_covariates, num_chains, keep_every, probit_outcome_model, outcome_model, and rfx_model_spec Co-Authored-By: Claude Sonnet 4.6 --- R/bart.R | 159 +++++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 131 insertions(+), 28 deletions(-) diff --git a/R/bart.R b/R/bart.R index 6deaafdb..4ceb6ad5 100644 --- a/R/bart.R +++ b/R/bart.R @@ -4272,12 +4272,23 @@ createBARTModelFromCombinedJson <- function(json_object_list) { } # Unpack covariate preprocessor - preprocessor_metadata_string <- json_object$get_string( - "preprocessor_metadata" - ) - output[["train_set_metadata"]] <- createPreprocessorFromJsonString( - preprocessor_metadata_string - ) + if (has_field("preprocessor_metadata")) { + preprocessor_metadata_string <- json_object_default$get_string( + "preprocessor_metadata" + ) + output[["train_set_metadata"]] <- createPreprocessorFromJsonString( + preprocessor_metadata_string + ) + } else { + output[["train_set_metadata"]] <- NULL + warning(paste0( + "Field 'preprocessor_metadata' not found in JSON (model appears to have been serialized ", + "under stochtree ", + .ver, + "). DataFrame covariates will not be supported for prediction. ", + "Re-save your model to suppress this warning." + )) + } class(output) <- "bartmodel" return(output) @@ -4302,6 +4313,19 @@ createBARTModelFromCombinedJsonString <- function(json_string_list) { # defer to the first json json_object_default <- json_object_list[[1]] + # Helpers for optional-field presence checks + .ver <- inferStochtreeJsonVersion(json_object_default) + has_field <- function(name) { + json_contains_field_cpp(json_object_default$json_ptr, name) + } + has_subfolder_field <- function(subfolder, name) { + json_contains_field_subfolder_cpp( + json_object_default$json_ptr, + subfolder, + name + ) + } + # Unpack the forests include_mean_forest <- json_object_default$get_boolean( "include_mean_forest" @@ -4390,36 +4414,115 @@ createBARTModelFromCombinedJsonString <- function(json_string_list) { model_params[["include_mean_forest"]] <- include_mean_forest model_params[["include_variance_forest"]] <- include_variance_forest model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx") - model_params[["has_rfx_basis"]] <- json_object_default$get_boolean( - "has_rfx_basis" - ) - model_params[["num_rfx_basis"]] <- json_object_default$get_scalar( - "num_rfx_basis" - ) - model_params[["num_covariates"]] <- json_object_default$get_scalar( - "num_covariates" - ) + + if (has_field("has_rfx_basis")) { + model_params[["has_rfx_basis"]] <- json_object_default$get_boolean( + "has_rfx_basis" + ) + model_params[["num_rfx_basis"]] <- json_object_default$get_scalar( + "num_rfx_basis" + ) + } else { + model_params[["has_rfx_basis"]] <- FALSE + model_params[["num_rfx_basis"]] <- 1 + warning(paste0( + "Fields 'has_rfx_basis' and 'num_rfx_basis' not found in JSON (model appears to have been ", + "serialized under stochtree ", + .ver, + "). Defaulting to FALSE / 1. ", + "Re-save your model to suppress this warning." + )) + } + + model_params[["num_covariates"]] <- if (has_field("num_covariates")) { + json_object_default$get_scalar("num_covariates") + } else { + NA_real_ + } model_params[["num_basis"]] <- json_object_default$get_scalar("num_basis") - model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") - model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") model_params[["requires_basis"]] <- json_object_default$get_boolean( "requires_basis" ) - model_params[["probit_outcome_model"]] <- json_object_default$get_boolean( - "probit_outcome_model" - ) - outcome_model_outcome <- json_object_default$get_string( - "outcome", - "outcome_model" - ) - outcome_model_link <- json_object_default$get_string("link", "outcome_model") + + model_params[["probit_outcome_model"]] <- if ( + has_field("probit_outcome_model") + ) { + json_object_default$get_boolean("probit_outcome_model") + } else { + FALSE + } + + if ( + has_subfolder_field("outcome_model", "outcome") && + has_subfolder_field("outcome_model", "link") + ) { + outcome_model_outcome <- json_object_default$get_string( + "outcome", + "outcome_model" + ) + outcome_model_link <- json_object_default$get_string("link", "outcome_model") + } else { + outcome_model_outcome <- "continuous" + outcome_model_link <- "identity" + warning(paste0( + "Fields 'outcome' and 'link' not found under 'outcome_model' in JSON (model appears to have ", + "been serialized under stochtree ", + .ver, + "). Defaulting to outcome='continuous', ", + "link='identity'. Re-save your model to suppress this warning." + )) + } model_params[["outcome_model"]] <- OutcomeModel( outcome = outcome_model_outcome, link = outcome_model_link ) - model_params[["rfx_model_spec"]] <- json_object_default$get_string( - "rfx_model_spec" - ) + + if (has_field("rfx_model_spec")) { + model_params[["rfx_model_spec"]] <- json_object_default$get_string( + "rfx_model_spec" + ) + } else { + model_params[["rfx_model_spec"]] <- "" + if (model_params[["has_rfx"]]) { + warning(paste0( + "Field 'rfx_model_spec' not found in JSON (model appears to have been serialized under ", + "stochtree ", + .ver, + "). Defaulting to ''. Re-save your model to suppress this warning." + )) + } + } + + if (has_field("num_chains")) { + model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") + } else { + model_params[["num_chains"]] <- 1 + warning(paste0( + "Field 'num_chains' not found in JSON (model appears to have been serialized under stochtree ", + .ver, + "). Defaulting to 1. Re-save your model to suppress this warning." + )) + } + + if (has_field("keep_every")) { + model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") + } else { + model_params[["keep_every"]] <- 1 + warning(paste0( + "Field 'keep_every' not found in JSON (model appears to have been serialized under stochtree ", + .ver, + "). Defaulting to 1. Re-save your model to suppress this warning." + )) + } + + if (model_params[["outcome_model"]]$link == "cloglog") { + cloglog_num_categories <- json_object_default$get_scalar( + "cloglog_num_categories" + ) + model_params[["cloglog_num_categories"]] <- cloglog_num_categories + } else { + model_params[["cloglog_num_categories"]] <- 0 + } # Combine values that are sample-specific for (i in 1:length(json_object_list)) { From 291de77e5616ad18427f5dcd2883880096f6324e Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 19 Mar 2026 21:04:15 -0500 Subject: [PATCH 5/9] Fix missing helpers and unguarded fields in BCF combined JSON string deserializer createBCFModelFromCombinedJsonString was missing the .ver/.has_field/ has_subfolder_field helpers entirely, causing errors on any optional-field guard. Also: - Guard internal_propensity_model in the initial string-to-object loop using json_contains_field_cpp directly (before json_object_default exists) - Add guards for all optional model_params fields to match createBCFModelFromCombinedJson: has_rfx_basis/num_rfx_basis, num_chains, keep_every, multivariate_treatment, sample_tau_0, internal_propensity_model, probit_outcome_model, outcome_model subfolder, rfx_model_spec - Guard preprocessor_metadata at the end of the function Co-Authored-By: Claude Sonnet 4.6 --- R/bcf.R | 164 ++++++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 129 insertions(+), 35 deletions(-) diff --git a/R/bcf.R b/R/bcf.R index b1aa4f4d..b8cf4b0b 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -5386,7 +5386,10 @@ createBCFModelFromCombinedJsonString <- function(json_string_list) { # Add runtime check for separately serialized propensity models # We don't support merging BCF models with independent propensity models # this way at the moment - if (json_object_list[[i]]$get_boolean("internal_propensity_model")) { + if ( + json_contains_field_cpp(json_object_list[[i]]$json_ptr, "internal_propensity_model") && + json_object_list[[i]]$get_boolean("internal_propensity_model") + ) { stop( "Combining separate BCF models with cached internal propensity models is currently unsupported. To make this work, please first train a propensity model and then pass the propensities as data to the separate BCF models before sampling." ) @@ -5397,6 +5400,13 @@ createBCFModelFromCombinedJsonString <- function(json_string_list) { # defer to the first json json_object_default <- json_object_list[[1]] + # Version inference and presence-check helpers + .ver <- inferStochtreeJsonVersion(json_object_default) + has_field <- function(name) json_contains_field_cpp(json_object_default$json_ptr, name) + has_subfolder_field <- function(subfolder, name) { + json_contains_field_subfolder_cpp(json_object_default$json_ptr, subfolder, name) + } + # Unpack the forests output[["forests_mu"]] <- loadForestContainerCombinedJson( json_object_list, @@ -5484,44 +5494,120 @@ createBCFModelFromCombinedJsonString <- function(json_string_list) { "propensity_covariate" ) model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx") - model_params[["has_rfx_basis"]] <- json_object_default$get_boolean( - "has_rfx_basis" - ) - model_params[["num_rfx_basis"]] <- json_object_default$get_scalar( - "num_rfx_basis" - ) + if (has_field("has_rfx_basis")) { + model_params[["has_rfx_basis"]] <- json_object_default$get_boolean( + "has_rfx_basis" + ) + model_params[["num_rfx_basis"]] <- json_object_default$get_scalar( + "num_rfx_basis" + ) + } else { + model_params[["has_rfx_basis"]] <- FALSE + model_params[["num_rfx_basis"]] <- 1 + warning(sprintf( + "Fields 'has_rfx_basis' and 'num_rfx_basis' not found in BCF JSON (inferred version: %s). Defaulting to has_rfx_basis=FALSE, num_rfx_basis=1.", + .ver + )) + } model_params[["num_covariates"]] <- json_object_default$get_scalar( "num_covariates" ) - model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") - model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") - model_params[["multivariate_treatment"]] <- json_object_default$get_boolean( - "multivariate_treatment" - ) + if (has_field("num_chains")) { + model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") + } else { + model_params[["num_chains"]] <- 1 + warning(sprintf( + "Field 'num_chains' not found in BCF JSON (inferred version: %s). Defaulting to 1.", + .ver + )) + } + if (has_field("keep_every")) { + model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") + } else { + model_params[["keep_every"]] <- 1 + warning(sprintf( + "Field 'keep_every' not found in BCF JSON (inferred version: %s). Defaulting to 1.", + .ver + )) + } + if (has_field("multivariate_treatment")) { + model_params[["multivariate_treatment"]] <- json_object_default$get_boolean( + "multivariate_treatment" + ) + } else { + model_params[["multivariate_treatment"]] <- FALSE + warning(sprintf( + "Field 'multivariate_treatment' not found in BCF JSON (inferred version: %s). Defaulting to FALSE.", + .ver + )) + } model_params[["adaptive_coding"]] <- json_object_default$get_boolean( "adaptive_coding" ) - model_params[["sample_tau_0"]] <- json_object_default$get_boolean( - "sample_tau_0" - ) - model_params[[ - "internal_propensity_model" - ]] <- json_object_default$get_boolean("internal_propensity_model") - model_params[["probit_outcome_model"]] <- json_object_default$get_boolean( - "probit_outcome_model" - ) - outcome_model_outcome <- json_object_default$get_string( - "outcome", - "outcome_model" - ) - outcome_model_link <- json_object_default$get_string("link", "outcome_model") + if (has_field("sample_tau_0")) { + model_params[["sample_tau_0"]] <- json_object_default$get_boolean( + "sample_tau_0" + ) + } else { + model_params[["sample_tau_0"]] <- FALSE + warning(sprintf( + "Field 'sample_tau_0' not found in BCF JSON (inferred version: %s). Defaulting to FALSE.", + .ver + )) + } + if (has_field("internal_propensity_model")) { + model_params[[ + "internal_propensity_model" + ]] <- json_object_default$get_boolean("internal_propensity_model") + } else { + model_params[["internal_propensity_model"]] <- FALSE + warning(sprintf( + "Field 'internal_propensity_model' not found in BCF JSON (inferred version: %s). Defaulting to FALSE.", + .ver + )) + } + if (has_field("probit_outcome_model")) { + model_params[["probit_outcome_model"]] <- json_object_default$get_boolean( + "probit_outcome_model" + ) + } else { + model_params[["probit_outcome_model"]] <- FALSE + warning(sprintf( + "Field 'probit_outcome_model' not found in BCF JSON (inferred version: %s). Defaulting to FALSE.", + .ver + )) + } + if (has_subfolder_field("outcome_model", "outcome")) { + outcome_model_outcome <- json_object_default$get_string( + "outcome", + "outcome_model" + ) + outcome_model_link <- json_object_default$get_string("link", "outcome_model") + } else { + outcome_model_outcome <- "continuous" + outcome_model_link <- "identity" + warning(sprintf( + "Subfolder 'outcome_model' not found in BCF JSON (inferred version: %s). Defaulting to outcome='continuous', link='identity'.", + .ver + )) + } model_params[["outcome_model"]] <- OutcomeModel( outcome = outcome_model_outcome, link = outcome_model_link ) - model_params[["rfx_model_spec"]] <- json_object_default$get_string( - "rfx_model_spec" - ) + if (has_field("rfx_model_spec")) { + model_params[["rfx_model_spec"]] <- json_object_default$get_string( + "rfx_model_spec" + ) + } else { + model_params[["rfx_model_spec"]] <- "" + if (model_params[["has_rfx"]]) { + warning(sprintf( + "Field 'rfx_model_spec' not found in BCF JSON (inferred version: %s) but has_rfx=TRUE.", + .ver + )) + } + } # Combine values that are sample-specific for (i in 1:length(json_object_list)) { @@ -5679,12 +5765,20 @@ createBCFModelFromCombinedJsonString <- function(json_string_list) { } # Unpack covariate preprocessor - preprocessor_metadata_string <- json_object_default$get_string( - "preprocessor_metadata" - ) - output[["train_set_metadata"]] <- createPreprocessorFromJsonString( - preprocessor_metadata_string - ) + if (has_field("preprocessor_metadata")) { + preprocessor_metadata_string <- json_object_default$get_string( + "preprocessor_metadata" + ) + output[["train_set_metadata"]] <- createPreprocessorFromJsonString( + preprocessor_metadata_string + ) + } else { + output[["train_set_metadata"]] <- NULL + warning(sprintf( + "Field 'preprocessor_metadata' not found in BCF JSON (inferred version: %s). Preprocessor is unavailable; prediction may fail.", + .ver + )) + } class(output) <- "bcfmodel" return(output) From a5e47ee13a3cc3811f294c6b44a50acbf1f9eb6d Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 19 Mar 2026 21:07:08 -0500 Subject: [PATCH 6/9] Add BCF JSON serialization roundtrip tests for all deserialization paths The existing BCF serialization tests only covered createBCFModelFromJsonString. Add a test covering all five paths: createBCFModelFromJson (in-memory object), createBCFModelFromJsonString (string), createBCFModelFromJsonFile (file), createBCFModelFromCombinedJson (list of objects), and createBCFModelFromCombinedJsonString (list of strings). The combined-string path would have caught the missing has_field/.ver helpers fixed in the previous commit. Co-Authored-By: Claude Sonnet 4.6 --- test/R/testthat/test-bcf.R | 75 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/test/R/testthat/test-bcf.R b/test/R/testthat/test-bcf.R index b26f25ec..99c25316 100644 --- a/test/R/testthat/test-bcf.R +++ b/test/R/testthat/test-bcf.R @@ -847,3 +847,78 @@ test_that("BCF internal propensity model works with data frame covariates", { expect_true(!is.null(bcf_model$tau_hat_train)) expect_true(!is.null(bcf_model$tau_hat_test)) }) + +test_that("BCF JSON serialization roundtrip covers all deserialization paths", { + skip_on_cran() + + # Generate simulated data + set.seed(42) + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + pi_x <- 0.25 + 0.5 * X[, 1] + mu_x <- pi_x * 5 + tau_x <- X[, 2] * 2 + Z <- rbinom(n, 1, pi_x) + y <- mu_x + Z * tau_x + rnorm(n, 0, 1) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + pi_test <- pi_x[test_inds] + pi_train <- pi_x[train_inds] + Z_test <- Z[test_inds] + Z_train <- Z[train_inds] + y_train <- y[train_inds] + + # Fit model + bcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + num_gfr = 0, + num_burnin = 0, + num_mcmc = 10 + ) + preds_orig <- predict(bcf_model, X_test, Z_test, pi_test) + y_hat_orig <- rowMeans(preds_orig[["y_hat"]]) + tau_hat_orig <- rowMeans(preds_orig[["tau_hat"]]) + + # Path 1: in-memory JSON object + bcf_json <- saveBCFModelToJson(bcf_model) + bcf_rt <- createBCFModelFromJson(bcf_json) + preds_rt <- predict(bcf_rt, X_test, Z_test, pi_test) + expect_equal(rowMeans(preds_rt[["y_hat"]]), y_hat_orig) + expect_equal(rowMeans(preds_rt[["tau_hat"]]), tau_hat_orig) + + # Path 2: JSON string + bcf_json_string <- saveBCFModelToJsonString(bcf_model) + bcf_rt <- createBCFModelFromJsonString(bcf_json_string) + preds_rt <- predict(bcf_rt, X_test, Z_test, pi_test) + expect_equal(rowMeans(preds_rt[["y_hat"]]), y_hat_orig) + expect_equal(rowMeans(preds_rt[["tau_hat"]]), tau_hat_orig) + + # Path 3: JSON file + tmpjson <- tempfile(fileext = ".json") + saveBCFModelToJsonFile(bcf_model, tmpjson) + bcf_rt <- createBCFModelFromJsonFile(tmpjson) + unlink(tmpjson) + preds_rt <- predict(bcf_rt, X_test, Z_test, pi_test) + expect_equal(rowMeans(preds_rt[["y_hat"]]), y_hat_orig) + expect_equal(rowMeans(preds_rt[["tau_hat"]]), tau_hat_orig) + + # Path 4: list of in-memory JSON objects (combined) + bcf_rt <- createBCFModelFromCombinedJson(list(bcf_json)) + preds_rt <- predict(bcf_rt, X_test, Z_test, pi_test) + expect_equal(rowMeans(preds_rt[["y_hat"]]), y_hat_orig) + expect_equal(rowMeans(preds_rt[["tau_hat"]]), tau_hat_orig) + + # Path 5: list of JSON strings (combined) + bcf_rt <- createBCFModelFromCombinedJsonString(list(bcf_json_string)) + preds_rt <- predict(bcf_rt, X_test, Z_test, pi_test) + expect_equal(rowMeans(preds_rt[["y_hat"]]), y_hat_orig) + expect_equal(rowMeans(preds_rt[["tau_hat"]]), tau_hat_orig) +}) From e80b7a8db1725abdddc842691ae3b99afd41e7e7 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 19 Mar 2026 21:19:36 -0500 Subject: [PATCH 7/9] Fix BCF field-name mismatches between R and Python JSON schemas (#321) Standardize R BCF JSON field names to match Python (canonical): - Write: initial_sigma2 -> sigma2_init - Write: b_1_samples -> b1_samples, b_0_samples -> b0_samples Read side accepts both old and new names across all four deserialization functions (createBCFModelFromJson, createBCFModelFromCombinedJson, createBCFModelFromCombinedJsonString) with deprecation warnings for legacy field names. The R in-memory object fields ($b_0_samples, $b_1_samples, $model_params$initial_sigma2) are unchanged. Note: b0/b1 presence check uses has_subfolder_field("parameters", "b1_samples") since these fields live in the "parameters" subfolder, not at the top level. Co-Authored-By: Claude Sonnet 4.6 --- R/bcf.R | 101 +++++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 75 insertions(+), 26 deletions(-) diff --git a/R/bcf.R b/R/bcf.R index b8cf4b0b..93663faf 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -4530,7 +4530,7 @@ saveBCFModelToJson <- function(object) { jsonobj$add_scalar("outcome_scale", object$model_params$outcome_scale) jsonobj$add_scalar("outcome_mean", object$model_params$outcome_mean) jsonobj$add_boolean("standardize", object$model_params$standardize) - jsonobj$add_scalar("initial_sigma2", object$model_params$initial_sigma2) + jsonobj$add_scalar("sigma2_init", object$model_params$initial_sigma2) jsonobj$add_boolean( "sample_sigma2_global", object$model_params$sample_sigma2_global @@ -4609,8 +4609,8 @@ saveBCFModelToJson <- function(object) { ) } if (object$model_params$adaptive_coding) { - jsonobj$add_vector("b_1_samples", object$b_1_samples, "parameters") - jsonobj$add_vector("b_0_samples", object$b_0_samples, "parameters") + jsonobj$add_vector("b1_samples", object$b_1_samples, "parameters") + jsonobj$add_vector("b0_samples", object$b_0_samples, "parameters") } if (object$model_params$sample_tau_0 && !is.null(object$tau_0_samples)) { jsonobj$add_scalar("tau_0_dim", nrow(object$tau_0_samples)) @@ -4749,7 +4749,15 @@ createBCFModelFromJson <- function(json_object) { model_params[["outcome_scale"]] <- json_object$get_scalar("outcome_scale") model_params[["outcome_mean"]] <- json_object$get_scalar("outcome_mean") model_params[["standardize"]] <- json_object$get_boolean("standardize") - model_params[["initial_sigma2"]] <- json_object$get_scalar("initial_sigma2") + if (has_field("sigma2_init")) { + model_params[["initial_sigma2"]] <- json_object$get_scalar("sigma2_init") + } else { + model_params[["initial_sigma2"]] <- json_object$get_scalar("initial_sigma2") + warning(sprintf( + "JSON field 'initial_sigma2' is deprecated; please re-save the model to use 'sigma2_init' (inferred version: %s).", + .ver + )) + } model_params[["sample_sigma2_global"]] <- json_object$get_boolean( "sample_sigma2_global" ) @@ -4893,14 +4901,17 @@ createBCFModelFromJson <- function(json_object) { ) } if (model_params[["adaptive_coding"]]) { - output[["b_1_samples"]] <- json_object$get_vector( - "b_1_samples", - "parameters" - ) - output[["b_0_samples"]] <- json_object$get_vector( - "b_0_samples", - "parameters" - ) + if (has_subfolder_field("parameters", "b1_samples")) { + output[["b_1_samples"]] <- json_object$get_vector("b1_samples", "parameters") + output[["b_0_samples"]] <- json_object$get_vector("b0_samples", "parameters") + } else { + output[["b_1_samples"]] <- json_object$get_vector("b_1_samples", "parameters") + output[["b_0_samples"]] <- json_object$get_vector("b_0_samples", "parameters") + warning(sprintf( + "JSON fields 'b_1_samples'/'b_0_samples' are deprecated; please re-save the model to use 'b1_samples'/'b0_samples' (inferred version: %s).", + .ver + )) + } } if (model_params[["sample_tau_0"]]) { tau_0_dim <- as.integer(json_object$get_scalar("tau_0_dim")) @@ -5063,9 +5074,15 @@ createBCFModelFromCombinedJson <- function(json_object_list) { model_params[["standardize"]] <- json_object_default$get_boolean( "standardize" ) - model_params[["initial_sigma2"]] <- json_object_default$get_scalar( - "initial_sigma2" - ) + if (has_field("sigma2_init")) { + model_params[["initial_sigma2"]] <- json_object_default$get_scalar("sigma2_init") + } else { + model_params[["initial_sigma2"]] <- json_object_default$get_scalar("initial_sigma2") + warning(sprintf( + "JSON field 'initial_sigma2' is deprecated; please re-save the model to use 'sigma2_init' (inferred version: %s).", + .ver + )) + } model_params[["sample_sigma2_global"]] <- json_object_default$get_boolean( "sample_sigma2_global" ) @@ -5296,26 +5313,39 @@ createBCFModelFromCombinedJson <- function(json_object_list) { } } } + .b_use_new_names <- has_subfolder_field("parameters", "b1_samples") if (model_params[["adaptive_coding"]]) { + if (!.b_use_new_names) { + warning(sprintf( + "JSON fields 'b_1_samples'/'b_0_samples' are deprecated; please re-save the model to use 'b1_samples'/'b0_samples' (inferred version: %s).", + .ver + )) + } for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] if (i == 1) { output[["b_1_samples"]] <- json_object$get_vector( - "b_1_samples", + if (.b_use_new_names) "b1_samples" else "b_1_samples", "parameters" ) output[["b_0_samples"]] <- json_object$get_vector( - "b_0_samples", + if (.b_use_new_names) "b0_samples" else "b_0_samples", "parameters" ) } else { output[["b_1_samples"]] <- c( output[["b_1_samples"]], - json_object$get_vector("b_1_samples", "parameters") + json_object$get_vector( + if (.b_use_new_names) "b1_samples" else "b_1_samples", + "parameters" + ) ) output[["b_0_samples"]] <- c( output[["b_0_samples"]], - json_object$get_vector("b_0_samples", "parameters") + json_object$get_vector( + if (.b_use_new_names) "b0_samples" else "b_0_samples", + "parameters" + ) ) } } @@ -5477,9 +5507,15 @@ createBCFModelFromCombinedJsonString <- function(json_string_list) { model_params[["standardize"]] <- json_object_default$get_boolean( "standardize" ) - model_params[["initial_sigma2"]] <- json_object_default$get_scalar( - "initial_sigma2" - ) + if (has_field("sigma2_init")) { + model_params[["initial_sigma2"]] <- json_object_default$get_scalar("sigma2_init") + } else { + model_params[["initial_sigma2"]] <- json_object_default$get_scalar("initial_sigma2") + warning(sprintf( + "JSON field 'initial_sigma2' is deprecated; please re-save the model to use 'sigma2_init' (inferred version: %s).", + .ver + )) + } model_params[["sample_sigma2_global"]] <- json_object_default$get_boolean( "sample_sigma2_global" ) @@ -5710,26 +5746,39 @@ createBCFModelFromCombinedJsonString <- function(json_string_list) { } } } + .b_use_new_names <- has_subfolder_field("parameters", "b1_samples") if (model_params[["adaptive_coding"]]) { + if (!.b_use_new_names) { + warning(sprintf( + "JSON fields 'b_1_samples'/'b_0_samples' are deprecated; please re-save the model to use 'b1_samples'/'b0_samples' (inferred version: %s).", + .ver + )) + } for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] if (i == 1) { output[["b_1_samples"]] <- json_object$get_vector( - "b_1_samples", + if (.b_use_new_names) "b1_samples" else "b_1_samples", "parameters" ) output[["b_0_samples"]] <- json_object$get_vector( - "b_0_samples", + if (.b_use_new_names) "b0_samples" else "b_0_samples", "parameters" ) } else { output[["b_1_samples"]] <- c( output[["b_1_samples"]], - json_object$get_vector("b_1_samples", "parameters") + json_object$get_vector( + if (.b_use_new_names) "b1_samples" else "b_1_samples", + "parameters" + ) ) output[["b_0_samples"]] <- c( output[["b_0_samples"]], - json_object$get_vector("b_0_samples", "parameters") + json_object$get_vector( + if (.b_use_new_names) "b0_samples" else "b_0_samples", + "parameters" + ) ) } } From a9a2c8c697a13a4975f0faefe191ba934ac16579 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 19 Mar 2026 21:22:54 -0500 Subject: [PATCH 8/9] Add tests for BCF JSON field-name standardization (#321) Two new tests in test-serialization.R: - Verify that freshly serialized BCF JSON uses canonical names (sigma2_init, b1_samples, b0_samples) and not the old names - Verify that legacy JSON with old names (initial_sigma2, b_1_samples, b_0_samples) still deserializes correctly and emits deprecation warnings The legacy test works by serializing a model, substituting old names with gsub, then loading the patched JSON string and asserting on warnings and prediction equality. Co-Authored-By: Claude Sonnet 4.6 --- test/R/testthat/test-serialization.R | 72 ++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/test/R/testthat/test-serialization.R b/test/R/testthat/test-serialization.R index 2f0d4aaa..26809da9 100644 --- a/test/R/testthat/test-serialization.R +++ b/test/R/testthat/test-serialization.R @@ -179,3 +179,75 @@ test_that("BCF Serialization (no propensity)", { # Assertion expect_equal(y_hat_orig, y_hat_reloaded) }) + +test_that("BCF JSON uses canonical field names (sigma2_init, b1_samples, b0_samples)", { + skip_on_cran() + + set.seed(1) + n <- 100 + X <- matrix(runif(n * 5), ncol = 5) + pi_x <- 0.25 + 0.5 * X[, 1] + Z <- rbinom(n, 1, pi_x) + y <- pi_x * 5 + Z * X[, 2] * 2 + rnorm(n) + + bcf_model <- bcf( + X_train = X, Z_train = Z, y_train = y, + propensity_train = pi_x, + num_gfr = 0, num_burnin = 0, num_mcmc = 10 + ) + json_string <- saveBCFModelToJsonString(bcf_model) + + # New canonical names must be present + expect_true(grepl('"sigma2_init"', json_string)) + expect_true(grepl('"b1_samples"', json_string)) + expect_true(grepl('"b0_samples"', json_string)) + + # Legacy names must not be present + expect_false(grepl('"initial_sigma2"', json_string)) + expect_false(grepl('"b_1_samples"', json_string)) + expect_false(grepl('"b_0_samples"', json_string)) +}) + +test_that("BCF JSON deserialization handles legacy field names with warnings", { + skip_on_cran() + + set.seed(2) + n <- 100 + X <- matrix(runif(n * 5), ncol = 5) + pi_x <- 0.25 + 0.5 * X[, 1] + Z <- rbinom(n, 1, pi_x) + y <- pi_x * 5 + Z * X[, 2] * 2 + rnorm(n) + X_test <- matrix(runif(20 * 5), ncol = 5) + pi_test <- rep(0.5, 20) + Z_test <- rbinom(20, 1, 0.5) + + bcf_model <- bcf( + X_train = X, Z_train = Z, y_train = y, + propensity_train = pi_x, + num_gfr = 0, num_burnin = 0, num_mcmc = 10 + ) + preds_orig <- predict(bcf_model, X_test, Z_test, pi_test) + + # Simulate a legacy JSON by replacing canonical names with old names + json_new <- saveBCFModelToJsonString(bcf_model) + json_legacy <- gsub('"sigma2_init"', '"initial_sigma2"', json_new, fixed = TRUE) + json_legacy <- gsub('"b1_samples"', '"b_1_samples"', json_legacy, fixed = TRUE) + json_legacy <- gsub('"b0_samples"', '"b_0_samples"', json_legacy, fixed = TRUE) + + # Loading a legacy JSON should emit deprecation warnings for all renamed fields + all_warnings <- character(0) + withCallingHandlers( + bcf_legacy <- createBCFModelFromJsonString(json_legacy), + warning = function(w) { + all_warnings <<- c(all_warnings, conditionMessage(w)) + invokeRestart("muffleWarning") + } + ) + expect_true(any(grepl("initial_sigma2.*deprecated|deprecated.*initial_sigma2", all_warnings))) + expect_true(any(grepl("b_1_samples.*deprecated|b_0_samples.*deprecated|deprecated.*b_[01]_samples", all_warnings))) + + # Predictions must still match + preds_legacy <- predict(bcf_legacy, X_test, Z_test, pi_test) + expect_equal(rowMeans(preds_legacy[["y_hat"]]), rowMeans(preds_orig[["y_hat"]])) + expect_equal(rowMeans(preds_legacy[["tau_hat"]]), rowMeans(preds_orig[["tau_hat"]])) +}) From e56d2a79b4c20ca127d0995a236a48b2bd77ee15 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 20 Mar 2026 00:03:58 -0500 Subject: [PATCH 9/9] Add serialization snapshot and backward-compat tests (#322) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add minimal fixture JSONs (~14–22 KB each) for BART and BCF in both R and Python test directories; untrack *.json globally except for fixtures paths in .gitignore - Add test/R/testthat/test-serialization-compat.R (26 tests): snapshot load + predict tests, and backward-compat tests for missing optional fields (outcome_model, multivariate_treatment, internal_propensity_model, rfx_model_spec, preprocessor_metadata, num_chains/keep_every, has_rfx_basis) - Add test/python/test_serialization_compat.py (15 tests): same coverage for Python BARTModel and BCFModel - All R compat tests have skip_on_cran(); cran-bootstrap.R excludes both the fixture JSON files and test-serialization-compat.R from the tarball - Add any::jsonlite to extra-packages in r-test.yml, r-devel-check.yml, r-python-slow-api-test.yml, and r-valgrind-check.yml Co-Authored-By: Claude Sonnet 4.6 --- .github/workflows/r-devel-check.yml | 2 +- .github/workflows/r-python-slow-api-test.yml | 2 +- .github/workflows/r-test.yml | 2 +- .github/workflows/r-valgrind-check.yml | 2 +- .gitignore | 2 + cran-bootstrap.R | 4 + test/R/testthat/fixtures/bart_mcmc.json | 1 + test/R/testthat/fixtures/bcf_mcmc.json | 1 + test/R/testthat/test-serialization-compat.R | 336 +++++++++++++++++++ test/python/fixtures/bart_mcmc.json | 1 + test/python/fixtures/bcf_mcmc.json | 1 + test/python/test_serialization_compat.py | 272 +++++++++++++++ 12 files changed, 622 insertions(+), 4 deletions(-) create mode 100644 test/R/testthat/fixtures/bart_mcmc.json create mode 100644 test/R/testthat/fixtures/bcf_mcmc.json create mode 100644 test/R/testthat/test-serialization-compat.R create mode 100644 test/python/fixtures/bart_mcmc.json create mode 100644 test/python/fixtures/bcf_mcmc.json create mode 100644 test/python/test_serialization_compat.py diff --git a/.github/workflows/r-devel-check.yml b/.github/workflows/r-devel-check.yml index a82fc6fb..fcaf781e 100644 --- a/.github/workflows/r-devel-check.yml +++ b/.github/workflows/r-devel-check.yml @@ -29,7 +29,7 @@ jobs: - uses: r-lib/actions/setup-r-dependencies@v2 with: - extra-packages: any::testthat, any::decor, any::rcmdcheck + extra-packages: any::testthat, any::decor, any::rcmdcheck, any::jsonlite needs: check - name: Create a CRAN-ready version of the R package diff --git a/.github/workflows/r-python-slow-api-test.yml b/.github/workflows/r-python-slow-api-test.yml index bd2d3930..0e816743 100644 --- a/.github/workflows/r-python-slow-api-test.yml +++ b/.github/workflows/r-python-slow-api-test.yml @@ -57,7 +57,7 @@ jobs: - name: Setup R Package Dependencies uses: r-lib/actions/setup-r-dependencies@v2 with: - extra-packages: any::testthat, any::decor, any::rcmdcheck + extra-packages: any::testthat, any::decor, any::rcmdcheck, any::jsonlite needs: check - name: Create a CRAN-ready version of the R package diff --git a/.github/workflows/r-test.yml b/.github/workflows/r-test.yml index f00938a0..f6e3a38c 100644 --- a/.github/workflows/r-test.yml +++ b/.github/workflows/r-test.yml @@ -39,7 +39,7 @@ jobs: - uses: r-lib/actions/setup-r-dependencies@v2 with: - extra-packages: any::testthat, any::decor, any::rcmdcheck + extra-packages: any::testthat, any::decor, any::rcmdcheck, any::jsonlite needs: check - name: Create a CRAN-ready version of the R package diff --git a/.github/workflows/r-valgrind-check.yml b/.github/workflows/r-valgrind-check.yml index 40f5463e..2450d87b 100644 --- a/.github/workflows/r-valgrind-check.yml +++ b/.github/workflows/r-valgrind-check.yml @@ -20,7 +20,7 @@ jobs: - name: Install dependencies run: | - R -q -e 'pak::pkg_install(c("deps::stochtree_cran", "any::rcmdcheck"), dependencies = TRUE)' + R -q -e 'pak::pkg_install(c("deps::stochtree_cran", "any::rcmdcheck", "any::jsonlite"), dependencies = TRUE)' - uses: r-lib/actions/check-r-package@v2 with: diff --git a/.gitignore b/.gitignore index ee49f886..0b5146bc 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,8 @@ build/ .vscode/ xcode/ *.json +!test/R/testthat/fixtures/*.json +!test/python/fixtures/*.json .vs/ cpp_docs/doxyoutput/html cpp_docs/doxyoutput/xml diff --git a/cran-bootstrap.R b/cran-bootstrap.R index 00ec6081..d2306714 100644 --- a/cran-bootstrap.R +++ b/cran-bootstrap.R @@ -130,6 +130,10 @@ if (pkgdown_build) { # Handle tests separately (move from test/R/ folder to tests/ folder) if (include_tests) { test_files_src <- list.files("test/R", recursive = TRUE, full.names = TRUE) + # Exclude fixture JSON files: large test-only snapshots, not needed on CRAN + test_files_src <- test_files_src[!grepl("/fixtures/.*\\.json$", test_files_src)] + # Exclude backward-compat tests: they depend on jsonlite and fixture files, not suitable for CRAN + test_files_src <- test_files_src[!grepl("test-serialization-compat\\.R$", test_files_src)] test_files_dst <- file.path(cran_dir, gsub("test/R", "tests", test_files_src)) pkg_core_files <- c(pkg_core_files, test_files_src) pkg_core_files_dst <- c(pkg_core_files_dst, test_files_dst) diff --git a/test/R/testthat/fixtures/bart_mcmc.json b/test/R/testthat/fixtures/bart_mcmc.json new file mode 100644 index 00000000..700990ef --- /dev/null +++ b/test/R/testthat/fixtures/bart_mcmc.json @@ -0,0 +1 @@ +{"forests":{"forest_0":{"forest_0":{"is_exponentiated":false,"is_leaf_constant":true,"num_trees":5,"output_dimension":1,"tree_0":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[1.0880185641326535e-17,-0.11277086295725307,0.1496859070030688],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[2,-1,-1],"threshold":[0.4545916165301922,0.0,0.0]},"tree_1":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[1.0880185641326535e-17,-0.04052406562528336,0.0022188249749190555],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[2,-1,-1],"threshold":[0.5708777275344915,0.0,0.0]},"tree_2":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[1.0880185641326535e-17,0.11543648697026145,-0.0028720783768259317],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[1,-1,-1],"threshold":[0.5048831069781892,0.0,0.0]},"tree_3":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[1.0880185641326535e-17,-0.13126474901741714,-0.3095130954734597],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[1,-1,-1],"threshold":[0.5651986454837683,0.0,0.0]},"tree_4":{"category_list":[],"category_list_begin":[0,0,0,0,0],"category_list_end":[0,0,0,0,0],"deleted_nodes":[3,4],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[1.0880185641326535e-17,0.0017481741594564336,0.16571429762916723,-0.25774414767857406,0.2588609628084577],"leaf_vector":[],"leaf_vector_begin":[0,0,0,0,0],"leaf_vector_end":[0,0,0,0,0],"leaves":[1,2],"left":[1,-1,-1,-1,-1],"node_deleted":[false,false,false,true,true],"node_type":[1,0,0,0,0],"num_deleted_nodes":2,"num_nodes":5,"output_dimension":1,"parent":[-1,0,0,2,2],"right":[2,-1,-1,-1,-1],"split_index":[1,-1,1,-1,-1],"threshold":[0.04829446941376727,0.0,0.2383545573177497,0.0,0.0]}},"forest_1":{"is_exponentiated":false,"is_leaf_constant":true,"num_trees":5,"output_dimension":1,"tree_0":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[1.0880185641326535e-17,-0.12832832196465194,-0.03225392815581969],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[2,-1,-1],"threshold":[0.4545916165301922,0.0,0.0]},"tree_1":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[1.0880185641326535e-17,-0.04758297089512515,0.19461926228905674],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[2,-1,-1],"threshold":[0.5708777275344915,0.0,0.0]},"tree_2":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[1.0880185641326535e-17,-0.05495981564200835,-0.13728748299701282],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[1,-1,-1],"threshold":[0.5048831069781892,0.0,0.0]},"tree_3":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[1.0880185641326535e-17,0.22231368332987,0.01027291134949435],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[1,-1,-1],"threshold":[0.5651986454837683,0.0,0.0]},"tree_4":{"category_list":[],"category_list_begin":[0,0,0,0,0],"category_list_end":[0,0,0,0,0],"deleted_nodes":[3,4],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[1.0880185641326535e-17,0.19468122012752942,0.03169404592694457,-0.25774414767857406,0.2588609628084577],"leaf_vector":[],"leaf_vector_begin":[0,0,0,0,0],"leaf_vector_end":[0,0,0,0,0],"leaves":[1,2],"left":[1,-1,-1,-1,-1],"node_deleted":[false,false,false,true,true],"node_type":[1,0,0,0,0],"num_deleted_nodes":2,"num_nodes":5,"output_dimension":1,"parent":[-1,0,0,2,2],"right":[2,-1,-1,-1,-1],"split_index":[1,-1,1,-1,-1],"threshold":[0.04829446941376727,0.0,0.2383545573177497,0.0,0.0]}},"forest_2":{"is_exponentiated":false,"is_leaf_constant":true,"num_trees":5,"output_dimension":1,"tree_0":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[1.0880185641326535e-17,-0.10788101667336364,-0.017227937899200674],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[2,-1,-1],"threshold":[0.4545916165301922,0.0,0.0]},"tree_1":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[1.0880185641326535e-17,-0.03616763334360402,0.06017837575984457],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[2,-1,-1],"threshold":[0.5708777275344915,0.0,0.0]},"tree_2":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[1.0880185641326535e-17,-0.044538838367736165,0.02979981473212648],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[1,-1,-1],"threshold":[0.5048831069781892,0.0,0.0]},"tree_3":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[1.0880185641326535e-17,0.1743989271087603,-0.12290884709055233],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[1,-1,-1],"threshold":[0.5651986454837683,0.0,0.0]},"tree_4":{"category_list":[],"category_list_begin":[0,0,0,0,0],"category_list_end":[0,0,0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0,2],"is_log_scale":false,"leaf_parents":[2],"leaf_value":[1.0880185641326535e-17,0.22187227386986846,0.03169404592694457,0.26995816074244966,-0.3812283527654699],"leaf_vector":[],"leaf_vector_begin":[0,0,0,0,0],"leaf_vector_end":[0,0,0,0,0],"leaves":[1,4,3],"left":[1,-1,4,-1,-1],"node_deleted":[false,false,false,false,false],"node_type":[1,0,1,0,0],"num_deleted_nodes":0,"num_nodes":5,"output_dimension":1,"parent":[-1,0,0,2,2],"right":[2,-1,3,-1,-1],"split_index":[1,-1,0,-1,-1],"threshold":[0.04829446941376727,0.0,0.6749693892376654,0.0,0.0]}},"initialized":true,"is_exponentiated":false,"is_leaf_constant":true,"num_samples":3,"num_trees":5,"output_dimension":1}},"has_rfx":false,"has_rfx_basis":false,"include_mean_forest":true,"include_variance_forest":false,"keep_every":1.0,"num_basis":0.0,"num_burnin":3.0,"num_chains":1.0,"num_covariates":5.0,"num_forests":1,"num_gfr":0.0,"num_mcmc":3.0,"num_numeric_vars":5.0,"num_ordered_cat_vars":0.0,"num_random_effects":0,"num_rfx_basis":0.0,"num_samples":3.0,"num_unordered_cat_vars":0.0,"numeric_vars":["x1","x2","x3","x4","x5","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA"],"outcome_mean":1.5188774524806328,"outcome_model":{"link":"identity","outcome":"continuous"},"outcome_scale":5.726626018388027,"parameters":{"sigma2_global_samples":[38.98734209941173,26.56455710845395,21.437710704586216],"sigma2_leaf_samples":[0.01928980600904151,0.02477847847940415,0.030466505518607117]},"preprocessor_metadata":"{\"feature_types\":[0.0,0.0,0.0,0.0,0.0],\"forests\":{},\"num_forests\":0,\"num_numeric_vars\":5,\"num_ordered_cat_vars\":0,\"num_random_effects\":0,\"num_unordered_cat_vars\":0,\"numeric_vars\":[\"x1\",\"x2\",\"x3\",\"x4\",\"x5\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\"],\"original_var_indices\":[1.0,2.0,3.0,4.0,5.0],\"random_effects\":{}}","probit_outcome_model":false,"random_effects":{},"requires_basis":false,"rfx_model_spec":"custom","sample_sigma2_global":true,"sample_sigma2_leaf":true,"sigma2_init":1.0000000000000002,"standardize":true,"stochtree_version":"0.4.1.9000"} diff --git a/test/R/testthat/fixtures/bcf_mcmc.json b/test/R/testthat/fixtures/bcf_mcmc.json new file mode 100644 index 00000000..7999574b --- /dev/null +++ b/test/R/testthat/fixtures/bcf_mcmc.json @@ -0,0 +1 @@ +{"adaptive_coding":true,"forests":{"forest_0":{"forest_0":{"is_exponentiated":false,"is_leaf_constant":true,"num_trees":5,"output_dimension":1,"tree_0":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[9.237055564881303e-17,-0.5308197039570554,0.22222251687201255],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[0,-1,-1],"threshold":[0.16750638543965574,0.0,0.0]},"tree_1":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[-0.3571111266594255,-0.4110570829930577,-0.6272838110479163],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[2,1],"left":[2,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[1,-1,-1],"split_index":[2,-1,-1],"threshold":[0.16467940440826415,0.0,0.0]},"tree_2":{"category_list":[],"category_list_begin":[0,0,0,0,0],"category_list_end":[0,0,0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0,1],"is_log_scale":false,"leaf_parents":[1],"leaf_value":[9.237055564881303e-17,-0.32512731664499683,0.3065199073254932,0.1465005797402202,-0.02664041763950984],"leaf_vector":[],"leaf_vector_begin":[0,0,0,0,0],"leaf_vector_end":[0,0,0,0,0],"leaves":[2,3,4],"left":[1,3,-1,-1,-1],"node_deleted":[false,false,false,false,false],"node_type":[1,1,0,0,0],"num_deleted_nodes":0,"num_nodes":5,"output_dimension":1,"parent":[-1,0,0,1,1],"right":[2,4,-1,-1,-1],"split_index":[1,4,-1,-1,-1],"threshold":[0.42583129665895003,0.2314946999390916,0.0,0.0,0.0]},"tree_3":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[0.1810927255796227,0.6309083457965031,0.1915446423358686],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[2,1],"left":[2,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[1,-1,-1],"split_index":[5,-1,-1],"threshold":[0.47218739519790437,0.0,0.0]},"tree_4":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[-0.026901687375745633,-0.27035742291022075,-0.2910503270138097],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[3,-1,-1],"threshold":[0.5469386072582907,0.0,0.0]}},"forest_1":{"is_exponentiated":false,"is_leaf_constant":true,"num_trees":5,"output_dimension":1,"tree_0":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[9.237055564881303e-17,-0.4039471521255247,0.17548884324239356],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[0,-1,-1],"threshold":[0.16750638543965574,0.0,0.0]},"tree_1":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[2,1],"has_categorical_split":false,"internal_nodes":[],"is_log_scale":false,"leaf_parents":[],"leaf_value":[-0.4425207860943093,-0.4110570829930577,-0.6272838110479163],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[0],"left":[-1,-1,-1],"node_deleted":[false,true,true],"node_type":[0,0,0],"num_deleted_nodes":2,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[-1,-1,-1],"split_index":[2,-1,-1],"threshold":[0.16467940440826415,0.0,0.0]},"tree_2":{"category_list":[],"category_list_begin":[0,0,0,0,0],"category_list_end":[0,0,0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0,1],"is_log_scale":false,"leaf_parents":[1],"leaf_value":[9.237055564881303e-17,-0.32512731664499683,0.3067453959518575,0.41545675844009733,0.0010530756876630364],"leaf_vector":[],"leaf_vector_begin":[0,0,0,0,0],"leaf_vector_end":[0,0,0,0,0],"leaves":[2,3,4],"left":[1,3,-1,-1,-1],"node_deleted":[false,false,false,false,false],"node_type":[1,1,0,0,0],"num_deleted_nodes":0,"num_nodes":5,"output_dimension":1,"parent":[-1,0,0,1,1],"right":[2,4,-1,-1,-1],"split_index":[1,4,-1,-1,-1],"threshold":[0.42583129665895003,0.2314946999390916,0.0,0.0,0.0]},"tree_3":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[0.1810927255796227,0.47111213458855905,0.04964342053981662],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[2,1],"left":[2,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[1,-1,-1],"split_index":[5,-1,-1],"threshold":[0.47218739519790437,0.0,0.0]},"tree_4":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[-0.026901687375745633,-0.16960540978563884,-0.30994846078958677],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[3,-1,-1],"threshold":[0.5469386072582907,0.0,0.0]}},"forest_2":{"is_exponentiated":false,"is_leaf_constant":true,"num_trees":5,"output_dimension":1,"tree_0":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[9.237055564881303e-17,-0.427817587183343,0.0912909117543235],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[0,-1,-1],"threshold":[0.16750638543965574,0.0,0.0]},"tree_1":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[-0.4425207860943093,-0.394254948283752,-0.024472073457625566],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[5,-1,-1],"threshold":[0.5322191552483091,0.0,0.0]},"tree_2":{"category_list":[],"category_list_begin":[0,0,0,0,0],"category_list_end":[0,0,0,0,0],"deleted_nodes":[3,4],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[9.237055564881303e-17,-0.17501387144585762,0.4025237729681204,0.41545675844009733,0.0010530756876630364],"leaf_vector":[],"leaf_vector_begin":[0,0,0,0,0],"leaf_vector_end":[0,0,0,0,0],"leaves":[2,1],"left":[1,-1,-1,-1,-1],"node_deleted":[false,false,false,true,true],"node_type":[1,0,0,0,0],"num_deleted_nodes":2,"num_nodes":5,"output_dimension":1,"parent":[-1,0,0,1,1],"right":[2,-1,-1,-1,-1],"split_index":[1,4,-1,-1,-1],"threshold":[0.42583129665895003,0.2314946999390916,0.0,0.0,0.0]},"tree_3":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[0.1810927255796227,0.16612116618156503,0.21977517637721566],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[2,1],"left":[2,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[1,-1,-1],"split_index":[5,-1,-1],"threshold":[0.47218739519790437,0.0,0.0]},"tree_4":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[-0.026901687375745633,-0.18459386243026318,-0.0046858073689943405],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[3,-1,-1],"threshold":[0.5469386072582907,0.0,0.0]}},"initialized":true,"is_exponentiated":false,"is_leaf_constant":true,"num_samples":3,"num_trees":5,"output_dimension":1},"forest_1":{"forest_0":{"is_exponentiated":false,"is_leaf_constant":false,"num_trees":5,"output_dimension":1,"tree_0":{"category_list":[],"category_list_begin":[0],"category_list_end":[0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[],"is_log_scale":false,"leaf_parents":[],"leaf_value":[-0.19029305000936797],"leaf_vector":[],"leaf_vector_begin":[0],"leaf_vector_end":[0],"leaves":[0],"left":[-1],"node_deleted":[false],"node_type":[0],"num_deleted_nodes":0,"num_nodes":1,"output_dimension":1,"parent":[-1],"right":[-1],"split_index":[-1],"threshold":[0.0]},"tree_1":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[1,2],"has_categorical_split":false,"internal_nodes":[],"is_log_scale":false,"leaf_parents":[],"leaf_value":[0.4390456129899465,0.19756904796608402,0.470431415164641],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[0],"left":[-1,-1,-1],"node_deleted":[false,true,true],"node_type":[0,0,0],"num_deleted_nodes":2,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[-1,-1,-1],"split_index":[0,-1,-1],"threshold":[0.41234994250247853,0.0,0.0]},"tree_2":{"category_list":[],"category_list_begin":[0],"category_list_end":[0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[],"is_log_scale":false,"leaf_parents":[],"leaf_value":[0.028191174915310476],"leaf_vector":[],"leaf_vector_begin":[0],"leaf_vector_end":[0],"leaves":[0],"left":[-1],"node_deleted":[false],"node_type":[0],"num_deleted_nodes":0,"num_nodes":1,"output_dimension":1,"parent":[-1],"right":[-1],"split_index":[-1],"threshold":[0.0]},"tree_3":{"category_list":[],"category_list_begin":[0],"category_list_end":[0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[],"is_log_scale":false,"leaf_parents":[],"leaf_value":[-0.19824578146979116],"leaf_vector":[],"leaf_vector_begin":[0],"leaf_vector_end":[0],"leaves":[0],"left":[-1],"node_deleted":[false],"node_type":[0],"num_deleted_nodes":0,"num_nodes":1,"output_dimension":1,"parent":[-1],"right":[-1],"split_index":[-1],"threshold":[0.0]},"tree_4":{"category_list":[],"category_list_begin":[0],"category_list_end":[0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[],"is_log_scale":false,"leaf_parents":[],"leaf_value":[0.21194576301866647],"leaf_vector":[],"leaf_vector_begin":[0],"leaf_vector_end":[0],"leaves":[0],"left":[-1],"node_deleted":[false],"node_type":[0],"num_deleted_nodes":0,"num_nodes":1,"output_dimension":1,"parent":[-1],"right":[-1],"split_index":[-1],"threshold":[0.0]}},"forest_1":{"is_exponentiated":false,"is_leaf_constant":false,"num_trees":5,"output_dimension":1,"tree_0":{"category_list":[],"category_list_begin":[0],"category_list_end":[0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[],"is_log_scale":false,"leaf_parents":[],"leaf_value":[0.005768008389537283],"leaf_vector":[],"leaf_vector_begin":[0],"leaf_vector_end":[0],"leaves":[0],"left":[-1],"node_deleted":[false],"node_type":[0],"num_deleted_nodes":0,"num_nodes":1,"output_dimension":1,"parent":[-1],"right":[-1],"split_index":[-1],"threshold":[0.0]},"tree_1":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[1,2],"has_categorical_split":false,"internal_nodes":[],"is_log_scale":false,"leaf_parents":[],"leaf_value":[0.44524931850933525,0.19756904796608402,0.470431415164641],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[0],"left":[-1,-1,-1],"node_deleted":[false,true,true],"node_type":[0,0,0],"num_deleted_nodes":2,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[-1,-1,-1],"split_index":[0,-1,-1],"threshold":[0.41234994250247853,0.0,0.0]},"tree_2":{"category_list":[],"category_list_begin":[0],"category_list_end":[0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[],"is_log_scale":false,"leaf_parents":[],"leaf_value":[-0.1475177199356373],"leaf_vector":[],"leaf_vector_begin":[0],"leaf_vector_end":[0],"leaves":[0],"left":[-1],"node_deleted":[false],"node_type":[0],"num_deleted_nodes":0,"num_nodes":1,"output_dimension":1,"parent":[-1],"right":[-1],"split_index":[-1],"threshold":[0.0]},"tree_3":{"category_list":[],"category_list_begin":[0],"category_list_end":[0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[],"is_log_scale":false,"leaf_parents":[],"leaf_value":[-0.17998183825712963],"leaf_vector":[],"leaf_vector_begin":[0],"leaf_vector_end":[0],"leaves":[0],"left":[-1],"node_deleted":[false],"node_type":[0],"num_deleted_nodes":0,"num_nodes":1,"output_dimension":1,"parent":[-1],"right":[-1],"split_index":[-1],"threshold":[0.0]},"tree_4":{"category_list":[],"category_list_begin":[0],"category_list_end":[0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[],"is_log_scale":false,"leaf_parents":[],"leaf_value":[0.16820308269313614],"leaf_vector":[],"leaf_vector_begin":[0],"leaf_vector_end":[0],"leaves":[0],"left":[-1],"node_deleted":[false],"node_type":[0],"num_deleted_nodes":0,"num_nodes":1,"output_dimension":1,"parent":[-1],"right":[-1],"split_index":[-1],"threshold":[0.0]}},"forest_2":{"is_exponentiated":false,"is_leaf_constant":false,"num_trees":5,"output_dimension":1,"tree_0":{"category_list":[],"category_list_begin":[0],"category_list_end":[0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[],"is_log_scale":false,"leaf_parents":[],"leaf_value":[0.12311594268086734],"leaf_vector":[],"leaf_vector_begin":[0],"leaf_vector_end":[0],"leaves":[0],"left":[-1],"node_deleted":[false],"node_type":[0],"num_deleted_nodes":0,"num_nodes":1,"output_dimension":1,"parent":[-1],"right":[-1],"split_index":[-1],"threshold":[0.0]},"tree_1":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[1,2],"has_categorical_split":false,"internal_nodes":[],"is_log_scale":false,"leaf_parents":[],"leaf_value":[0.13061412466301242,0.19756904796608402,0.470431415164641],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[0],"left":[-1,-1,-1],"node_deleted":[false,true,true],"node_type":[0,0,0],"num_deleted_nodes":2,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[-1,-1,-1],"split_index":[0,-1,-1],"threshold":[0.41234994250247853,0.0,0.0]},"tree_2":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[-0.1475177199356373,0.19835867531563567,0.1010328168922715],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[4,-1,-1],"threshold":[0.8646070190188782,0.0,0.0]},"tree_3":{"category_list":[],"category_list_begin":[0],"category_list_end":[0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[],"is_log_scale":false,"leaf_parents":[],"leaf_value":[-0.2751839785918239],"leaf_vector":[],"leaf_vector_begin":[0],"leaf_vector_end":[0],"leaves":[0],"left":[-1],"node_deleted":[false],"node_type":[0],"num_deleted_nodes":0,"num_nodes":1,"output_dimension":1,"parent":[-1],"right":[-1],"split_index":[-1],"threshold":[0.0]},"tree_4":{"category_list":[],"category_list_begin":[0],"category_list_end":[0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[],"is_log_scale":false,"leaf_parents":[],"leaf_value":[0.17240348914728137],"leaf_vector":[],"leaf_vector_begin":[0],"leaf_vector_end":[0],"leaves":[0],"left":[-1],"node_deleted":[false],"node_type":[0],"num_deleted_nodes":0,"num_nodes":1,"output_dimension":1,"parent":[-1],"right":[-1],"split_index":[-1],"threshold":[0.0]}},"initialized":true,"is_exponentiated":false,"is_leaf_constant":false,"num_samples":3,"num_trees":5,"output_dimension":1}},"has_rfx":false,"has_rfx_basis":false,"include_variance_forest":false,"internal_propensity_model":false,"keep_every":1.0,"multivariate_treatment":false,"num_burnin":3.0,"num_chains":1.0,"num_covariates":5.0,"num_forests":2,"num_gfr":0.0,"num_mcmc":3.0,"num_numeric_vars":5.0,"num_ordered_cat_vars":0.0,"num_random_effects":0,"num_rfx_basis":0.0,"num_samples":3.0,"num_unordered_cat_vars":0.0,"numeric_vars":["x1","x2","x3","x4","x5","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA","NA"],"outcome_mean":3.2783620508409594,"outcome_model":{"link":"identity","outcome":"continuous"},"outcome_scale":1.3752466941156343,"parameters":{"b0_samples":[-1.1280937135690468,-0.8968554508268374,-0.8914413479735037],"b1_samples":[0.787755126873328,0.7898954507582615,0.8171895336080333],"sigma2_global_samples":[1.2639895748026893,0.7952680713067062,1.1561091621127795],"sigma2_leaf_mu_samples":[0.121253975087549,0.10402271497258983,0.05064522904759007],"tau_0_samples":[0.3153960407351745,0.15338163265511487,0.1701309993313882]},"preprocessor_metadata":"{\"feature_types\":[0.0,0.0,0.0,0.0,0.0],\"forests\":{},\"num_forests\":0,\"num_numeric_vars\":5,\"num_ordered_cat_vars\":0,\"num_random_effects\":0,\"num_unordered_cat_vars\":0,\"numeric_vars\":[\"x1\",\"x2\",\"x3\",\"x4\",\"x5\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\",\"NA\"],\"original_var_indices\":[1.0,2.0,3.0,4.0,5.0],\"random_effects\":{}}","probit_outcome_model":false,"propensity_covariate":"prognostic","random_effects":{},"rfx_model_spec":"custom","sample_sigma2_global":true,"sample_sigma2_leaf_mu":true,"sample_sigma2_leaf_tau":false,"sample_tau_0":true,"sigma2_init":1.0000000000000004,"standardize":true,"stochtree_version":"0.4.1.9000","tau_0_dim":1.0} diff --git a/test/R/testthat/test-serialization-compat.R b/test/R/testthat/test-serialization-compat.R new file mode 100644 index 00000000..ff7c7e5b --- /dev/null +++ b/test/R/testthat/test-serialization-compat.R @@ -0,0 +1,336 @@ +# Backward-compatibility deserialization tests +# +# These tests verify that models serialized without certain optional fields +# (as would be produced by older package versions) can still be loaded +# correctly, with appropriate warnings where applicable. +# +# Fixture files (test/R/testthat/fixtures/) are generated once from the +# current package and checked in. They serve as a "snapshot" — if a future +# change breaks the ability to deserialize them, these tests will catch it. + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +#' Read and parse a fixture JSON file into a plain R list +read_fixture_json <- function(fixture_name) { + path <- testthat::test_path("fixtures", fixture_name) + js <- paste(readLines(path, warn = FALSE), collapse = "") + jsonlite::fromJSON(js, simplifyVector = FALSE) +} + +#' Serialise an R list back to a JSON string suitable for createBARTModelFromJsonString +# / createBCFModelFromJsonString. +write_json_string <- function(obj) { + jsonlite::toJSON(obj, auto_unbox = TRUE, digits = NA) +} + +#' Remove one or more top-level fields from a parsed JSON list, then serialise. +strip_fields <- function(obj, ...) { + fields <- c(...) + for (f in fields) { + obj[[f]] <- NULL + } + write_json_string(obj) +} + +#' Collect all warning messages emitted while evaluating `expr`. +collect_warnings <- function(expr) { + warns <- character(0) + withCallingHandlers(expr, warning = function(w) { + warns <<- c(warns, conditionMessage(w)) + invokeRestart("muffleWarning") + }) + warns +} + +# =========================================================================== +# BART snapshot tests +# =========================================================================== + +test_that("BART fixture deserializes and predictions are reproducible", { + skip_on_cran() + skip_if_not_installed("jsonlite") + + fixture_obj <- read_fixture_json("bart_mcmc.json") + json_str <- write_json_string(fixture_obj) + + set.seed(1) + n <- 30 + X <- matrix(runif(n * 5), ncol = 5) + m <- createBARTModelFromJsonString(json_str) + # Just verify the model loads and can predict without error + preds <- predict(m, X = X) + expect_true(is.list(preds)) + expect_true("y_hat" %in% names(preds)) + expect_equal(nrow(preds$y_hat), n) +}) + +test_that("BART roundtrip from fixture matches direct load", { + skip_on_cran() + skip_if_not_installed("jsonlite") + + fixture_obj <- read_fixture_json("bart_mcmc.json") + json_str <- write_json_string(fixture_obj) + + m1 <- createBARTModelFromJsonString(json_str) + m2 <- createBARTModelFromJsonString(json_str) # second load must be identical + + set.seed(99) + X <- matrix(runif(20 * 5), ncol = 5) + p1 <- rowMeans(predict(m1, X = X)$y_hat) + p2 <- rowMeans(predict(m2, X = X)$y_hat) + expect_equal(p1, p2) +}) + +# =========================================================================== +# BCF snapshot tests +# =========================================================================== + +test_that("BCF fixture deserializes and predictions are reproducible", { + skip_on_cran() + skip_if_not_installed("jsonlite") + + fixture_obj <- read_fixture_json("bcf_mcmc.json") + json_str <- write_json_string(fixture_obj) + + set.seed(1) + n <- 30 + X <- matrix(runif(n * 5), ncol = 5) + Z <- rbinom(n, 1, 0.5) + pi <- rep(0.5, n) + + m <- createBCFModelFromJsonString(json_str) + preds <- predict(m, X, Z, pi) + expect_true(is.list(preds)) + expect_true("y_hat" %in% names(preds)) + expect_true("tau_hat" %in% names(preds)) + expect_equal(nrow(preds$y_hat), n) +}) + +# =========================================================================== +# BART backward-compat: missing optional fields +# =========================================================================== + +test_that("BART loads without 'outcome_model' (pre-v0.4.1)", { + skip_on_cran() + skip_if_not_installed("jsonlite") + + fixture_obj <- read_fixture_json("bart_mcmc.json") + json_str <- strip_fields(fixture_obj, "outcome_model", "probit_outcome_model") + + warns <- collect_warnings(m <- createBARTModelFromJsonString(json_str)) + # Should warn about missing outcome_model + expect_true( + any(grepl( + "outcome_model|outcome.*missing|missing.*outcome", + warns, + ignore.case = TRUE + )) || + length(warns) == 0 + ) # no warning is also acceptable for truly optional fields + + set.seed(1) + X <- matrix(runif(20 * 5), ncol = 5) + preds <- predict(m, X = X) + expect_equal(nrow(preds$y_hat), 20) +}) + +test_that("BART loads without 'rfx_model_spec' when has_rfx=FALSE", { + skip_on_cran() + skip_if_not_installed("jsonlite") + + fixture_obj <- read_fixture_json("bart_mcmc.json") + # Ensure has_rfx is FALSE (fixture uses no RFX) + expect_false(isTRUE(fixture_obj$has_rfx)) + + json_str <- strip_fields(fixture_obj, "rfx_model_spec") + + m <- createBARTModelFromJsonString(json_str) + set.seed(1) + X <- matrix(runif(20 * 5), ncol = 5) + preds <- predict(m, X = X) + expect_equal(nrow(preds$y_hat), 20) +}) + +test_that("BART loads without 'preprocessor_metadata' (pre-preprocessor versions)", { + skip_on_cran() + skip_if_not_installed("jsonlite") + + fixture_obj <- read_fixture_json("bart_mcmc.json") + json_str <- strip_fields(fixture_obj, "preprocessor_metadata") + + # Loading should succeed (with a warning about missing preprocessor) + warns <- collect_warnings(m <- createBARTModelFromJsonString(json_str)) + expect_true(any(grepl("preprocessor|preprocess", warns, ignore.case = TRUE))) + # Model object is returned + expect_true(is.list(m)) +}) + +test_that("BART loads without 'num_chains' / 'keep_every'", { + skip_on_cran() + skip_if_not_installed("jsonlite") + + fixture_obj <- read_fixture_json("bart_mcmc.json") + json_str <- strip_fields(fixture_obj, "num_chains", "keep_every") + + m <- createBARTModelFromJsonString(json_str) + set.seed(1) + X <- matrix(runif(20 * 5), ncol = 5) + preds <- predict(m, X = X) + expect_equal(nrow(preds$y_hat), 20) +}) + +# =========================================================================== +# BCF backward-compat: missing optional fields +# =========================================================================== + +test_that("BCF loads without 'outcome_model' (pre-v0.4.1)", { + skip_on_cran() + skip_if_not_installed("jsonlite") + + fixture_obj <- read_fixture_json("bcf_mcmc.json") + json_str <- strip_fields(fixture_obj, "outcome_model", "probit_outcome_model") + + m <- createBCFModelFromJsonString(json_str) + + set.seed(1) + n <- 20 + X <- matrix(runif(n * 5), ncol = 5) + Z <- rbinom(n, 1, 0.5) + pi <- rep(0.5, n) + preds <- predict(m, X, Z, pi) + expect_equal(nrow(preds$y_hat), n) +}) + +test_that("BCF loads without 'multivariate_treatment' (pre-v0.4.0)", { + skip_on_cran() + skip_if_not_installed("jsonlite") + + fixture_obj <- read_fixture_json("bcf_mcmc.json") + json_str <- strip_fields(fixture_obj, "multivariate_treatment") + + m <- createBCFModelFromJsonString(json_str) + + set.seed(1) + n <- 20 + X <- matrix(runif(n * 5), ncol = 5) + Z <- rbinom(n, 1, 0.5) + pi <- rep(0.5, n) + preds <- predict(m, X, Z, pi) + expect_equal(nrow(preds$y_hat), n) +}) + +test_that("BCF loads without 'internal_propensity_model' (pre-v0.3.2)", { + skip_on_cran() + skip_if_not_installed("jsonlite") + + fixture_obj <- read_fixture_json("bcf_mcmc.json") + json_str <- strip_fields(fixture_obj, "internal_propensity_model") + + m <- createBCFModelFromJsonString(json_str) + + set.seed(1) + n <- 20 + X <- matrix(runif(n * 5), ncol = 5) + Z <- rbinom(n, 1, 0.5) + pi <- rep(0.5, n) + preds <- predict(m, X, Z, pi) + expect_equal(nrow(preds$y_hat), n) +}) + +test_that("BCF loads without 'rfx_model_spec' when has_rfx=FALSE", { + skip_on_cran() + skip_if_not_installed("jsonlite") + + fixture_obj <- read_fixture_json("bcf_mcmc.json") + expect_false(isTRUE(fixture_obj$has_rfx)) + + json_str <- strip_fields(fixture_obj, "rfx_model_spec") + + m <- createBCFModelFromJsonString(json_str) + + set.seed(1) + n <- 20 + X <- matrix(runif(n * 5), ncol = 5) + Z <- rbinom(n, 1, 0.5) + pi <- rep(0.5, n) + preds <- predict(m, X, Z, pi) + expect_equal(nrow(preds$y_hat), n) +}) + +test_that("BCF loads without 'preprocessor_metadata'", { + skip_on_cran() + skip_if_not_installed("jsonlite") + + fixture_obj <- read_fixture_json("bcf_mcmc.json") + json_str <- strip_fields(fixture_obj, "preprocessor_metadata") + + # Loading should succeed (with a warning about missing preprocessor) + warns <- collect_warnings(m <- createBCFModelFromJsonString(json_str)) + expect_true(any(grepl("preprocessor|preprocess", warns, ignore.case = TRUE))) + # Model object is returned + expect_true(is.list(m)) +}) + +test_that("BCF loads without 'num_chains' / 'keep_every'", { + skip_on_cran() + skip_if_not_installed("jsonlite") + + fixture_obj <- read_fixture_json("bcf_mcmc.json") + json_str <- strip_fields(fixture_obj, "num_chains", "keep_every") + + m <- createBCFModelFromJsonString(json_str) + + set.seed(1) + n <- 20 + X <- matrix(runif(n * 5), ncol = 5) + Z <- rbinom(n, 1, 0.5) + pi <- rep(0.5, n) + preds <- predict(m, X, Z, pi) + expect_equal(nrow(preds$y_hat), n) +}) + +test_that("BCF loads without 'has_rfx_basis'", { + skip_on_cran() + skip_if_not_installed("jsonlite") + + fixture_obj <- read_fixture_json("bcf_mcmc.json") + json_str <- strip_fields(fixture_obj, "has_rfx_basis") + + m <- createBCFModelFromJsonString(json_str) + + set.seed(1) + n <- 20 + X <- matrix(runif(n * 5), ncol = 5) + Z <- rbinom(n, 1, 0.5) + pi <- rep(0.5, n) + preds <- predict(m, X, Z, pi) + expect_equal(nrow(preds$y_hat), n) +}) + +test_that("BCF loads with multiple missing optional fields simultaneously", { + skip_on_cran() + skip_if_not_installed("jsonlite") + + fixture_obj <- read_fixture_json("bcf_mcmc.json") + # Strip all optional fields (including preprocessor_metadata — prediction not checked) + json_str <- strip_fields( + fixture_obj, + "outcome_model", + "probit_outcome_model", + "multivariate_treatment", + "internal_propensity_model", + "rfx_model_spec", + "num_chains", + "keep_every", + "has_rfx_basis", + "preprocessor_metadata" + ) + + warns <- collect_warnings(m <- createBCFModelFromJsonString(json_str)) + # Model must load + expect_true(is.list(m)) + # At least the preprocessor_metadata warning should fire + expect_true(any(grepl("preprocessor|preprocess", warns, ignore.case = TRUE))) +}) diff --git a/test/python/fixtures/bart_mcmc.json b/test/python/fixtures/bart_mcmc.json new file mode 100644 index 00000000..e08673fe --- /dev/null +++ b/test/python/fixtures/bart_mcmc.json @@ -0,0 +1 @@ +{"covariate_preprocessor":"{\"forests\":{},\"is_fitted\":true,\"num_forests\":0,\"num_onehot_features\":0,\"num_ordinal_features\":0,\"num_original_features\":5,\"num_random_effects\":0,\"onehot_feature_index\":[-1,-1,-1,-1,-1],\"ordinal_feature_index\":[-1,-1,-1,-1,-1],\"original_feature_indices\":[0,1,2,3,4],\"original_feature_types\":[\"float\",\"float\",\"float\",\"float\",\"float\"],\"processed_feature_types\":[0,0,0,0,0],\"random_effects\":{}}","forests":{"forest_0":{"forest_0":{"is_exponentiated":false,"is_leaf_constant":true,"num_trees":5,"output_dimension":1,"tree_0":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[0.0,-0.007664911879713012,-0.24799028316731644],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[1,-1,-1],"threshold":[0.4805559923814212,0.0,0.0]},"tree_1":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[0.0,-0.36579443282357277,0.20200884993554666],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[1,-1,-1],"threshold":[0.193778267085033,0.0,0.0]},"tree_2":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[0.0,0.018106639589123806,0.12351205067572671],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[1,-1,-1],"threshold":[0.6816232941694589,0.0,0.0]},"tree_3":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[0.0,0.041095258011548716,-0.06399946236752171],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[3,-1,-1],"threshold":[0.7112734331640973,0.0,0.0]},"tree_4":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[0.0,0.03569940289175565,-0.12833753827258798],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[1,-1,-1],"threshold":[0.91627488748497,0.0,0.0]}},"forest_1":{"is_exponentiated":false,"is_leaf_constant":true,"num_trees":5,"output_dimension":1,"tree_0":{"category_list":[],"category_list_begin":[0,0,0,0,0],"category_list_end":[0,0,0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0,2],"is_log_scale":false,"leaf_parents":[2],"leaf_value":[0.0,-0.0037771756236487926,-0.24799028316731644,-0.6681583458637876,0.11678604413009135],"leaf_vector":[],"leaf_vector_begin":[0,0,0,0,0],"leaf_vector_end":[0,0,0,0,0],"leaves":[1,3,4],"left":[1,-1,3,-1,-1],"node_deleted":[false,false,false,false,false],"node_type":[1,0,1,0,0],"num_deleted_nodes":0,"num_nodes":5,"output_dimension":1,"parent":[-1,0,0,2,2],"right":[2,-1,4,-1,-1],"split_index":[1,-1,0,-1,-1],"threshold":[0.4805559923814212,0.0,0.29483527135799825,0.0,0.0]},"tree_1":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[0.0,-0.20490086144210348,-0.1816094115237655],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[1,-1,-1],"threshold":[0.193778267085033,0.0,0.0]},"tree_2":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[0.0,-0.07733878457261288,0.15517987347875323],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[1,-1,-1],"threshold":[0.6816232941694589,0.0,0.0]},"tree_3":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[0.0,0.09841716038625006,0.22403694937446325],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[3,-1,-1],"threshold":[0.7112734331640973,0.0,0.0]},"tree_4":{"category_list":[],"category_list_begin":[0,0,0,0,0],"category_list_end":[0,0,0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0,1],"is_log_scale":false,"leaf_parents":[1],"leaf_value":[0.0,0.03569940289175565,-0.11346618168274017,0.020886074904858862,0.05595880095646979],"leaf_vector":[],"leaf_vector_begin":[0,0,0,0,0],"leaf_vector_end":[0,0,0,0,0],"leaves":[2,3,4],"left":[1,3,-1,-1,-1],"node_deleted":[false,false,false,false,false],"node_type":[1,1,0,0,0],"num_deleted_nodes":0,"num_nodes":5,"output_dimension":1,"parent":[-1,0,0,1,1],"right":[2,4,-1,-1,-1],"split_index":[1,0,-1,-1,-1],"threshold":[0.91627488748497,0.8856121515661568,0.0,0.0,0.0]}},"forest_2":{"is_exponentiated":false,"is_leaf_constant":true,"num_trees":5,"output_dimension":1,"tree_0":{"category_list":[],"category_list_begin":[0,0,0,0,0],"category_list_end":[0,0,0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0,2],"is_log_scale":false,"leaf_parents":[2],"leaf_value":[0.0,-0.008220996734382544,-0.24799028316731644,-0.38521611137711764,0.15730692815871508],"leaf_vector":[],"leaf_vector_begin":[0,0,0,0,0],"leaf_vector_end":[0,0,0,0,0],"leaves":[1,3,4],"left":[1,-1,3,-1,-1],"node_deleted":[false,false,false,false,false],"node_type":[1,0,1,0,0],"num_deleted_nodes":0,"num_nodes":5,"output_dimension":1,"parent":[-1,0,0,2,2],"right":[2,-1,4,-1,-1],"split_index":[1,-1,0,-1,-1],"threshold":[0.4805559923814212,0.0,0.29483527135799825,0.0,0.0]},"tree_1":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[0.0,0.0285088432572991,-0.13803728450615854],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[1,-1,-1],"threshold":[0.193778267085033,0.0,0.0]},"tree_2":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[0.0,-0.10454412305549132,0.023709571380789246],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[1,-1,-1],"threshold":[0.6816232941694589,0.0,0.0]},"tree_3":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[0.0,0.40233129943696777,0.15871203782248466],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[3,-1,-1],"threshold":[0.7112734331640973,0.0,0.0]},"tree_4":{"category_list":[],"category_list_begin":[0,0,0,0,0],"category_list_end":[0,0,0,0,0],"deleted_nodes":[3,4],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[0.0,-0.15981823787563715,0.10098835446144322,0.020886074904858862,0.05595880095646979],"leaf_vector":[],"leaf_vector_begin":[0,0,0,0,0],"leaf_vector_end":[0,0,0,0,0],"leaves":[2,1],"left":[1,-1,-1,-1,-1],"node_deleted":[false,false,false,true,true],"node_type":[1,0,0,0,0],"num_deleted_nodes":2,"num_nodes":5,"output_dimension":1,"parent":[-1,0,0,1,1],"right":[2,-1,-1,-1,-1],"split_index":[1,0,-1,-1,-1],"threshold":[0.91627488748497,0.8856121515661568,0.0,0.0,0.0]}},"initialized":true,"is_exponentiated":false,"is_leaf_constant":true,"num_samples":3,"num_trees":5,"output_dimension":1}},"has_rfx":false,"has_rfx_basis":false,"include_mean_forest":true,"include_variance_forest":false,"keep_every":1,"num_basis":0,"num_burnin":3,"num_chains":1,"num_forests":1,"num_gfr":0,"num_mcmc":3,"num_random_effects":0,"num_rfx_basis":0.0,"num_samples":3,"outcome_mean":0.24685854019520698,"outcome_model":{"link":"identity","outcome":"continuous"},"outcome_scale":5.76912761248013,"parameters":{"sigma2_global_samples":[35.05751384857162,26.967130601226234,39.25710729694107],"sigma2_leaf_samples":[0.04580047375916822,0.05893535577149821,0.0289409999519963]},"probit_outcome_model":false,"random_effects":{},"requires_basis":false,"rfx_model_spec":"custom","sample_sigma2_global":true,"sample_sigma2_leaf":true,"sigma2_init":0.9999999999999999,"standardize":true,"stochtree_version":"0.4.1"} \ No newline at end of file diff --git a/test/python/fixtures/bcf_mcmc.json b/test/python/fixtures/bcf_mcmc.json new file mode 100644 index 00000000..49760a63 --- /dev/null +++ b/test/python/fixtures/bcf_mcmc.json @@ -0,0 +1 @@ +{"adaptive_coding":true,"covariate_preprocessor":"{\"forests\":{},\"is_fitted\":true,\"num_forests\":0,\"num_onehot_features\":0,\"num_ordinal_features\":0,\"num_original_features\":5,\"num_random_effects\":0,\"onehot_feature_index\":[-1,-1,-1,-1,-1],\"ordinal_feature_index\":[-1,-1,-1,-1,-1],\"original_feature_indices\":[0,1,2,3,4],\"original_feature_types\":[\"float\",\"float\",\"float\",\"float\",\"float\"],\"processed_feature_types\":[0,0,0,0,0],\"random_effects\":{}}","forests":{"forest_0":{"forest_0":{"is_exponentiated":false,"is_leaf_constant":true,"num_trees":5,"output_dimension":1,"tree_0":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[-8.881784197001253e-19,-0.38781107129211345,0.6151039423332402],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[5,-1,-1],"threshold":[0.48459547622669097,0.0,0.0]},"tree_1":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[0.06440949344718495,0.0277488679983682,0.3187825002467023],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[1,-1,-1],"threshold":[0.8221663569220815,0.0,0.0]},"tree_2":{"category_list":[],"category_list_begin":[0],"category_list_end":[0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[],"is_log_scale":false,"leaf_parents":[],"leaf_value":[0.16039982440224884],"leaf_vector":[],"leaf_vector_begin":[0],"leaf_vector_end":[0],"leaves":[0],"left":[-1],"node_deleted":[false],"node_type":[0],"num_deleted_nodes":0,"num_nodes":1,"output_dimension":1,"parent":[-1],"right":[-1],"split_index":[-1],"threshold":[0.0]},"tree_3":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[-8.881784197001253e-19,-0.37708519671631696,-0.032147505949914945],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[1,-1,-1],"threshold":[0.4451778289698159,0.0,0.0]},"tree_4":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[0.14515959068307077,0.09073870304374561,0.019884521010203934],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[0,-1,-1],"threshold":[0.8634090721040851,0.0,0.0]}},"forest_1":{"is_exponentiated":false,"is_leaf_constant":true,"num_trees":5,"output_dimension":1,"tree_0":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[-8.881784197001253e-19,-0.5432283582930866,0.45648107470526483],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[5,-1,-1],"threshold":[0.48459547622669097,0.0,0.0]},"tree_1":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[0.06440949344718495,0.027293679274528103,-0.2612082537803446],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[1,-1,-1],"threshold":[0.8221663569220815,0.0,0.0]},"tree_2":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[0.16039982440224884,0.3098004891425462,0.5122612583142814],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[3,-1,-1],"threshold":[0.7849707101027427,0.0,0.0]},"tree_3":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[-8.881784197001253e-19,-0.364932502881452,0.003366931315997063],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[1,-1,-1],"threshold":[0.4451778289698159,0.0,0.0]},"tree_4":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[0.14515959068307077,0.242402364764471,0.42565510484539065],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[0,-1,-1],"threshold":[0.8634090721040851,0.0,0.0]}},"forest_2":{"is_exponentiated":false,"is_leaf_constant":true,"num_trees":5,"output_dimension":1,"tree_0":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[-8.881784197001253e-19,-0.4476065037350094,0.08939103127754282],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[5,-1,-1],"threshold":[0.48459547622669097,0.0,0.0]},"tree_1":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[0.06440949344718495,-0.16530335174540084,-0.2792091229804067],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[1,-1,-1],"threshold":[0.8221663569220815,0.0,0.0]},"tree_2":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[0.16039982440224884,0.3975498653392204,0.5597763235026052],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[3,-1,-1],"threshold":[0.7849707101027427,0.0,0.0]},"tree_3":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[-8.881784197001253e-19,-0.37596738123257406,0.32240980402749536],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[1,-1,-1],"threshold":[0.4451778289698159,0.0,0.0]},"tree_4":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[0.14515959068307077,0.14644717575576546,0.11535319464874515],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[0,-1,-1],"threshold":[0.8634090721040851,0.0,0.0]}},"initialized":true,"is_exponentiated":false,"is_leaf_constant":true,"num_samples":3,"num_trees":5,"output_dimension":1},"forest_1":{"forest_0":{"is_exponentiated":false,"is_leaf_constant":false,"num_trees":5,"output_dimension":1,"tree_0":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[1,2],"has_categorical_split":false,"internal_nodes":[],"is_log_scale":false,"leaf_parents":[],"leaf_value":[0.0419703880255408,0.08574444113593233,-0.3527886301807562],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[0],"left":[-1,-1,-1],"node_deleted":[false,true,true],"node_type":[0,0,0],"num_deleted_nodes":2,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[-1,-1,-1],"split_index":[4,-1,-1],"threshold":[0.5092400269377306,0.0,0.0]},"tree_1":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[1,2],"has_categorical_split":false,"internal_nodes":[],"is_log_scale":false,"leaf_parents":[],"leaf_value":[-0.1930936939301599,0.12715827335576363,0.08461752303368719],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[0],"left":[-1,-1,-1],"node_deleted":[false,true,true],"node_type":[0,0,0],"num_deleted_nodes":2,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[-1,-1,-1],"split_index":[4,-1,-1],"threshold":[0.8126306920990549,0.0,0.0]},"tree_2":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[1,2],"has_categorical_split":false,"internal_nodes":[],"is_log_scale":false,"leaf_parents":[],"leaf_value":[0.233602747109784,0.7393414802150162,0.08122925162439473],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[0],"left":[-1,-1,-1],"node_deleted":[false,true,true],"node_type":[0,0,0],"num_deleted_nodes":2,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[-1,-1,-1],"split_index":[4,-1,-1],"threshold":[0.07802654963941784,0.0,0.0]},"tree_3":{"category_list":[],"category_list_begin":[0],"category_list_end":[0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[],"is_log_scale":false,"leaf_parents":[],"leaf_value":[0.19248673499983215],"leaf_vector":[],"leaf_vector_begin":[0],"leaf_vector_end":[0],"leaves":[0],"left":[-1],"node_deleted":[false],"node_type":[0],"num_deleted_nodes":0,"num_nodes":1,"output_dimension":1,"parent":[-1],"right":[-1],"split_index":[-1],"threshold":[0.0]},"tree_4":{"category_list":[],"category_list_begin":[0],"category_list_end":[0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[],"is_log_scale":false,"leaf_parents":[],"leaf_value":[-0.1570039807493809],"leaf_vector":[],"leaf_vector_begin":[0],"leaf_vector_end":[0],"leaves":[0],"left":[-1],"node_deleted":[false],"node_type":[0],"num_deleted_nodes":0,"num_nodes":1,"output_dimension":1,"parent":[-1],"right":[-1],"split_index":[-1],"threshold":[0.0]}},"forest_1":{"is_exponentiated":false,"is_leaf_constant":false,"num_trees":5,"output_dimension":1,"tree_0":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[1,2],"has_categorical_split":false,"internal_nodes":[],"is_log_scale":false,"leaf_parents":[],"leaf_value":[0.23484777482160438,0.08574444113593233,-0.3527886301807562],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[0],"left":[-1,-1,-1],"node_deleted":[false,true,true],"node_type":[0,0,0],"num_deleted_nodes":2,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[-1,-1,-1],"split_index":[4,-1,-1],"threshold":[0.5092400269377306,0.0,0.0]},"tree_1":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[1,2],"has_categorical_split":false,"internal_nodes":[],"is_log_scale":false,"leaf_parents":[],"leaf_value":[-0.20500284078015898,0.12715827335576363,0.08461752303368719],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[0],"left":[-1,-1,-1],"node_deleted":[false,true,true],"node_type":[0,0,0],"num_deleted_nodes":2,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[-1,-1,-1],"split_index":[4,-1,-1],"threshold":[0.8126306920990549,0.0,0.0]},"tree_2":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[1,2],"has_categorical_split":false,"internal_nodes":[],"is_log_scale":false,"leaf_parents":[],"leaf_value":[0.0966417920161897,0.7393414802150162,0.08122925162439473],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[0],"left":[-1,-1,-1],"node_deleted":[false,true,true],"node_type":[0,0,0],"num_deleted_nodes":2,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[-1,-1,-1],"split_index":[4,-1,-1],"threshold":[0.07802654963941784,0.0,0.0]},"tree_3":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[0.19248673499983215,-0.12176656031525897,-0.08943489974077189],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[1,2],"left":[1,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[2,-1,-1],"split_index":[1,-1,-1],"threshold":[0.4443908976503099,0.0,0.0]},"tree_4":{"category_list":[],"category_list_begin":[0],"category_list_end":[0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[],"is_log_scale":false,"leaf_parents":[],"leaf_value":[-0.01886464446542896],"leaf_vector":[],"leaf_vector_begin":[0],"leaf_vector_end":[0],"leaves":[0],"left":[-1],"node_deleted":[false],"node_type":[0],"num_deleted_nodes":0,"num_nodes":1,"output_dimension":1,"parent":[-1],"right":[-1],"split_index":[-1],"threshold":[0.0]}},"forest_2":{"is_exponentiated":false,"is_leaf_constant":false,"num_trees":5,"output_dimension":1,"tree_0":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[0.23484777482160438,0.14144770539332582,0.09770158952660454],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[2,1],"left":[2,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[1,-1,-1],"split_index":[1,-1,-1],"threshold":[0.7003565348153463,0.0,0.0]},"tree_1":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[1,2],"has_categorical_split":false,"internal_nodes":[],"is_log_scale":false,"leaf_parents":[],"leaf_value":[0.19981143741129573,0.12715827335576363,0.08461752303368719],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[0],"left":[-1,-1,-1],"node_deleted":[false,true,true],"node_type":[0,0,0],"num_deleted_nodes":2,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[-1,-1,-1],"split_index":[4,-1,-1],"threshold":[0.8126306920990549,0.0,0.0]},"tree_2":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[0],"is_log_scale":false,"leaf_parents":[0],"leaf_value":[0.0966417920161897,-0.23766375126463393,0.040773820331649197],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[2,1],"left":[2,-1,-1],"node_deleted":[false,false,false],"node_type":[1,0,0],"num_deleted_nodes":0,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[1,-1,-1],"split_index":[2,-1,-1],"threshold":[0.8141920842835061,0.0,0.0]},"tree_3":{"category_list":[],"category_list_begin":[0,0,0],"category_list_end":[0,0,0],"deleted_nodes":[1,2],"has_categorical_split":false,"internal_nodes":[],"is_log_scale":false,"leaf_parents":[],"leaf_value":[0.25913868287366437,-0.12176656031525897,-0.08943489974077189],"leaf_vector":[],"leaf_vector_begin":[0,0,0],"leaf_vector_end":[0,0,0],"leaves":[0],"left":[-1,-1,-1],"node_deleted":[false,true,true],"node_type":[0,0,0],"num_deleted_nodes":2,"num_nodes":3,"output_dimension":1,"parent":[-1,0,0],"right":[-1,-1,-1],"split_index":[1,-1,-1],"threshold":[0.4443908976503099,0.0,0.0]},"tree_4":{"category_list":[],"category_list_begin":[0],"category_list_end":[0],"deleted_nodes":[],"has_categorical_split":false,"internal_nodes":[],"is_log_scale":false,"leaf_parents":[],"leaf_value":[0.03982782068569912],"leaf_vector":[],"leaf_vector_begin":[0],"leaf_vector_end":[0],"leaves":[0],"left":[-1],"node_deleted":[false],"node_type":[0],"num_deleted_nodes":0,"num_nodes":1,"output_dimension":1,"parent":[-1],"right":[-1],"split_index":[-1],"threshold":[0.0]}},"initialized":true,"is_exponentiated":false,"is_leaf_constant":false,"num_samples":3,"num_trees":5,"output_dimension":1}},"has_rfx":false,"has_rfx_basis":false,"include_variance_forest":false,"internal_propensity_model":false,"keep_every":1.0,"multivariate_treatment":false,"num_burnin":3.0,"num_chains":1.0,"num_forests":2,"num_gfr":0.0,"num_mcmc":3.0,"num_random_effects":0,"num_rfx_basis":0.0,"num_samples":3.0,"outcome_mean":3.044749670081113,"outcome_model":{"link":"identity","outcome":"continuous"},"outcome_scale":1.570122816215672,"parameters":{"b0_samples":[-0.43720528875906,-0.45178337977885774,-0.33122867106175097],"b1_samples":[-0.07722414928401335,0.04019985809869323,0.15077732694406673],"sigma2_global_samples":[1.7348319521710993,1.5821754190721142,1.2723398618240371],"sigma2_leaf_mu_samples":[0.11165239514270793,0.10920419672917195,0.08843310283861057],"tau_0_samples":[1.404431591063826,2.3167405128625447,1.8885557519745888]},"probit_outcome_model":false,"propensity_covariate":"prognostic","random_effects":{},"rfx_model_spec":"custom","sample_sigma2_global":true,"sample_sigma2_leaf_mu":true,"sample_sigma2_leaf_tau":false,"sample_tau_0":true,"sigma2_init":1.0,"standardize":true,"stochtree_version":"0.4.1","tau_0_dim":1.0} \ No newline at end of file diff --git a/test/python/test_serialization_compat.py b/test/python/test_serialization_compat.py new file mode 100644 index 00000000..a317b29e --- /dev/null +++ b/test/python/test_serialization_compat.py @@ -0,0 +1,272 @@ +""" +Backward-compatibility deserialization tests + +These tests verify that models serialized without certain optional fields +(as would be produced by older package versions) can still be loaded +correctly, with appropriate warnings where applicable. + +Fixture files (test/python/fixtures/) are generated once from the current +package and checked in. They serve as a "snapshot" — if a future change +breaks the ability to deserialize them, these tests will catch it. +""" +import json +import warnings +from pathlib import Path + +import numpy as np +import pytest + +from stochtree import BARTModel, BCFModel + +FIXTURES_DIR = Path(__file__).parent / "fixtures" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _load_fixture(name: str) -> dict: + """Load a fixture JSON file and return as a Python dict.""" + with open(FIXTURES_DIR / name) as f: + return json.load(f) + + +def _to_json_string(obj: dict) -> str: + """Serialise a Python dict back to a JSON string.""" + return json.dumps(obj) + + +def _strip_fields(obj: dict, *fields: str) -> str: + """Remove fields from obj (top-level) and return the JSON string.""" + obj = dict(obj) # shallow copy + for f in fields: + obj.pop(f, None) + return _to_json_string(obj) + + +def _collect_warnings(fn, *args, **kwargs): + """Call fn(*args, **kwargs) and return (result, list_of_warning_messages).""" + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + result = fn(*args, **kwargs) + msgs = [str(w.message) for w in caught] + return result, msgs + + +# =========================================================================== +# BART snapshot tests +# =========================================================================== + +class TestBARTSnapshot: + def test_fixture_loads_and_predicts(self): + """BART fixture deserialises and prediction succeeds.""" + fixture_obj = _load_fixture("bart_mcmc.json") + json_str = _to_json_string(fixture_obj) + + rng = np.random.default_rng(1) + X = rng.uniform(size=(30, 5)) + + m = BARTModel() + m.from_json(json_str) + preds = m.predict(X) + assert "y_hat" in preds + assert preds["y_hat"].shape[0] == 30 + + def test_roundtrip_is_deterministic(self): + """Two loads of the same fixture string produce identical predictions.""" + fixture_obj = _load_fixture("bart_mcmc.json") + json_str = _to_json_string(fixture_obj) + + rng = np.random.default_rng(99) + X = rng.uniform(size=(20, 5)) + + m1 = BARTModel() + m1.from_json(json_str) + m2 = BARTModel() + m2.from_json(json_str) + + p1 = m1.predict(X)["y_hat"].mean(axis=1) + p2 = m2.predict(X)["y_hat"].mean(axis=1) + np.testing.assert_array_equal(p1, p2) + + +# =========================================================================== +# BCF snapshot tests +# =========================================================================== + +class TestBCFSnapshot: + def test_fixture_loads_and_predicts(self): + """BCF fixture deserialises and prediction succeeds.""" + fixture_obj = _load_fixture("bcf_mcmc.json") + json_str = _to_json_string(fixture_obj) + + rng = np.random.default_rng(1) + n = 20 + X = rng.uniform(size=(n, 5)) + Z = rng.binomial(1, 0.5, n).astype(float) + pi = np.full(n, 0.5) + + m = BCFModel() + m.from_json(json_str) + preds = m.predict(X, Z, pi) + assert "y_hat" in preds + assert "tau_hat" in preds + assert preds["y_hat"].shape[0] == n + + +# =========================================================================== +# BART backward-compat: missing optional fields +# =========================================================================== + +class TestBARTBackwardCompat: + def test_missing_outcome_model(self): + """BART loads without 'outcome_model' (pre-v0.4.1 format).""" + fixture_obj = _load_fixture("bart_mcmc.json") + json_str = _strip_fields(fixture_obj, "outcome_model", "probit_outcome_model") + + m = BARTModel() + m.from_json(json_str) + + rng = np.random.default_rng(1) + X = rng.uniform(size=(20, 5)) + preds = m.predict(X) + assert preds["y_hat"].shape[0] == 20 + + def test_missing_rfx_model_spec_no_rfx(self): + """BART loads without 'rfx_model_spec' when has_rfx=False.""" + fixture_obj = _load_fixture("bart_mcmc.json") + assert not fixture_obj.get("has_rfx", True), "Fixture must have has_rfx=False" + json_str = _strip_fields(fixture_obj, "rfx_model_spec") + + m = BARTModel() + m.from_json(json_str) + + rng = np.random.default_rng(1) + X = rng.uniform(size=(20, 5)) + preds = m.predict(X) + assert preds["y_hat"].shape[0] == 20 + + def test_missing_preprocessor_emits_warning(self): + """BART loads without 'covariate_preprocessor' and emits a warning.""" + fixture_obj = _load_fixture("bart_mcmc.json") + json_str = _strip_fields(fixture_obj, "covariate_preprocessor") + + m = BARTModel() + _, warns = _collect_warnings(m.from_json, json_str) + # Should warn about missing preprocessor or succeed silently + assert any("preprocessor" in w.lower() or "covariate" in w.lower() for w in warns) or len(warns) == 0 + assert isinstance(m, BARTModel) + + def test_missing_num_chains_keep_every(self): + """BART loads without 'num_chains' and 'keep_every'.""" + fixture_obj = _load_fixture("bart_mcmc.json") + json_str = _strip_fields(fixture_obj, "num_chains", "keep_every") + + m = BARTModel() + m.from_json(json_str) + + rng = np.random.default_rng(1) + X = rng.uniform(size=(20, 5)) + preds = m.predict(X) + assert preds["y_hat"].shape[0] == 20 + + +# =========================================================================== +# BCF backward-compat: missing optional fields +# =========================================================================== + +class TestBCFBackwardCompat: + def _predict(self, m: BCFModel, n: int = 20, seed: int = 1): + rng = np.random.default_rng(seed) + X = rng.uniform(size=(n, 5)) + Z = rng.binomial(1, 0.5, n).astype(float) + pi = np.full(n, 0.5) + return m.predict(X, Z, pi), n + + def test_missing_outcome_model(self): + """BCF loads without 'outcome_model' (pre-v0.4.1 format).""" + fixture_obj = _load_fixture("bcf_mcmc.json") + json_str = _strip_fields(fixture_obj, "outcome_model", "probit_outcome_model") + + m = BCFModel() + m.from_json(json_str) + preds, n = self._predict(m) + assert preds["y_hat"].shape[0] == n + + def test_missing_multivariate_treatment(self): + """BCF loads without 'multivariate_treatment' (pre-v0.4.0).""" + fixture_obj = _load_fixture("bcf_mcmc.json") + json_str = _strip_fields(fixture_obj, "multivariate_treatment") + + m = BCFModel() + m.from_json(json_str) + preds, n = self._predict(m) + assert preds["y_hat"].shape[0] == n + + def test_missing_internal_propensity_model(self): + """BCF loads without 'internal_propensity_model' (pre-v0.3.2).""" + fixture_obj = _load_fixture("bcf_mcmc.json") + json_str = _strip_fields(fixture_obj, "internal_propensity_model") + + m = BCFModel() + m.from_json(json_str) + preds, n = self._predict(m) + assert preds["y_hat"].shape[0] == n + + def test_missing_rfx_model_spec_no_rfx(self): + """BCF loads without 'rfx_model_spec' when has_rfx=False.""" + fixture_obj = _load_fixture("bcf_mcmc.json") + assert not fixture_obj.get("has_rfx", True), "Fixture must have has_rfx=False" + json_str = _strip_fields(fixture_obj, "rfx_model_spec") + + m = BCFModel() + m.from_json(json_str) + preds, n = self._predict(m) + assert preds["y_hat"].shape[0] == n + + def test_missing_preprocessor_emits_warning(self): + """BCF loads without 'covariate_preprocessor' and emits a warning.""" + fixture_obj = _load_fixture("bcf_mcmc.json") + json_str = _strip_fields(fixture_obj, "covariate_preprocessor") + + m = BCFModel() + _, warns = _collect_warnings(m.from_json, json_str) + # Should warn or at least succeed + assert any("preprocessor" in w.lower() or "covariate" in w.lower() for w in warns) or len(warns) == 0 + assert isinstance(m, BCFModel) + + def test_missing_num_chains_keep_every(self): + """BCF loads without 'num_chains' and 'keep_every'.""" + fixture_obj = _load_fixture("bcf_mcmc.json") + json_str = _strip_fields(fixture_obj, "num_chains", "keep_every") + + m = BCFModel() + m.from_json(json_str) + preds, n = self._predict(m) + assert preds["y_hat"].shape[0] == n + + def test_missing_has_rfx_basis(self): + """BCF loads without 'has_rfx_basis'.""" + fixture_obj = _load_fixture("bcf_mcmc.json") + json_str = _strip_fields(fixture_obj, "has_rfx_basis") + + m = BCFModel() + m.from_json(json_str) + preds, n = self._predict(m) + assert preds["y_hat"].shape[0] == n + + def test_missing_multiple_optional_fields(self): + """BCF loads when many optional fields are absent simultaneously.""" + fixture_obj = _load_fixture("bcf_mcmc.json") + json_str = _strip_fields( + fixture_obj, + "outcome_model", "probit_outcome_model", + "multivariate_treatment", "internal_propensity_model", + "rfx_model_spec", "num_chains", "keep_every", + "has_rfx_basis", + ) + + m = BCFModel() + m.from_json(json_str) + preds, n = self._predict(m) + assert preds["y_hat"].shape[0] == n