diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 27000aaf5..04a0e545a 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -61,6 +61,9 @@ jobs: pip install . cd .. + - name: Set RETICULATE_PYTHON for vignette rendering + run: echo "RETICULATE_PYTHON=$(which python)" >> $GITHUB_ENV + - name: Setup pandoc uses: r-lib/actions/setup-pandoc@v2 @@ -73,57 +76,42 @@ jobs: uses: r-lib/actions/setup-r-dependencies@v2 with: working-directory: 'stochtree_repo' - extra-packages: any::latex2exp, any::ggplot2, any::decor, any::pkgdown + extra-packages: any::latex2exp, any::ggplot2, any::decor, any::pkgdown, any::reticulate, any::bayesplot, any::coda, any::doParallel, any::foreach, any::mvtnorm, any::rpart, any::rpart.plot, any::tgp, any::rprojroot needs: website - name: Build R doc site run: | cd stochtree_repo - Rscript cran-bootstrap.R 1 1 1 + Rscript cran-bootstrap.R 0 1 0 cd .. mkdir -p docs/R_docs/pkgdown Rscript -e 'pkgdown::build_site_github_pages("stochtree_repo/stochtree_cran", dest_dir = "../../docs/R_docs/pkgdown", install = TRUE)' + - name: Install stochtree R package for vignettes + run: Rscript -e 'install.packages("stochtree_repo/stochtree_cran", repos = NULL, type = "source")' + - name: Clean up the temporary stochtree_cran directory created run: | cd stochtree_repo - Rscript cran-cleanup.R + Rscript cran-cleanup.R cd .. - - name: Copy Jupyter notebook demos over to docs directory - run: | - cp stochtree_repo/demo/notebooks/supervised_learning.ipynb docs/python_docs/demo/supervised_learning.ipynb - cp stochtree_repo/demo/notebooks/causal_inference.ipynb docs/python_docs/demo/causal_inference.ipynb - cp stochtree_repo/demo/notebooks/heteroskedastic_supervised_learning.ipynb docs/python_docs/demo/heteroskedastic_supervised_learning.ipynb - cp stochtree_repo/demo/notebooks/multivariate_treatment_causal_inference.ipynb docs/python_docs/demo/multivariate_treatment_causal_inference.ipynb - cp stochtree_repo/demo/notebooks/reparameterized_causal_inference.ipynb docs/python_docs/demo/reparameterized_causal_inference.ipynb - cp stochtree_repo/demo/notebooks/serialization.ipynb docs/python_docs/demo/serialization.ipynb - cp stochtree_repo/demo/notebooks/tree_inspection.ipynb docs/python_docs/demo/tree_inspection.ipynb - cp stochtree_repo/demo/notebooks/summary.ipynb docs/python_docs/demo/summary.ipynb - cp stochtree_repo/demo/notebooks/ordinal_outcome.ipynb docs/python_docs/demo/ordinal_outcome.ipynb - cp stochtree_repo/demo/notebooks/prototype_interface.ipynb docs/python_docs/demo/prototype_interface.ipynb - cp stochtree_repo/demo/notebooks/sklearn_wrappers.ipynb docs/python_docs/demo/sklearn_wrappers.ipynb - cp stochtree_repo/demo/notebooks/multi_chain.ipynb docs/python_docs/demo/multi_chain.ipynb - - - name: Copy static vignettes over to docs directory - run: | - cp vignettes/Python/RDD/rdd.html docs/vignettes/Python/rdd.html - cp vignettes/Python/RDD/RDD_DAG.png docs/vignettes/Python/RDD_DAG.png - cp vignettes/Python/RDD/trees1.png docs/vignettes/Python/trees1.png - cp vignettes/Python/RDD/trees2.png docs/vignettes/Python/trees2.png - cp vignettes/Python/RDD/trees3.png docs/vignettes/Python/trees3.png - cp vignettes/R/RDD/rdd.html docs/vignettes/R/rdd.html - cp vignettes/Python/IV/iv.html docs/vignettes/Python/iv.html - cp vignettes/Python/IV/IV_CDAG.png docs/vignettes/Python/IV_CDAG.png - cp vignettes/R/IV/iv.html docs/vignettes/R/iv.html - + - name: Install Quarto + uses: quarto-dev/quarto-actions/setup@v2 + + - name: Install quartodoc + run: pip install quartodoc "griffe<1.0" + + - name: Regenerate Python API reference pages + run: quartodoc build + - name: Build the overall doc site run: | - mkdocs build + quarto render - name: Deploy to GitHub pages πŸš€ if: github.event_name != 'pull_request' uses: JamesIves/github-pages-deploy-action@v4 with: branch: gh-pages - folder: site \ No newline at end of file + folder: _site \ No newline at end of file diff --git a/.gitignore b/.gitignore index 0c5a93ab1..94c7985dd 100644 --- a/.gitignore +++ b/.gitignore @@ -34,9 +34,24 @@ yarn-error.log* /site/ /docs/cpp_docs/doxygen/ /docs/R_docs/pkgdown/* +_site/ +_freeze/ + +# Quarto render artifacts (output-dir should be _site/, but these appear in-place too) +/*.html +/site_libs/ +/objects.json +/development/*.html +/python-api/**/*.html +/python-api/reference/*.qmd # Virtual environments /python_venv /cpp_venv /venv +.venv .Rproj.user + +/.quarto/ +**/*.quarto_ipynb +**/*.rmarkdown diff --git a/.here b/.here new file mode 100644 index 000000000..e69de29bb diff --git a/README.md b/README.md index 7dc24bbd9..cdf0ad8df 100644 --- a/README.md +++ b/README.md @@ -4,80 +4,64 @@ ### MacOS -#### Cloning the repo +#### Software dependencies -First, you will need the stochtree repo on your local machine. -Navigate to the `documentation` repo in your terminal (*replace `~/path/to/documentation` with the path to the documentation repo on your local system*). +You'll need to have the following software installed -```{bash} -cd ~/path/to/documentation -``` - -Now, recursively clone the main `stochtree` repo into a `stochtree_repo` subfolder of the `documentation` repo - -```{bash} -git clone --recursive git@github.com:StochasticTree/stochtree.git stochtree_repo -``` - -#### Setting up build dependencies +- Python: can be installed via [homebrew](https://formulae.brew.sh/formula/python@3.14), [conda](https://www.anaconda.com/download), and [directly from the python site](https://www.python.org/downloads/) +- R: can be installed via [CRAN](https://cran.r-project.org/) or [homebrew](https://formulae.brew.sh/formula/r) +- Quarto: can be installed [directly from the Quarto site](https://quarto.org/docs/get-started/) or [homebrew](https://formulae.brew.sh/cask/quarto) +- Doxygen: can be installed [directly from the Doxygen site](https://www.doxygen.nl/) or [homebrew](https://formulae.brew.sh/formula/doxygen) -The docs are largely built using [`mkdocs`](https://www.mkdocs.org), [`pkgdown`](https://pkgdown.r-lib.org) and [`doxygen`](https://www.doxygen.nl/index.html), -with everything tied together using the ["Material for MkDocs"](https://squidfunk.github.io/mkdocs-material/) theme. +#### Setting up R and Python build dependencies -We first create a virtual environment and install the dependencies for `stochtree` as well as the doc site (several python packages: `mkdocs-material`, `mkdocstrings-python`, and `mkdocs-jupyter`). +Building multi-lingual (R and Python) vignettes requires installing the vignettes' package dependencies. In Python, this is done via a virtual environment (local `.venv`) ```{bash} -python -m venv venv -source venv/bin/activate +python -m venv .venv +source .venv/bin/activate pip install --upgrade pip pip install -r requirements.txt +pip install git+https://github.com/StochasticTree/stochtree.git ``` -##### stochtree - -Now, we build the `stochtree` python library locally in the virtual environment activated above +And in R, this is typically done as a global system install, though you might also consider [`renv`](https://rstudio.github.io/renv/) for managing project-specific R dependencies ```{bash} -cd stochtree_repo -pip install . -cd .. +Rscript -e 'install.packages(c("remotes", "devtools", "roxygen2", "ggplot2", "latex2exp", "decor", "pkgdown", "cpp11", "BH", "doParallel", "foreach", "knitr", "Matrix", "MASS", "mvtnorm", "rmarkdown", "testthat", "tgp", "here", "reticulate"), repos="https://cloud.r-project.org/")' +Rscript -e 'remotes::install_github("StochasticTree/stochtree", ref = "r-dev")' ``` -##### pkgdown +#### Cloning the stochtree repo -To use `pkgdown`, you need to install [R](https://cran.r-project.org). -One R is installed, make sure the dependendencies of the pkgdown build are installed +Building the R API docs (roxygen2 / pkgdown) and C++ API docs (Doxygen) requires the stochtree source. Clone it into `stochtree_repo/` at the repo root: -```{bash} -Rscript -e 'install.packages(c("remotes", "devtools", "roxygen2", "ggplot2", "latex2exp", "decor", "pkgdown", "cpp11", "BH", "doParallel", "foreach", "knitr", "Matrix", "MASS", "mvtnorm", "rmarkdown", "testthat", "tgp"), repos="https://cloud.r-project.org/")' +```bash +git clone --recurse-submodules https://github.com/StochasticTree/stochtree.git stochtree_repo ``` -### Building the R docs +#### Building the pkgdown site for the R API -To build the R docs, first run a script that lays out the package as needed +With the stochtree repo checked out and R dependencies installed (see above), run: -```{bash} +```bash cd stochtree_repo -Rscript cran-bootstrap.R 1 1 1 +Rscript cran-bootstrap.R 0 1 0 cd .. -``` - -Then run the `pkgdown` build workflow to put the R docs in the correct folder - -```{bash} mkdir -p docs/R_docs/pkgdown Rscript -e 'pkgdown::build_site_github_pages("stochtree_repo/stochtree_cran", dest_dir = "../../docs/R_docs/pkgdown", install = TRUE)' -rm -rf stochtree_repo/stochtree_cran +cd stochtree_repo +Rscript cran-cleanup.R +cd .. ``` -### Building the doxygen site for the C++ API +`cran-bootstrap.R 0 1 0` prepares a `stochtree_cran/` subdirectory with the pkgdown config but without vignette source (vignettes are served by the Quarto site instead). `cran-cleanup.R` removes that temporary directory when done. The output is written to `docs/R_docs/pkgdown/`. -First, ensure that you have [doxygen](https://www.doxygen.nl/index.html) installed. -On MacOS, this can be [done via homebrew](https://formulae.brew.sh/formula/doxygen) (i.e. `brew install doxygen`). +#### Building the Doxygen site for the C++ API -Then, modify the `Doxyfile` to build the C++ documentation as desired and build the doxygen site +With the stochtree repo checked out and Doxygen installed (`brew install doxygen` on macOS), run: -```{bash} +```bash sed -i '' 's|^OUTPUT_DIRECTORY *=.*|OUTPUT_DIRECTORY = ../docs/cpp_docs/|' stochtree_repo/Doxyfile sed -i '' 's|^GENERATE_XML *=.*|GENERATE_XML = NO|' stochtree_repo/Doxyfile sed -i '' 's|^GENERATE_HTML *=.*|GENERATE_HTML = YES|' stochtree_repo/Doxyfile @@ -87,47 +71,58 @@ doxygen Doxyfile cd .. ``` -### Copying Jupyter notebook demos to the docs directory +The output is written to `docs/cpp_docs/doxygen/`. -```{bash} -cp stochtree_repo/demo/notebooks/supervised_learning.ipynb docs/python_docs/demo/supervised_learning.ipynb -cp stochtree_repo/demo/notebooks/causal_inference.ipynb docs/python_docs/demo/causal_inference.ipynb -cp stochtree_repo/demo/notebooks/heteroskedastic_supervised_learning.ipynb docs/python_docs/demo/heteroskedastic_supervised_learning.ipynb -cp stochtree_repo/demo/notebooks/ordinal_outcome.ipynb docs/python_docs/demo/ordinal_outcome.ipynb -cp stochtree_repo/demo/notebooks/multivariate_treatment_causal_inference.ipynb docs/python_docs/demo/multivariate_treatment_causal_inference.ipynb -cp stochtree_repo/demo/notebooks/reparameterized_causal_inference.ipynb docs/python_docs/demo/reparameterized_causal_inference.ipynb -cp stochtree_repo/demo/notebooks/serialization.ipynb docs/python_docs/demo/serialization.ipynb -cp stochtree_repo/demo/notebooks/tree_inspection.ipynb docs/python_docs/demo/tree_inspection.ipynb -cp stochtree_repo/demo/notebooks/summary.ipynb docs/python_docs/demo/summary.ipynb -cp stochtree_repo/demo/notebooks/prototype_interface.ipynb docs/python_docs/demo/prototype_interface.ipynb -cp stochtree_repo/demo/notebooks/sklearn_wrappers.ipynb docs/python_docs/demo/sklearn_wrappers.ipynb -cp stochtree_repo/demo/notebooks/multi_chain.ipynb docs/python_docs/demo/multi_chain.ipynb +#### Building the vignettes with quarto + +The vignettes live in the `vignettes/` directory and are configured as a standalone Quarto website via `vignettes/_quarto.yml`. Each `.qmd` file uses `{.panel-tabset group="language"}` tabsets to present R and Python code side-by-side. Python cells are executed via `reticulate`; set the `RETICULATE_PYTHON` environment variable to point at your `.venv` interpreter if it isn't picked up automatically. + +To render all vignettes at once: + +```bash +cd vignettes +quarto render ``` -### Copy static vignettes over to docs directory +To render a single vignette: -```{bash} -cp vignettes/Python/RDD/rdd.html docs/vignettes/Python/rdd.html -cp vignettes/Python/RDD/RDD_DAG.png docs/vignettes/Python/RDD_DAG.png -cp vignettes/Python/RDD/trees1.png docs/vignettes/Python/trees1.png -cp vignettes/Python/RDD/trees2.png docs/vignettes/Python/trees2.png -cp vignettes/Python/RDD/trees3.png docs/vignettes/Python/trees3.png -cp vignettes/R/RDD/rdd.html docs/vignettes/R/rdd.html -cp vignettes/Python/IV/iv.html docs/vignettes/Python/iv.html -cp vignettes/Python/IV/IV_CDAG.png docs/vignettes/Python/IV_CDAG.png -cp vignettes/R/IV/iv.html docs/vignettes/R/iv.html +```bash +cd vignettes +quarto render bart.qmd ``` -### Building the overall website +To preview the vignette site locally with live reload: -To build and preview the site locally, run +```bash +cd vignettes +quarto preview +``` -```{bash} -mkdocs serve +The rendered site is written to `vignettes/_site/`. Individual vignettes use `freeze: auto` in their frontmatter, so re-renders only re-execute cells whose source has changed. To force a full re-execution, delete `vignettes/_freeze/` before rendering. + +#### Regenerating the Python API reference pages + +The `python-api/reference/*.qmd` files are generated by quartodoc from the stochtree package's docstrings. They are checked into the repo, so a normal `quarto render` will render whatever is already there. If you've updated docstrings in the stochtree Python package, regenerate them first: + +```bash +source .venv/bin/activate +quartodoc build ``` -To build the files underlying the static site, run +This reads the `quartodoc:` block in `_quarto.yml` and writes updated `.qmd` files to `python-api/reference/`. Run it from the repo root. After regenerating, commit the updated `.qmd` files and run `quarto render` as normal. -```{bash} -mkdocs build +The CI workflow always runs `quartodoc build` before `quarto render` so the live site stays in sync with the latest package docstrings. + +### Building the overall website + +The full site (vignettes + Python API reference + embedded pkgdown/Doxygen) is built from the repo root with: + +```bash +quarto render ``` + +This requires pkgdown and Doxygen output to already exist at `docs/R_docs/pkgdown/` and `docs/cpp_docs/doxygen/` respectively (the CI workflow builds these before running `quarto render`). For iterating on vignettes alone, the `cd vignettes && quarto render` workflow described above is faster. + +**Freeze cache note:** The vignette `.qmd` files use `freeze: auto`, so re-renders only re-execute cells whose source has changed. The freeze cache lives at `_freeze/vignettes/` (top-level render) or `vignettes/_freeze/` (standalone vignette render). If you switch between the two render modes, copy the cache to the appropriate location before rendering to avoid unnecessary re-execution. + +The CI workflow (`.github/workflows/docs.yml`) handles the full build and deploys the output `_site/` directory to the `gh-pages` branch. diff --git a/_quarto.yml b/_quarto.yml new file mode 100644 index 000000000..da63788d9 --- /dev/null +++ b/_quarto.yml @@ -0,0 +1,155 @@ +project: + type: website + output-dir: _site + render: + - "*.qmd" + - "vignettes/*.qmd" + - "python-api/reference/*.qmd" + - "development/*.qmd" + resources: + - docs/R_docs/pkgdown/ + - docs/cpp_docs/doxygen/ + +website: + title: "StochTree" + site-url: "https://stochtree.ai/" + repo-url: "https://github.com/StochasticTree/stochtree" + repo-actions: [issue] + + navbar: + left: + - href: index.qmd + text: Home + - href: getting-started.qmd + text: Getting Started + - href: about.qmd + text: About + - text: R Package + href: docs/R_docs/pkgdown/index.html + - text: Python API + href: python-api/reference/index.qmd + - text: C++ API + href: docs/cpp_docs/doxygen/index.html + - text: Vignettes + href: vignettes/index.qmd + - text: Development + menu: + - text: Overview + href: development/index.qmd + - text: Contributing + href: development/contributing.qmd + - text: Adding New Models + href: development/new-models.qmd + - text: Roadmap + href: development/roadmap.qmd + + sidebar: + - id: python-api + title: "Python API" + style: docked + contents: + - python-api/reference/index.qmd + - section: "Core Models" + contents: + - python-api/reference/bart.BARTModel.qmd + - python-api/reference/bcf.BCFModel.qmd + - section: "Scikit-Learn Interface" + contents: + - python-api/reference/sklearn.StochTreeBARTRegressor.qmd + - python-api/reference/sklearn.StochTreeBARTBinaryClassifier.qmd + - section: "Low-Level API" + contents: + - python-api/reference/data.Dataset.qmd + - python-api/reference/data.Residual.qmd + - python-api/reference/forest.Forest.qmd + - python-api/reference/forest.ForestContainer.qmd + - python-api/reference/sampler.ForestSampler.qmd + - python-api/reference/sampler.GlobalVarianceModel.qmd + - python-api/reference/sampler.LeafVarianceModel.qmd + + - id: vignettes + title: "Vignettes" + style: docked + contents: + - vignettes/index.qmd + - section: "Core Models" + contents: + - vignettes/bart.qmd + - vignettes/bcf.qmd + - vignettes/heteroskedastic.qmd + - vignettes/ordinal-outcome.qmd + - vignettes/multivariate-bcf.qmd + - section: "Practical Topics" + contents: + - vignettes/serialization.qmd + - vignettes/multi-chain.qmd + - vignettes/tree-inspection.qmd + - vignettes/summary-plotting.qmd + - vignettes/prior-calibration.qmd + - vignettes/sklearn.qmd + - section: "Low-Level Interface" + contents: + - vignettes/custom-sampling.qmd + - vignettes/ensemble-kernel.qmd + - section: "Advanced Methods" + contents: + - vignettes/rdd.qmd + - vignettes/iv.qmd + +format: + html: + theme: [minty, assets/custom.scss] + css: assets/api.css + toc: true + toc-depth: 3 + grid: + body-width: 1000px + margin-width: 200px + +execute: + freeze: auto + +quartodoc: + package: stochtree + dir: python-api/reference + style: pkgdown + parser: numpy + render_interlinks: false + sections: + - title: Core Models + desc: High-level model interfaces for supervised learning and causal inference. + contents: + - name: bart.BARTModel + member_order: source + - name: bcf.BCFModel + member_order: source + - title: Scikit-Learn Interface + desc: stochtree models wrapped as sklearn-compatible estimators. + contents: + - name: sklearn.StochTreeBARTRegressor + member_order: source + - name: sklearn.StochTreeBARTBinaryClassifier + member_order: source + - title: Low-Level API β€” Data + desc: Data structures for custom sampling workflows. + contents: + - name: data.Dataset + member_order: source + - name: data.Residual + member_order: source + - title: Low-Level API β€” Forest + desc: Forest containers and inspection. + contents: + - name: forest.Forest + member_order: source + - name: forest.ForestContainer + member_order: source + - title: Low-Level API β€” Samplers + desc: Sampler classes for building custom models. + contents: + - name: sampler.ForestSampler + member_order: source + - name: sampler.GlobalVarianceModel + member_order: source + - name: sampler.LeafVarianceModel + member_order: source diff --git a/about.qmd b/about.qmd new file mode 100644 index 000000000..9a8159ab5 --- /dev/null +++ b/about.qmd @@ -0,0 +1,105 @@ +--- +title: "Overview of Stochastic Tree Models" +--- + +Stochastic tree models are a powerful addition to your modeling toolkit. +As with many machine learning methods, understanding these models in depth is an involved task. + +There are many excellent published papers on stochastic tree models +(to name a few, the [original BART paper](https://projecteuclid.org/journals/annals-of-applied-statistics/volume-4/issue-1/BART-Bayesian-additive-regression-trees/10.1214/09-AOAS285.full), +[the XBART paper](https://www.tandfonline.com/doi/full/10.1080/01621459.2021.1942012), +and [the BCF paper](https://projecteuclid.org/journals/bayesian-analysis/volume-15/issue-3/Bayesian-Regression-Tree-Models-for-Causal-Inference--Regularization-Confounding/10.1214/19-BA1195.full)). +Here, we aim to build up an abbreviated intuition for these models from their conceptually-simple building blocks. + +## Notation + +We're going to introduce some notation to make these concepts precise. +In a traditional supervised learning setting, we hope to predict some **outcome** from **features** in a training dataset. +We'll call the outcome $y$ and the features $X$. +Our goal is to come up with a function $f$ that predicts the outcome $y$ as well as possible from $X$ alone. + +## Decision Trees + +[Decision tree learning](https://en.wikipedia.org/wiki/Decision_tree_learning) is a simple machine learning method that +constructs a function $f$ from a series of conditional statements. Consider the tree below. + +```{mermaid} +stateDiagram-v2 + state split_one <> + state split_two <> + split_one --> split_two: if x1 <= 1 + split_one --> c : if x1 > 1 + split_two --> a: if x2 <= -2 + split_two --> b : if x2 > -2 +``` + +We evaluate two conditional statements (`X[,1] > 1` and `X[,2] > -2`), arranged in a tree-like sequence of branches, +which determine whether the model predicts `a`, `b`, or `c`. We could similarly express this tree in math notation as + +$$ +f(X_i) = \begin{cases} +a & ; \;\;\; X_{i,1} \leq 1, \;\; X_{i,2} \leq -2\\ +b & ; \;\;\; X_{i,1} \leq 1, \;\; X_{i,2} > -2\\ +c & ; \;\;\; X_{i,1} > 1 +\end{cases} +$$ + +We won't belabor the discussion of trees as there are many good textbooks and online articles on the topic, +but we'll close by noting that training decision trees introduces a delicate balance between +[overfitting and underfitting](https://en.wikipedia.org/wiki/Overfitting). +Simple trees like the one above do not capture much complexity in a dataset and may potentially be underfit +while deep, complex trees are vulnerable to overfitting and tend to have high variance. + +## Boosted Decision Tree Ensembles + +One way to address the overfitting-underfitting tradeoff of decision trees is to build an "ensemble" of decision +trees, so that the function $f$ is defined by a sum of $k$ individual decision trees $g_i$ + +$$ +f(X_i) = g_1(X_i) + \dots + g_k(X_i) +$$ + +There are several ways to train an ensemble of decision trees (sometimes called "forests"), the most popular of which are [random forests](https://en.wikipedia.org/wiki/Random_forest) and +[gradient boosting](https://en.wikipedia.org/wiki/Gradient_boosting). Their main difference is that random forests train +all $m$ trees independently of one another, while boosting trains trees sequentially, so that tree $j$ depends on the result of training trees 1 through $j-1$. +Libraries like [xgboost](https://xgboost.readthedocs.io/en/stable/) and [LightGBM](https://lightgbm.readthedocs.io/en/latest/) are popular examples of boosted tree ensembles. + +Tree ensembles often [outperform neural networks and other machine learning methods on tabular datasets](https://arxiv.org/abs/2207.08815), +but classic tree ensemble methods return a single estimated function $f$, without expressing uncertainty around its estimates. + +## Stochastic Tree Ensembles + +[Stochastic](https://en.wikipedia.org/wiki/Stochastic) tree ensembles differ from their classical counterparts in their use of randomness in learning a function. +Rather than returning a single "best" tree ensemble, stochastic tree ensembles return a range of tree ensembles that fit the data well. +Mechanically, it's useful to think of "sampling" -- rather than "fitting" -- a stochastic tree ensemble model. + +Why is this useful? Suppose we've sampled $m$ forests. For each observation $i$, we obtain $m$ predictions: $[f_1(X_i), \dots, f_m(X_i)]$. +From this "dataset" of predictions, we can compute summary statistics, where a mean or median would give something akin to the predictions of an xgboost or lightgbm model, +and the $\alpha$ and $1-\alpha$ quantiles give a [credible interval](https://en.wikipedia.org/wiki/Credible_interval). + +Rather than explain each of the models that `stochtree` supports in depth here, we provide a high-level overview, with pointers to the relevant literature. + +### Supervised Learning + +The [`bart`](docs/R_docs/pkgdown/reference/bart.html) R function and the [`BARTModel`](python-api/reference/bart.BARTModel.qmd) Python class are the primary interface for supervised +prediction tasks in `stochtree`. The primary references for these models are +[BART (Chipman, George, McCulloch 2010)](https://projecteuclid.org/journals/annals-of-applied-statistics/volume-4/issue-1/BART-Bayesian-additive-regression-trees/10.1214/09-AOAS285.full) and +[XBART (He and Hahn 2021)](https://www.tandfonline.com/doi/full/10.1080/01621459.2021.1942012). + +In addition to the standard BART / XBART models, in which each tree's leaves return a constant prediction, `stochtree` also supports +arbitrary leaf regression on a user-provided basis (i.e. an expanded version of [Chipman et al 2002](https://link.springer.com/article/10.1023/A:1013916107446) and [Gramacy and Lee 2012](https://www.tandfonline.com/doi/abs/10.1198/016214508000000689)). + +### Causal Inference + +The [`bcf`](docs/R_docs/pkgdown/reference/bcf.html) R function and the [`BCFModel`](python-api/reference/bcf.BCFModel.qmd) Python class are the primary interface for causal effect +estimation in `stochtree`. The primary references for these models are +[BCF (Hahn, Murray, Carvalho 2021)](https://projecteuclid.org/journals/bayesian-analysis/volume-15/issue-3/Bayesian-Regression-Tree-Models-for-Causal-Inference--Regularization-Confounding/10.1214/19-BA1195.full) and +[XBCF (Krantsevich, He, Hahn 2022)](https://arxiv.org/abs/2209.06998). + +### Additional Modeling Features + +Both the BART and BCF interfaces in `stochtree` support the following extensions: + +* Accelerated / "warm-start" sampling of forests (i.e. [He and Hahn 2021](https://www.tandfonline.com/doi/full/10.1080/01621459.2021.1942012)) +* Forest-based heteroskedasticity (i.e. [Murray 2021](https://www.tandfonline.com/doi/abs/10.1080/01621459.2020.1813587)) +* Additive random effects (i.e. [Gelman et al 2008](https://www.tandfonline.com/doi/abs/10.1198/106186008X287337)) diff --git a/assets/api.css b/assets/api.css new file mode 100644 index 000000000..e5b738be8 --- /dev/null +++ b/assets/api.css @@ -0,0 +1,37 @@ +/* ── Secondary nav bar (breadcrumb bar below main navbar) ──────────────── + The bar uses var(--bs-breadcrumb-bg) for its background, so we override + the variable rather than the property. + ──────────────────────────────────────────────────────────────────────── */ +:root { + --bs-breadcrumb-bg: #ffffff; +} + +/* ── Secondary nav bar: links (e.g. "Core Models") ────────────────────────── + Breadcrumb item links inherit the global link color (teal) which disappears + on the now-white breadcrumb bar. Override to the same dark slate as the text. + ──────────────────────────────────────────────────────────────────────────── */ +.quarto-secondary-nav a, +.breadcrumb-item a { + color: #2d3748 !important; +} + +/* ── Methods summary table ─────────────────────────────────────────────── + The methods table lives in
. + First column (method name): never wrap, pin a minimum width so long + names like `compute_posterior_interval` stay on one line. + ──────────────────────────────────────────────────────────────────────── */ +section#methods table.table td:first-child, +section#methods table.table th:first-child { + min-width: 240px; + white-space: nowrap; +} + +/* ── Parameter / Returns tables ────────────────────────────────────────── + These live in
+ Give the Name column a floor so it isn't crowded by Type/Description. + ──────────────────────────────────────────────────────────────────────── */ +.doc-section table.table td:first-child, +.doc-section table.table th:first-child { + min-width: 140px; + white-space: nowrap; +} diff --git a/assets/custom.scss b/assets/custom.scss new file mode 100644 index 000000000..4e818bc9d --- /dev/null +++ b/assets/custom.scss @@ -0,0 +1,16 @@ +/*-- scss:defaults --*/ + +// Replace minty's green primary with a teal-blue +$primary: #1a7a9c; + +// Breadcrumb bar: white background, dark text +$breadcrumb-bg: #ffffff; +$breadcrumb-color: #2d3748; +$breadcrumb-active-color: #2d3748; +$breadcrumb-divider-color: #2d3748; + +// Dark navbar +$navbar-bg: #2d3748; +$navbar-fg: #ffffff; +$navbar-hl: #7dd3e8; // lighter teal-blue for active/hover links + diff --git a/development/contributing.qmd b/development/contributing.qmd new file mode 100644 index 000000000..44b951a21 --- /dev/null +++ b/development/contributing.qmd @@ -0,0 +1,277 @@ +--- +title: "Contributing" +--- + +`stochtree` is hosted on [Github](https://github.com/StochasticTree/stochtree/). +Any feedback, requests, or bug reports can be submitted as [issues](https://github.com/StochasticTree/stochtree/issues). +Moreover, if you have ideas for how to improve stochtree, we welcome [pull requests](https://github.com/StochasticTree/stochtree/pulls). + +## Building StochTree + +Any local stochtree development will require cloning the repository from Github. +If you don't have git installed, you can do so following [these instructions](https://learn.microsoft.com/en-us/devops/develop/git/install-and-set-up-git). + +Once git is available at the command line, navigate to the folder that will store this project (in bash / zsh, this is done by running `cd` followed by the path to the directory). +Then, clone the `stochtree` repo as a subfolder by running + +```bash +git clone --recursive https://github.com/StochasticTree/stochtree.git +``` + +*NOTE*: this project incorporates several C++ dependencies as [git submodules](https://git-scm.com/book/en/v2/Git-Tools-Submodules), +which is why the `--recursive` flag is necessary. If you have already cloned the repo without the `--recursive` flag, +you can retrieve the submodules recursively by running `git submodule update --init --recursive` in the main repo directory. + +### R + +This section will detail how to use RStudio to build and make changes to stochtree. There are other tools that are useful for R +package development (for example, [Positron](https://github.com/posit-dev/positron), [VS Code](https://code.visualstudio.com/docs/languages/r), +and [ESS](https://ess.r-project.org/)), but we will focus on RStudio in this walkthrough. + +Once you've cloned the stochtree repository, follow these steps to build stochtree: + +1. [Create an RStudio project in the stochtree directory](https://support.posit.co/hc/en-us/articles/200526207-Using-RStudio-Projects) +2. [Build the package in RStudio](https://docs.posit.co/ide/user/ide/guide/pkg-devel/writing-packages.html#building-a-package) + +Note that due to the complicated folder structure of the stochtree repo, step 2 might not work out of the box on all platforms. +If stochtree fails to build, you can use the script that we use to create a CRAN-friendly stochtree R package directory, which +creates a `stochtree_cran` subdirectory of the stochtree folder and copies the relevant R package files into this subfolder. +You can run this script by entering `Rscript cran-bootstrap.R 1 1 1` in the terminal in RStudio. +Once you have a `stochtree_cran` subfolder, you can build stochtree using + +```r +devtools::install_local("stochtree_cran") +``` + +Since this is a temporary folder, you can clean it up by running `Rscript cran-cleanup.R` in the terminal in RStudio. + +### Python + +Building and making changes to the python library is best done in an isolated virtual environment. There are many different ways of +managing virtual environments in Python, but here we focus on `conda` and `venv`. + +#### Conda + +Conda provides a straightforward experience in managing python dependencies, avoiding version conflicts / ABI issues / etc. + +To build stochtree using a `conda` based workflow, first create and activate a conda environment with the requisite dependencies + +```bash +conda create -n stochtree-dev -c conda-forge python=3.10 numpy scipy pytest pandas pybind11 scikit-learn matplotlib seaborn +conda activate stochtree-dev +pip install jupyterlab +``` + +Then install the package by navigating to the stochtree directory and running + +```bash +pip install . +``` + +Note that if you are making changes and finding that they aren't reflected after a reinstall of stochtree, you can +clear all of the python package build artifacts with + +```bash +rm -rf stochtree.egg-info; rm -rf .pytest_cache; rm -rf build +``` + +and then rerun `pip install .` + +#### Venv + +You could also use venv for environment management. First, navigate to the folder in which you usually store virtual environments +(i.e. `cd /path/to/envs`) and create and activate a virtual environment: + +```bash +python -m venv venv +source venv/bin/activate +``` + +Install all of the package (and demo notebook) dependencies + +```bash +pip install numpy scipy pytest pandas scikit-learn pybind11 matplotlib seaborn jupyterlab +``` + +Then install the package by navigating to the stochtree directory and running + +```bash +pip install . +``` + +Note that if you are making changes and finding that they aren't reflected after a reinstall of stochtree, you can +clear all of the python package development artifacts with + +```bash +rm -rf stochtree.egg-info; rm -rf .pytest_cache; rm -rf build +``` + +and then rerun `pip install .` + +### C++ + +#### CMake + +The C++ project can be built independently from the R / Python packages using `cmake`. +See [here](https://cmake.org/install/) for details on installing cmake (alternatively, +on MacOS, `cmake` can be installed using [homebrew](https://formulae.brew.sh/formula/cmake)). +Once `cmake` is installed, you can build the CLI by navigating to the main +project directory at your command line (i.e. `cd /path/to/stochtree`) and +running the following code + +```bash +rm -rf build +mkdir build +cmake -S . -B build +cmake --build build +``` + +The CMake build has two primary targets, which are detailed below. + +##### Debug Program + +`debug/api_debug.cpp` defines a standalone target that can be straightforwardly run with a debugger (i.e. `lldb`, `gdb`) +while making non-trivial changes to the C++ code. +This debugging program is compiled as part of the CMake build if the `BUILD_DEBUG_TARGETS` option in `CMakeLists.txt` is set to `ON`. + +Once the program has been built, it can be run from the command line via `./build/debugstochtree` or attached to a debugger +via `lldb ./build/debugstochtree` (clang) or `gdb ./build/debugstochtree` (gcc). + +##### Unit Tests + +We test `stochtree` using the [GoogleTest](https://google.github.io/googletest/) framework. +Unit tests are compiled into a single target as part of the CMake build if the `BUILD_TEST` option is set to `ON` +and the test suite can be run after compilation via `./build/teststochtree`. + +## Debugging + +Debugging stochtree invariably leads to the "core" C++ codebase, which requires care to debug correctly. +Below we detail how to debug stochtree's C++ core through each of the three interfaces (C++, R and Python). + +### C++ Program + +The `debugstochtree` cmake target exists precisely to quickly debug the C++ core of stochtree. + +First, you must build the program using debug symbols, which you can do by enabling the `USE_DEBUG` option +and ensuring that `BUILD_DEBUG_TARGETS` is also switched on, as below + +```bash +rm -rf build +mkdir build +cmake -S . -B build -DBUILD_DEBUG_TARGETS=ON -DUSE_DEBUG=ON +cmake --build build +``` + +From here, you can debug at the command line using [lldb](https://lldb.llvm.org/) on MacOS or [gdb](https://sourceware.org/gdb/) on Linux by running +either `lldb ./build/debugstochtree` or `gdb ./build/debugstochtree` and using the appropriate shortcuts to navigate your program. + +#### Xcode + +While using `gdb` or `lldb` on `debugstochtree` at the command line is very helpful, users may prefer debugging in a full-fledged IDE like Xcode (if working on MacOS). +This project's C++ core can be converted to an Xcode project from `CMakeLists.txt`, but first you must turn off sanitizers +(Xcode has its own way of setting this at build time, and having injected +`-fsanitize=address` statically into compiler arguments will cause errors). To do this, modify the `USE_SANITIZER` line in `CMakeLists.txt`: + +``` +option(USE_SANITIZER "Use santizer flags" OFF) +``` + +To generate an Xcode project, navigate to the main project folder and run: + +```bash +rm -rf xcode/ +mkdir xcode +cd xcode +cmake -G Xcode .. -DCMAKE_C_COMPILER=cc -DCMAKE_CXX_COMPILER=c++ -DUSE_SANITIZER=OFF -DUSE_DEBUG=OFF +cd .. +``` + +Now, if you navigate to the xcode subfolder (in Finder), you should be able to click on a `.xcodeproj` file and the project will open in Xcode. + +### R Package + +Debugging stochtree R code requires building the R package with debug symbols. +The simplest way to do this is to open your R installation's `Makevars` file +by running `usethis::edit_r_makevars()` in RStudio which will open `Makevars` +in a code editor. + +If your `Makevars` file already has a line that begins with `CXX17FLAGS = ...`, +look for a `-g -O2` compiler flag and change this to `-g -O0`. If this flag isn't +set in the `CXX17FLAGS = ` line, then simply add `-g -O0` to this line after the ` = `. +If your `Makevars` file does not have a line that begins with `CXX17FLAGS = ...`, +add `CXX17FLAGS = -g -O0`. + +Now, rebuild the R package as above. Save the R code you'd like to debug to an R script. +Suppose for the sake of illustration that the code you want to debug is saved in +`path/to/debug_script.R`. + +At the command line (either the terminal in RStudio or your local terminal program), +run `R -d lldb` if you are using MacOS (or `R -d gdb` if you are using Linux). + +Now, you'll see an lldb prompt which should look like below with a blinking cursor after it + +``` +(lldb) +``` + +From there, you can set breakpoints, either to specific lines of specific files like `b src/tree.cpp:2117` or to break whenever there is an error using `breakpoint set -E c++`. +(**Note**: in gdb, the breakpoint and control flow commands are slightly different, see [here](https://www.maths.ed.ac.uk/~swood34/RCdebug/RCdebug.html) for more detail on debugging R through `gdb`.) +Now, you can run R through the debugger by typing + +``` +r +``` + +This should load an R console, from which you can execute a script you've set up to run your code using + +```r +source("path/to/debug_script.R") +``` + +The code will either stop when it hits your first line-based breakpoint or when it runs into an error if you set the error-based breakpoint. +From there, you can navigate using `lldb` (or `gdb`) commands. + +**Note**: once you've loaded the R console, you can also simply interactively run commands that call stochtree's C++ code (i.e. running the `bart()` or `bcf()` functions). If you're debugging at this level, you probably have a specific problem in mind, and using a repeatable script will be worth your while, but it is not strictly necessary. + +### Python Package + +First, you need to build stochtree's C++ extension with debug symbols. +As always, start by navigating to the stochtree directory (i.e. `cd /path/to/stochtree/`) +and activating your development virtual environment (i.e. `conda activate [env_name]` or `source venv/bin/activate`). + +Since stochtree builds its C++ extension via cmake [following this example](https://github.com/pybind/cmake_example), +you'll need to ensure that the `self.debug` field in the `CMakeBuild` class is set to `True`. +You can do this by setting an environment variable of `DEBUG` equal to 1. +In bash, you can do this with `export DEBUG=1` at the command line. + +Once this is done, build the python library using + +```bash +pip install . +``` + +Suppose you'd like to debug stochtree through a script called `/path/to/script.py`. + +First, target a python process with `lldb` (or, alternatively, replace with `gdb` below if you use `gcc` as your compiler) via + +``` +lldb python +``` + +Now, you'll see an lldb (or gdb) prompt which should look like below with a blinking cursor after it + +``` +(lldb) +``` + +From there, you can set breakpoints, either to specific lines of specific files like `b src/tree.cpp:2117` or to break whenever there is an error using `breakpoint set -E c++`. +(If you're using `gdb`, see [here](https://lldb.llvm.org/use/map.html) for a comparison between lldb commands and gdb commands for setting breakpoints and navigating your program.) +Now you can run your python script through the debugger by typing + +``` +r /path/to/script.py +``` + +The program will run until the first breakpoint is hit, and at this point you can navigate using lldb (or gdb) commands. + +**Note**: rather than running a script like `/path/to/script.py` above, you can also simply load the python console by typing `r` at the `(lldb)` terminal and then interactively run commands that call stochtree's C++ code (i.e. sampling `BARTModel` or `BCFModel` objects). If you're debugging at this level, you probably have a specific problem in mind, and using a repeatable script will be worth your while, but it is not strictly necessary. diff --git a/development/index.qmd b/development/index.qmd new file mode 100644 index 000000000..9a308254b --- /dev/null +++ b/development/index.qmd @@ -0,0 +1,9 @@ +--- +title: "Development" +--- + +`stochtree` is in active development. Here, we detail some aspects of the development process: + +* [Contributing](contributing.qmd): how to get involved with stochtree, by contributing code, documentation, or helpful feedback +* [Adding New Models](new-models.qmd): how to add a new outcome model in C++ and make it available through the R and Python frontends +* [Roadmap](roadmap.qmd): timelines for new feature development and releases diff --git a/development/new-models.qmd b/development/new-models.qmd new file mode 100644 index 000000000..90a6cda11 --- /dev/null +++ b/development/new-models.qmd @@ -0,0 +1,273 @@ +--- +title: "Adding New Models to stochtree" +--- + +While the process of working with `stochtree`'s codebase to add +functionality or fix bugs is covered in the [contributing](contributing.qmd) +page, this page discusses a specific type of contribution in detail: +contributing new models (i.e. likelihoods and leaf parameter priors). + +Our C++ core is designed to support any conditionally-conjugate model, but this flexibility requires some explanation in order to be easily modified. + +## Overview + +The key components of `stochtree`'s models are: + +1. A **SuffStat** class that stores and accumulates sufficient statistics +2. A **LeafModel** class that computes marginal likelihoods / posterior parameters and samples leaf node parameters + +Each model implements a different version of these two classes. For example, the "classic" +BART model with constant Gaussian leaves and a Gaussian likelihood is represented by the +`GaussianConstantSuffStat` and `GaussianConstantLeafModel` classes. + +Each class implements a common API, and we use a [factory pattern](https://en.wikipedia.org/wiki/Factory_(object-oriented_programming)) and the C++17 +[std::variant](https://www.cppreference.com/w/cpp/utility/variant.html) +feature to dispatch the correct model at runtime. +Finally, R and Python wrappers expose this flexibility through the BART / BCF interfaces. + +Adding a new leaf model thus requires implementing new `SuffStat` and `LeafModel` +classes, then updating the factory functions and R / Python logic. + +## SuffStat Class + +As a pattern, sufficient statistic classes end in `*SuffStat` and implement several methods: + +* `IncrementSuffStat`: Increment a model's sufficient statistics by one data observation +* `ResetSuffStat`: Reset a model's sufficient statistics to zero / empty +* `AddSuffStat`: Combine two sufficient statistics, storing their sum in the sufficient statistic object that calls this method (without modifying the supplied `SuffStat` objects) +* `SubtractSuffStat`: Same as above but subtracting the second `SuffStat` argument from the first, rather than adding +* `SampleGreaterThan`: Checks whether the current sample size of a `SuffStat` object is greater than some threshold +* `SampleGreaterThanEqual`: Checks whether the current sample size of a `SuffStat` object is greater than or equal to some threshold +* `SampleSize`: Returns the current sample size of a `SuffStat` object + +For the sake of illustration, imagine we are adding a model called `OurNewModel`. The new sufficient statistic class should look something like: + +```cpp +class OurNewModelSuffStat { + public: + data_size_t n; + // Custom sufficient statistics for `OurNewModel` + double stat1; + double stat2; + + OurNewModelSuffStat() { + n = 0; + stat1 = 0.0; + stat2 = 0.0; + } + + void IncrementSuffStat(ForestDataset& dataset, Eigen::VectorXd& outcome, + ForestTracker& tracker, data_size_t row_idx, int tree_idx) { + n += 1; + stat1 += /* accumulate from outcome, dataset, or tracker as needed */; + stat2 += /* accumulate from outcome, dataset, or tracker as needed */; + } + + void ResetSuffStat() { + n = 0; + stat1 = 0.0; + stat2 = 0.0; + } + + void AddSuffStat(OurNewModelSuffStat& lhs, OurNewModelSuffStat& rhs) { + n = lhs.n + rhs.n; + stat1 = lhs.stat1 + rhs.stat1; + stat2 = lhs.stat2 + rhs.stat2; + } + + void SubtractSuffStat(OurNewModelSuffStat& lhs, OurNewModelSuffStat& rhs) { + n = lhs.n - rhs.n; + stat1 = lhs.stat1 - rhs.stat1; + stat2 = lhs.stat2 - rhs.stat2; + } + + bool SampleGreaterThan(data_size_t threshold) { return n > threshold; } + bool SampleGreaterThanEqual(data_size_t threshold) { return n >= threshold; } + data_size_t SampleSize() { return n; } +}; +``` + +## LeafModel Class + +Leaf model classes end in `*LeafModel` and implement several methods: + +* `SplitLogMarginalLikelihood`: the log marginal likelihood of a potential split, as a function of the sufficient statistics for the newly proposed left and right node (i.e. ignoring data points unaffected by a split) +* `NoSplitLogMarginalLikelihood`: the log marginal likelihood of a node without splitting, as a function of the sufficient statistics for that node +* `SampleLeafParameters`: Sample the leaf node parameters for every leaf in a provided tree, according to this model's conditionally conjugate leaf node posterior +* `RequiresBasis`: Whether or not a model requires regressing on "basis functions" in the leaves + +As above, imagine that we are implementing a new model called `OurNewModel`. The new leaf model class should look something like: + +```cpp +class OurNewModelLeafModel { + public: + OurNewModelLeafModel(/* model parameters */) { + // Set model parameters + } + + double SplitLogMarginalLikelihood(OurNewModelSuffStat& left_stat, + OurNewModelSuffStat& right_stat, + double global_variance) { + double left_log_ml = /* calculate left node log ML */; + double right_log_ml = /* calculate right node log ML */; + return left_log_ml + right_log_ml; + } + + double NoSplitLogMarginalLikelihood(OurNewModelSuffStat& suff_stat, + double global_variance) { + double log_ml = /* calculate node log ML */; + return log_ml; + } + + void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, + ColumnVector& residual, Tree* tree, int tree_num, + double global_variance, std::mt19937& gen) { + // Sample parameters for every leaf in a tree, update `tree` directly + } + + inline bool RequiresBasis() { return /* true/false based on your model */; } + + // Helper methods below for `SampleLeafParameters`, which depend on the + // nature of the leaf model (i.e. location-scale, shape-scale, etc...) + + double PosteriorParameterMean(OurNewModelSuffStat& suff_stat, + double global_variance) { + return /* calculate posterior mean */; + } + + double PosteriorParameterVariance(OurNewModelSuffStat& suff_stat, + double global_variance) { + return /* calculate posterior variance */; + } + + private: + // Leaf model parameters + double param1_; + double param2_; +}; +``` + +## Factory Functions + +Updating the factory pattern to be able to dispatch `OurNewModel` has several steps. + +First, we add our model to the `ModelType` enum in `include/stochtree/leaf_model.h`: + +```cpp +enum ModelType { + kConstantLeafGaussian, + kUnivariateRegressionLeafGaussian, + kMultivariateRegressionLeafGaussian, + kLogLinearVariance, + kOurNewModel // New model +}; +``` + +Next, we add the `OurNewModelSuffStat` and `OurNewModelLeafModel` classes to the `std::variant` unions in `include/stochtree/leaf_model.h`: + +```cpp +using SuffStatVariant = std::variant; // New model + +using LeafModelVariant = std::variant; // New model +``` + +Finally, we update the factory functions to dispatch the correct class from the union based on the `ModelType` integer code + +```cpp +static inline SuffStatVariant suffStatFactory(ModelType model_type, int basis_dim = 0) { + if (model_type == kConstantLeafGaussian) { + return createSuffStat(); + } else if (model_type == kUnivariateRegressionLeafGaussian) { + return createSuffStat(); + } else if (model_type == kMultivariateRegressionLeafGaussian) { + return createSuffStat(basis_dim); + } else if (model_type == kLogLinearVariance) { + return createSuffStat(); + } else if (model_type == kOurNewModel) { // New model + return createSuffStat(); + } else { + Log::Fatal("Incompatible model type provided to suff stat factory"); + } +} + +static inline LeafModelVariant leafModelFactory(ModelType model_type, double tau, + Eigen::MatrixXd& Sigma0, double a, double b) { + if (model_type == kConstantLeafGaussian) { + return createLeafModel(tau); + } else if (model_type == kUnivariateRegressionLeafGaussian) { + return createLeafModel(tau); + } else if (model_type == kMultivariateRegressionLeafGaussian) { + return createLeafModel(Sigma0); + } else if (model_type == kLogLinearVariance) { + return createLeafModel(a, b); + } else if (model_type == kOurNewModel) { // New model + return createLeafModel(/* initializer values */); + } else { + Log::Fatal("Incompatible model type provided to leaf model factory"); + } +} +``` + +## R Wrapper + +To reflect this change through to the R interface, we first add the new model to the logic in the `sample_gfr_one_iteration_cpp` +and `sample_mcmc_one_iteration_cpp` functions in the `src/sampler.cpp` file + +```cpp +// Convert leaf model type to enum +StochTree::ModelType model_type; +if (leaf_model_int == 0) model_type = StochTree::ModelType::kConstantLeafGaussian; +else if (leaf_model_int == 1) model_type = StochTree::ModelType::kUnivariateRegressionLeafGaussian; +else if (leaf_model_int == 2) model_type = StochTree::ModelType::kMultivariateRegressionLeafGaussian; +else if (leaf_model_int == 3) model_type = StochTree::ModelType::kLogLinearVariance; +else if (leaf_model_int == 4) model_type = StochTree::ModelType::kOurNewModel; // New model +else StochTree::Log::Fatal("Invalid model type"); +``` + +Then we add the integer code for `OurNewModel` to the `leaf_model_type` field signature in `R/config.R` + +```r +#' @field leaf_model_type Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression, 4 = your new model) +leaf_model_type = NULL, +``` + +## Python Wrapper + +Python's C++ wrapper code contains similar logic to that of the `src/sampler.cpp` file in the R interface. +Add the new model to the `SampleOneIteration` method of the `ForestSamplerCpp` class in the `src/py_stochtree.cpp` file. + +```cpp +// Convert leaf model type to enum +StochTree::ModelType model_type; +if (leaf_model_int == 0) model_type = StochTree::ModelType::kConstantLeafGaussian; +else if (leaf_model_int == 1) model_type = StochTree::ModelType::kUnivariateRegressionLeafGaussian; +else if (leaf_model_int == 2) model_type = StochTree::ModelType::kMultivariateRegressionLeafGaussian; +else if (leaf_model_int == 3) model_type = StochTree::ModelType::kLogLinearVariance; +else if (leaf_model_int == 4) model_type = StochTree::ModelType::kOurNewModel; // New model +else StochTree::Log::Fatal("Invalid model type"); +``` + +And then add the integer code for your new model to the `leaf_model_type` documentation in `stochtree/config.py`. + +## Additional Considerations + +Some of the `SuffStat` and `LeafModel` classes currently supported by stochtree require extra initialization parameters. +We support this via [variadic templates](https://en.cppreference.com/w/cpp/language/parameter_pack.html) in C++ + +```cpp +template +static inline void GFRSampleOneIter(TreeEnsemble& active_forest, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, + ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, + std::vector& sweep_update_indices, double global_variance, std::vector& feature_types, int cutpoint_grid_size, + bool keep_forest, bool pre_initialized, bool backfitting, int num_features_subsample, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) +``` + +If your new classes take any initialization arguments, these are provided in the factory functions, so you might also need to edit the signature of the factory functions. diff --git a/development/roadmap.qmd b/development/roadmap.qmd new file mode 100644 index 000000000..f3715e5e5 --- /dev/null +++ b/development/roadmap.qmd @@ -0,0 +1,23 @@ +--- +title: "Development Roadmap" +--- + +We are working hard to make `stochtree` faster, easier to use, and more flexible! Below is a snapshot of our development roadmap. We categorize new product enhancements into four categories: + +1. **User Interface**: the way that a user can build, store, and use models +2. **Performance**: program runtime and memory usage of various models +3. **Modeling Features**: scope of modeling tools provided +4. **Interoperability**: compatibility with other computing and data libraries + +Our development goals are prioritized along three broad timelines + +1. **Now**: development is currently underway or planned for a near-term release +2. **Next**: design / research needed; development hinges on feasibility and time demands +3. **Later**: long-term goal; exploratory + +| Category | Now | Next | Later | +| --- | --- | --- | --- | +| User Interface | | | | +| Performance | | | Hardware acceleration (Apple Silicon GPU)
Hardware acceleration (NVIDIA GPU)
Out-of-memory sampler | +| Modeling Features | Quantile cutpoint sampling
Probit BART and BCF | Monotonicity constraints
Multiclass classification | | +| Interoperability | | | PyMC (Python)
Stan (R / Python)
Apache Arrow (R / Python)
Polars (Python) | diff --git a/getting-started.qmd b/getting-started.qmd new file mode 100644 index 000000000..f52461ef6 --- /dev/null +++ b/getting-started.qmd @@ -0,0 +1,182 @@ +--- +title: "Getting Started" +--- + +`stochtree` is composed of a C++ "core" and R / Python interfaces to that core. +Below, we detail how to install the R / Python packages, or work directly with the C++ codebase. + +## R Package + +### CRAN + +The R package can be installed from CRAN via + +```r +install.packages("stochtree") +``` + +### Development Version (Local Build) + +The development version of `stochtree` can be installed from Github via + +```r +remotes::install_github("StochasticTree/stochtree", ref="r-dev") +``` + +## Python Package + +### PyPI + +`stochtree`'s Python package can be installed from PyPI via + +```bash +pip install stochtree +``` + +### Development Version (Local Build) + +The development version of `stochtree` can be installed from source using pip's [git interface](https://pip.pypa.io/en/stable/topics/vcs-support/). +To proceed, you will need a working version of [git](https://git-scm.com) and python 3.8 or greater (available from several sources, one of the most +straightforward being the [anaconda](https://docs.conda.io/projects/conda/en/stable/user-guide/install/index.html) suite). + +#### Quick start + +Without worrying about virtual environments (detailed further below), `stochtree` can be installed from the command line + +```bash +pip install numpy scipy pytest pandas scikit-learn pybind11 +pip install git+https://github.com/StochasticTree/stochtree.git +``` + +#### Virtual environment installation + +Often, users prefer to manage different projects (with different package / python version requirements) in virtual environments. + +##### Conda + +Conda provides a straightforward experience in managing python dependencies, avoiding version conflicts / ABI issues / etc. + +To build stochtree using a `conda` based workflow, first create and activate a conda environment with the requisite dependencies + +```bash +conda create -n stochtree-dev -c conda-forge python=3.10 numpy scipy pytest pandas pybind11 scikit-learn +conda activate stochtree-dev +``` + +Then install the package from github via pip + +```bash +pip install git+https://github.com/StochasticTree/stochtree.git +``` + +(*Note*: if you'd like to run `stochtree`'s notebook examples, you will also need `jupyterlab`, `seaborn`, and `matplotlib`) + +```bash +conda install matplotlib seaborn +pip install jupyterlab +``` + +##### Venv + +You could also use venv for environment management. First, navigate to the folder in which you usually store virtual environments +(i.e. `cd /path/to/envs`) and create and activate a virtual environment: + +```bash +python -m venv venv +source venv/bin/activate +``` + +Install all of the package (and demo notebook) dependencies + +```bash +pip install numpy scipy pytest pandas scikit-learn pybind11 +``` + +Then install stochtree via + +```bash +pip install git+https://github.com/StochasticTree/stochtree.git +``` + +As above, if you'd like to run the notebook examples, you will also need `jupyterlab`, `seaborn`, and `matplotlib`: + +```bash +pip install matplotlib seaborn jupyterlab +``` + +## C++ Core + +While the C++ core links to both R and Python for a performant, high-level interface, +the C++ code can be compiled and unit-tested and compiled into a standalone +[debug program](https://github.com/StochasticTree/stochtree/tree/main/debug). + +### Compilation + +#### Cloning the Repository + +To clone the repository, you must have git installed, which you can do following [these instructions](https://learn.microsoft.com/en-us/devops/develop/git/install-and-set-up-git). + +Once git is available at the command line, navigate to the folder that will store this project (in bash / zsh, this is done by running `cd` followed by the path to the directory). +Then, clone the `stochtree` repo as a subfolder by running + +```bash +git clone --recursive https://github.com/StochasticTree/stochtree.git +``` + +*NOTE*: this project incorporates several dependencies as [git submodules](https://git-scm.com/book/en/v2/Git-Tools-Submodules), +which is why the `--recursive` flag is necessary (some systems may perform a recursive clone without this flag, but +`--recursive` ensures this behavior on all platforms). If you have already cloned the repo without the `--recursive` flag, +you can retrieve the submodules recursively by running `git submodule update --init --recursive` in the main repo directory. + +#### CMake Build + +The C++ project can be built independently from the R / Python packages using `cmake`. +See [here](https://cmake.org/install/) for details on installing cmake (alternatively, +on MacOS, `cmake` can be installed using [homebrew](https://formulae.brew.sh/formula/cmake)). +Once `cmake` is installed, you can build the CLI by navigating to the main +project directory at your command line (i.e. `cd /path/to/stochtree`) and +running the following code + +```bash +rm -rf build +mkdir build +cmake -S . -B build +cmake --build build +``` + +The CMake build has two primary targets, which are detailed below. + +##### Debug Program + +`debug/api_debug.cpp` defines a standalone target that can be straightforwardly run with a debugger (i.e. `lldb`, `gdb`) +while making non-trivial changes to the C++ code. +This debugging program is compiled as part of the CMake build if the `BUILD_DEBUG_TARGETS` option in `CMakeLists.txt` is set to `ON`. + +Once the program has been built, it can be run from the command line via `./build/debugstochtree` or attached to a debugger +via `lldb ./build/debugstochtree` (clang) or `gdb ./build/debugstochtree` (gcc). + +##### Unit Tests + +We test `stochtree` using the [GoogleTest](https://google.github.io/googletest/) framework. +Unit tests are compiled into a single target as part of the CMake build if the `BUILD_TEST` option is set to `ON` +and the test suite can be run after compilation via `./build/teststochtree`. + +### Xcode + +While using `gdb` or `lldb` on `debugstochtree` at the command line is very helpful, users may prefer debugging in a full-fledged IDE like Xcode. This project's C++ core can be converted to an Xcode project from `CMakeLists.txt`, but first you must turn off sanitizers. To do this, modify the `USE_SANITIZER` line in `CMakeLists.txt`: + +``` +option(USE_SANITIZER "Use santizer flags" OFF) +``` + +To generate an Xcode project, navigate to the main project folder and run: + +```bash +rm -rf xcode/ +mkdir xcode +cd xcode +cmake -G Xcode .. -DCMAKE_C_COMPILER=cc -DCMAKE_CXX_COMPILER=c++ -DUSE_SANITIZER=OFF -DUSE_DEBUG=OFF +cd .. +``` + +Now, if you navigate to the xcode subfolder (in Finder), you should be able to click on a `.xcodeproj` file and the project will open in Xcode. diff --git a/index.qmd b/index.qmd new file mode 100644 index 000000000..da5561fad --- /dev/null +++ b/index.qmd @@ -0,0 +1,45 @@ +--- +title: "StochTree" +--- + +`stochtree` (short for "stochastic trees") unlocks flexible decision tree modeling in R or Python. + +## What does the software do? + +Boosted decision tree models (like [xgboost](https://xgboost.readthedocs.io/en/stable/), +[LightGBM](https://lightgbm.readthedocs.io/en/latest/), or +[scikit-learn's HistGradientBoostingRegressor](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.HistGradientBoostingRegressor.html)) +are great, but often require time-consuming hyperparameter tuning. +`stochtree` can help you avoid this, by running a fast Bayesian analog of gradient boosting (called BART -- Bayesian Additive Regression Trees). + +`stochtree` has two primary interfaces: + +1. "High-level": robust implementations of many popular stochastic tree algorithms (BART, XBART, BCF, XBCF), with support for serialization and parallelism. +2. "Low-level": access to the "inner loop" of a forest sampler, allowing custom tree algorithm development in <50 lines of code. + +The "core" of the software is written in C++, but it provides R and Python APIs. +The R package is [available on CRAN](https://cran.r-project.org/web/packages/stochtree/index.html) and the Python package is [available on PyPI](https://pypi.org/project/stochtree/). + +## Why "stochastic" trees? + +"Stochastic" loosely means the same thing as "random." This naturally raises the question: how is `stochtree` different from a random forest library? +At a superficial level, both are decision tree ensembles that use randomness in training. + +The difference lies in how that "randomness" is deployed. +Random forests take random subsets of a training dataset, and then run a deterministic decision tree fitting algorithm ([recursive partitioning](https://en.wikipedia.org/wiki/Recursive_partitioning)). +Stochastic tree algorithms use randomness to construct decision tree ensembles from a fixed training dataset. + +The original stochastic tree model, [Bayesian Additive Regression Trees (BART)](https://projecteuclid.org/journals/annals-of-applied-statistics/volume-4/issue-1/BART-Bayesian-additive-regression-trees/10.1214/09-AOAS285.full), used [Markov Chain Monte Carlo (MCMC)](https://en.wikipedia.org/wiki/Markov_chain_Monte_Carlo) to sample forests from their posterior distribution. + +So why not call our project `bayesiantree`? + +Some algorithms implemented in `stochtree` are "quasi-Bayesian" in that they are inspired by a Bayesian model, but are sampled with fast algorithms that do not provide a valid Bayesian posterior distribution. + +Moreover, we think of stochastic forests as general-purpose modeling tools. +What makes them useful is their strong empirical performance -- especially on small or noisy datasets -- not their adherence to any statistical framework. + +So why not just call our project `decisiontree`? + +Put simply, the sampling approach is part of what makes BART and other `stochtree` algorithms work so well -- we know because we have tested out versions that did not do stochastic sampling of the tree fits. + +So we settled on the term "stochastic trees", or "stochtree" for short (pronounced "stoke-tree"). diff --git a/requirements.txt b/requirements.txt index 382bb4851..a7255f35c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,5 +8,5 @@ matplotlib seaborn mkdocs-material mkdocstrings-python -mkdocs-jupyter +mkdocs-jupyter<0.25 arviz[all] diff --git a/vignettes/.gitignore b/vignettes/.gitignore index 6041614a6..0ff578e35 100644 --- a/vignettes/.gitignore +++ b/vignettes/.gitignore @@ -1,4 +1,9 @@ /.quarto/ **/*.quarto_ipynb _freeze/ -_site/ \ No newline at end of file +_site/ +*.Renviron +*.json +*_files/ +*.rds +*_libs/ diff --git a/vignettes/Python/IV/IV_CDAG.png b/vignettes/Python/IV/IV_CDAG.png deleted file mode 100644 index 7900ff501..000000000 Binary files a/vignettes/Python/IV/IV_CDAG.png and /dev/null differ diff --git a/vignettes/Python/IV/iv.html b/vignettes/Python/IV/iv.html deleted file mode 100644 index c99a5e793..000000000 --- a/vignettes/Python/IV/iv.html +++ /dev/null @@ -1,8882 +0,0 @@ - - - - - -iv - - - - - - - - - - - - -
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
- - diff --git a/vignettes/Python/IV/iv.ipynb b/vignettes/Python/IV/iv.ipynb deleted file mode 100644 index 053fcc77d..000000000 --- a/vignettes/Python/IV/iv.ipynb +++ /dev/null @@ -1,1130 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Instrumental Variables (IV) with `stochtree`\n", - "\n", - "### P. Richard Hahn, Arizona State University\n", - "\n", - "### Drew Herren, University of Texas at Austin\n", - "\n", - "### 2025-04-26" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Introduction\n", - "\n", - "Here we consider a causal inference problem with a binary treatment and a binary outcome where there is unobserved confounding, but an exogenous instrument is available (also binary). This problem will require a number of extensions to the basic BART model, all of which can be implemented straightforwardly as Gibbs samplers using `stochtree`. We'll go through all of the model fitting steps in quite a lot of detail here.\n", - "\n", - "## Background\n", - "\n", - "To be concrete, suppose we wish to measure the effect of receiving a flu vaccine on the probability of getting the flu. Individuals who opt to get a flu shot differ in many ways from those that don't, and these lifestyle differences presumably also affect their respective chances of getting the flu. Consequently, comparing the percentage of individuals who get the flu in the vaccinated and unvaccinated groups does not give a clear picture of the vaccine efficacy. \n", - "\n", - "However, a so-called encouragement design can be implemented, where some individuals are selected at random to be given some extra incentive to get a flu shot (free clinics at the workplace or a personalized reminder, for example). Studying the impact of this randomized encouragement allows us to tease apart the impact of the vaccine from the confounding factors, at least to some extent. This exact problem has been considered several times in the literature, starting with McDonald, Hiu, and Tierny (1992) with follow-on analysis by Hirano et. al. (2000), Richardson and Robins (2011), and Imbens and Rubin (2015).\n", - "\n", - "Our analysis here follows the Bayesian nonparametric approach described in the supplement to Hahn, Murray, and Manolopoulou (2016)." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "First, load requisite libraries" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "from scipy.stats import norm\n", - "\n", - "from stochtree import (\n", - " RNG,\n", - " Dataset,\n", - " Forest,\n", - " ForestContainer,\n", - " ForestSampler,\n", - " Residual, \n", - " ForestModelConfig, \n", - " GlobalModelConfig,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Notation\n", - "\n", - "Let $V$ denote the treatment variable (as in \"vaccine\"). Let $Y$ denote the response variable (getting the flu), $Z$ denote the instrument (encouragement or reminder to get a flu shot), and $X$ denote an additional observable covariate (for instance, patient age).\n", - "\n", - "Further, let $S$ denote the so-called *principal strata*, which is an exhaustive characterization of how individuals' might be affected by the encouragement regarding the flu shot. Some people will get a flu shot no matter what: these are the *always takers* (a). Some people will not get the flu shot no matter what: these are the *never takers* (n). For both always-takers and never-takers, the randomization of the encouragement is irrelevant and our data set contains no always takers who skipped the vaccine and no never takers who got the vaccine and so the treatment effect of the vaccine in these groups is fundamentally non-identifiable. \n", - "\n", - "By contrast, we also have *compliers* (c): folks who would not have gotten the shot but for the fact that they were encouraged to do so. These are the people about whom our randomized encouragement provides some information, because they are precisely the ones that have been randomized to treatment. \n", - "\n", - "Lastly, we could have *defiers* (d): contrarians who who were planning on getting the shot, but -- upon being reminded -- decided not to! For our analysis we will do the usual thing of assuming that there are no defiers. And because we are going to simulate our data, we can make sure that this assumption is true." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## The causal diagram\n", - "\n", - "The causal diagram for this model can be expressed as follows. Here we are considering one confounder and moderator variable ($X$), which is the patient's age. In our data generating process (which we know because this is a simulation demonstration) higher age will make it more likely that a person is an always taker or complier and less likely that they are a never taker, which in turn has an effect on flu risk. We stipulate here that always takers are at lower risk and never takers at higher risk. Simultaneously, age has an increasing and then decreasing direct effect on flu risk; very young and very old are at higher risk, while young and middle age adults are at lower risk. In this DGP the flu efficacy has a multiplicative effect, reducing flu risk as a fixed proportion of baseline risk -- accordingly, the treatment effect (as a difference) is nonlinear in Age (for each principal stratum).\n", - "\n", - "![IV_CDAG](IV_CDAG.png)\n", - "\n", - "The biggest question about this graph concerns the dashed red arrow from the putative instrument $Z$ to the outcome (flu). We say \"putative\" because if that dashed red arrow is there, then technically $Z$ is not a valid instrument. The assumption/assertion that there is no dashed red arrow is called the \"exclusion restriction\". In this vignette, we will explore what sorts of inferences are possible if we remain agnostic about the presence or absence of that dashed red arrow." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Potential outcomes\n", - "\n", - "There are two relevant potential outcomes in an instrumental variables analysis, corresponding to the causal effect of the instrument on the treatment and the causal effect of the treatment on the outcome. In this example, that is the effect of the reminder/encouragement on vaccine status and the effect of the vaccine itself on the flu. The notation is $V(Z)$ and $Y(V(Z),Z)$ respectively, so that we have six distinct random variables: $V(0)$, $V(1)$, $Y(0,0)$, $Y(1,0)$, $Y(0,1)$ and $Y(1,1)$. The problem -- sometimes called the *fundamental problem of causal inference* -- is that some of these random variables can never be seen simultaneously, they are observationally mutually exclusive. For this reason, it may be helpful to think about causal inference as a missing data problem, as depicted in the following table." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\n", - "\n", - "| $i$ | $Z_i$ | $V_i(0)$ | $V_i(1)$ | $Y_i(0,0)$ | $Y_i(1,0)$ | $Y_i(0,1)$ | $Y_i(1,1)$ |\n", - "| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n", - "| 1 | 1 | ? | 1 | ? | ? | ? | 0 |\n", - "| 2 | 0 | 1 | ? | ? | 1 | ? | ? |\n", - "| 3 | 0 | 0 | ? | 1 | ? | ? | ? |\n", - "| 4 | 1 | ? | 0 | ? | ? | 0 | ? |" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Likewise, with this notation we can formally define the principal strata:" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "| $V_i(0)$ | $V_i(1)$ | $S_i$ |\n", - "| :---: | :---: | :---: |\n", - "| 0 | 0 | Never Taker ($n$) |\n", - "| 1 | 1 | Always Taker ($a$) |\n", - "| 0 | 1 | Complier ($c$) |\n", - "| 1 | 0 | Defier ($d$) |" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Estimands and Identification\n", - "\n", - "Let $\\pi_s(x)$ denote the conditional (on $x$) probability that an individual belongs to principal stratum $s$:\n", - "\n", - "\\begin{equation}\n", - "\\pi_s(x)=\\operatorname{Pr}(S=s \\mid X=x),\n", - "\\end{equation}\n", - "\n", - "and let $\\gamma_s^{v z}(x)$ denote the potential outcome probability for given values $v$ and $z$:\n", - "\n", - "\\begin{equation}\n", - "\\gamma_s^{v z}(x)=\\operatorname{Pr}(Y(v, z)=1 \\mid S=s, X=x)\n", - "\\end{equation}\n", - "\n", - "Various estimands of interest may be expressed in terms of the functions $\\gamma_c^{vz}(x)$. In particular, the complier conditional average treatment effect $$\\gamma_c^{1,z}(x) - \\gamma_c^{0,z}(x)$$ is the ultimate goal (for either $z=0$ or $z=1$). Under an exclusion restriction, we would have $\\gamma_s^{vz}(x) = \\gamma_s^{v}(x)$ and the reminder status $z$ itself would not matter. In that case, we can estimate $$\\gamma_c^{1,z}(x) - \\gamma_c^{0,z}$$ and $$\\gamma_c^{1,1}(x) - \\gamma_c^{0,0}(x).$$ This latter quantity is called the complier intent-to-treat effect, or $ITT_c$, and it can be partially identify even if the exclusion restriction is violated, as follows. \n", - "\n", - "The left-hand side of the following system of equations are all estimable quantities that can be learned from observable data, while the right hand side expressions involve the unknown functions of interest, $\\gamma_s^{vz}(x)$:\n", - "\n", - "\\begin{equation}\n", - "\\begin{aligned}\n", - "p_{1 \\mid 00}(x) = \\operatorname{Pr}(Y=1 \\mid V=0, Z=0, X=x)=\\frac{\\pi_c(x)}{\\pi_c(x)+\\pi_n(x)} \\gamma_c^{00}(x)+\\frac{\\pi_n(x)}{\\pi_c(x)+\\pi_n(x)} \\gamma_n^{00}(x) \\\\\n", - "p_{1 \\mid 11}(x) =\\operatorname{Pr}(Y=1 \\mid V=1, Z=1, X=x)=\\frac{\\pi_c(x)}{\\pi_c(x)+\\pi_a(x)} \\gamma_c^{11}(x)+\\frac{\\pi_a(x)}{\\pi_c(x)+\\pi_a(x)} \\gamma_a^{11}(x) \\\\\n", - "p_{1 \\mid 01}(x) =\\operatorname{Pr}(Y=1 \\mid V=0, Z=1, X=x)=\\frac{\\pi_d(x)}{\\pi_d(x)+\\pi_n(x)} \\gamma_d^{01}(x)+\\frac{\\pi_n(x)}{\\pi_d(x)+\\pi_n(x)} \\gamma_n^{01}(x) \\\\\n", - "p_{1 \\mid 10}(x) =\\operatorname{Pr}(Y=1 \\mid V=1, Z=0, X=x)=\\frac{\\pi_d(x)}{\\pi_d(x)+\\pi_a(x)} \\gamma_d^{10}(x)+\\frac{\\pi_a(x)}{\\pi_d(x)+\\pi_a(x)} \\gamma_a^{10}(x)\n", - "\\end{aligned}\n", - "\\end{equation}\n", - "\n", - "Furthermore, we have\n", - "\n", - "\\begin{equation}\n", - "\\begin{aligned}\n", - "\\operatorname{Pr}(V=1 \\mid Z=0, X=x)&=\\pi_a(x)+\\pi_d(x)\\\\\n", - "\\operatorname{Pr}(V=1 \\mid Z=1, X=x)&=\\pi_a(x)+\\pi_c(x)\n", - "\\end{aligned}\n", - "\\end{equation}\n", - "\n", - "Under the monotonicy assumption, $\\pi_d(x) = 0$ and these expressions simplify somewhat.\n", - "\n", - "\\begin{equation}\n", - "\\begin{aligned}\n", - "p_{1 \\mid 00}(x)&=\\frac{\\pi_c(x)}{\\pi_c(x)+\\pi_n(x)} \\gamma_c^{00}(x)+\\frac{\\pi_n(x)}{\\pi_c(x)+\\pi_n(x)} \\gamma_n^{00}(x) \\\\\n", - "p_{1 \\mid 11}(x)&=\\frac{\\pi_c(x)}{\\pi_c(x)+\\pi_a(x)} \\gamma_c^{11}(x)+\\frac{\\pi_a(x)}{\\pi_c(x)+\\pi_a(x)} \\gamma_a^{11}(x) \\\\\n", - "p_{1 \\mid 01}(x)&=\\gamma_n^{01}(x) \\\\\n", - "p_{1 \\mid 10}(x)&=\\gamma_a^{10}(x)\n", - "\\end{aligned}\n", - "\\end{equation}\n", - "\n", - "and\n", - "\n", - "\\begin{equation}\n", - "\\begin{aligned}\n", - "\\operatorname{Pr}(V=1 \\mid Z=0, X=x)&=\\pi_a(x)\\\\\n", - "\\operatorname{Pr}(V=1 \\mid Z=1, X=x)&=\\pi_a(x)+\\pi_c(x)\n", - "\\end{aligned}\n", - "\\end{equation}\n", - "\n", - "The exclusion restriction would dictate that $\\gamma_s^{01}(x) = \\gamma_s^{00}(x)$ and $\\gamma_s^{11}(x) = \\gamma_s^{10}(x)$ for all $s$. This has two implications. One, $\\gamma_n^{01}(x) = \\gamma_n^{00}(x)$ and $\\gamma_a^{10}(x) = \\gamma_a^{11}(x)$,and because the left-hand terms are identified, this permits $\\gamma_c^{11}(x)$ and $\\gamma_c^{00}(x)$ to be solved for by substitution. Two, with these two quantities solved for, we also have the two other quantities (the different settings of $z$), since $\\gamma_c^{11}(x) = \\gamma_c^{10}(x)$ and $\\gamma_c^{00}(x) = \\gamma_c^{01}(x)$. Consequently, both of our estimands from above can be estimated:\n", - "\n", - "$$\\gamma_c^{11}(x) - \\gamma_c^{01}(x)$$\n", - "and \n", - "\n", - "$$\\gamma_c^{10}(x) - \\gamma_c^{00}(x)$$\n", - "because they are both (supposing the exclusion restriction holds) the same as\n", - "\n", - "$$\\gamma_c^{11}(x) - \\gamma_c^{00}(x).$$\n", - "If the exclusion restriction does *not* hold, then the three above treatment effects are all (potentially) distinct and not much can be said about the former two. The latter one, the $ITT_c$, however, can be partially identified, by recognizing that the first two equations (in our four equation system) provide non-trivial bounds based on the fact that while $\\gamma_c^{11}(x)$ and $\\gamma_c^{00}(x)$ are no longer identified, as probabilities both must lie between 0 and 1. Thus, \n", - "\n", - "\\begin{equation}\n", - "\\begin{aligned}\n", - "\t\\max\\left(\n", - "\t\t0, \\frac{\\pi_c(x)+\\pi_n(x)}{\\pi_c(x)}p_{1\\mid 00}(x) - \\frac{\\pi_n(x)}{\\pi_c(x)}\n", - "\t\\right)\n", - "&\\leq\\gamma^{00}_c(x)\\leq\n", - "\t\\min\\left(\n", - "\t\t1, \\frac{\\pi_c(x)+\\pi_n(x)}{\\pi_c(x)}p_{1\\mid 00}(x)\n", - "\t\\right)\\\\\\\\\n", - "%\n", - "\\max\\left(\n", - " 0, \\frac{\\pi_a(x)+\\pi_c(x)}{\\pi_c(x)}p_{1\\mid 11}(x) - \\frac{\\pi_a(x)}{\\pi_c(x)}\n", - "\\right)\n", - "&\\leq\\gamma^{11}_c(x)\\leq\n", - "\\min\\left(\n", - " 1, \\frac{\\pi_a(x)+\\pi_c(x)}{\\pi_c(x)}p_{1\\mid 11}(x)\n", - "\\right)\n", - "\\end{aligned}\n", - "\\end{equation}\n", - "\n", - "The point of all this is that the data (plus a no-defiers assumption) lets us estimate all the necessary inputs to these upper and lower bounds on $\\gamma^{11}_c(x)$ and $\\gamma^{00}_c(x)$ which in turn define our estimand. What remains is to estimate those inputs, as functions of $x$, and to do so while enforcing the monotonicty restriction $$\\operatorname{Pr}(V=1 \\mid Z=0, X=x)=\\pi_a(x) \\leq \n", - "\\operatorname{Pr}(V=1 \\mid Z=1, X=x)=\\pi_a(x)+\\pi_c(x).$$\n", - "\n", - "We can do all of this with calls to stochtree from R (or Python). But first, let's generate some test data. " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Simulate the data" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Start with some initial setup / housekeeping" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Size of the training sample\n", - "n = 20000\n", - "\n", - "# To set the seed for reproducibility/illustration purposes, replace \"None\" with a positive integer\n", - "random_seed = None\n", - "if random_seed is not None:\n", - " rng = np.random.default_rng(random_seed)\n", - "else:\n", - " rng = np.random.default_rng()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "First, we generate the instrument exogenously" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "z = rng.binomial(n=1, p=0.5, size=n)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Next, we generate the covariate. (For this example, let's think of it as patient age, although we are generating it from a uniform distribution between 0 and 3, so you have to imagine that it has been pre-standardized to this scale. It keeps the DGPs cleaner for illustration purposes.)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "p_X = 1\n", - "X = rng.uniform(low=0., high=3., size=(n,p_X))\n", - "x = X[:,0] # for ease of reference later" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Next, we generate the principal strata $S$ based on the observed value of $X$. We generate it according to a logistic regression with two coefficients per strata, an intercept and a slope. Here, these coefficients are set so that the probability of being a never taker decreases with age." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "alpha_a = 0\n", - "beta_a = 1\n", - "\n", - "alpha_n = 1\n", - "beta_n = -1\n", - "\n", - "alpha_c = 1\n", - "beta_c = 1\n", - "\n", - "# Define function (a logistic model) to generate Pr(S = s | X = x)\n", - "def pi_s(xval, alpha_a, beta_a, alpha_n, beta_n, alpha_c, beta_c):\n", - " w_a = np.exp(alpha_a + beta_a*xval)\n", - " w_n = np.exp(alpha_n + beta_n*xval)\n", - " w_c = np.exp(alpha_c + beta_c*xval)\n", - " w = np.column_stack((w_a, w_n, w_c))\n", - " w_rowsum = np.sum(w, axis=1, keepdims=True)\n", - " return np.divide(w, w_rowsum)\n", - " \n", - "# Sample principal strata based on observed probabilities\n", - "strata_probs = pi_s(X[:,0], alpha_a, beta_a, alpha_n, beta_n, alpha_c, beta_c)\n", - "s = np.empty_like(X[:,0], dtype=str)\n", - "for i in range(s.size):\n", - " s[i] = rng.choice(a=['a','n','c'], size=1, p=strata_probs[i,:])[0]\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Next, we generate the treatment variable, here denoted $V$ (for \"vaccine\"), as a *deterministic* function of $S$ and $Z$; this is what gives the principal strata their meaning." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "v = 1*(s=='a') + 0*(s=='n') + z*(s==\"c\") + (1-z)*(s == \"d\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Finally, the outcome structural model is specified, based on which the outcome is sampled. By varying this function in particular ways, we can alter the identification conditions." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def gamfun(xval, vval, zval, sval):\n", - " \"\"\"\n", - " If this function depends on zval, then exclusion restriction is violated.\n", - " If this function does not depend on sval, then IV analysis wasn't necessary.\n", - " If this function does not depend on x, then there are no HTEs.\n", - " \"\"\"\n", - " baseline = norm.cdf(2 - 1*xval - 2.5*((xval-1.5)**2) - 0.5*zval + 1*(sval==\"n\") - 1*(sval==\"a\"))\n", - " return baseline - 0.5*vval*baseline\n", - "\n", - "y = rng.binomial(n=1, p=gamfun(X[:,0],v,z,s), size=n)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Lastly, we perform some organization for our supervised learning algorithms later on." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Concatenate X, v and z for our supervised learning algorithms\n", - "Xall = np.concatenate((X, np.column_stack((v,z))), axis=1)\n", - "\n", - "# Update the size of \"X\" to be the size of Xall\n", - "p_X = p_X + 2\n", - "\n", - "# For the monotone probit model it is necessary to sort the observations so that the Z=1 cases are all together\n", - "# at the start of the outcome vector. \n", - "sort_index = np.argsort(z)[::-1]\n", - "X = X[sort_index,:]\n", - "Xall = Xall[sort_index,:]\n", - "z = z[sort_index]\n", - "v = v[sort_index]\n", - "s = s[sort_index]\n", - "y = y[sort_index]\n", - "x = x[sort_index]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now let's see if we can recover these functions from the observed data." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Fit the outcome model\n", - "\n", - "We have to fit three models here, the treatment models: $\\operatorname{Pr}(V = 1 | Z = 1, X=x)$ and $\\operatorname{Pr}(V = 1 | Z = 0,X = x)$, subject to the monotonicity constraint $\\operatorname{Pr}(V = 1 | Z = 1, X=x) \\geq \\operatorname{Pr}(V = 1 | Z = 0,X = x)$, and an outcome model $\\operatorname{Pr}(Y = 1 | Z = 1, V = 1, X = x)$. All of this will be done with stochtree. \n", - "\n", - "The outcome model is fit with a single (S-learner) BART model. This part of the model could be fit as a T-Learner or as a BCF model. Here we us an S-Learner for simplicity. Both models are probit models, and use the well-known Albert and Chib (1993) data augmentation Gibbs sampler. This section covers the more straightforward outcome model. The next section describes how the monotonicity constraint is handled with a data augmentation Gibbs sampler. \n", - "\n", - "These models could (and probably should) be wrapped as functions. Here they are implemented as scripts, with the full loops shown. The output -- at the end of the loops -- are stochtree forest objects from which we can extract posterior samples and generate predictions. In particular, the $ITT_c$ will be constructed using posterior counterfactual predictions derived from these forest objects. \n", - "\n", - "We begin by setting a bunch of hyperparameters and instantiating the forest objects to be operated upon in the main sampling loop. We also initialize the latent variables." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Fit the BART model for Pr(Y = 1 | Z = 1, V = 1, X = x)\n", - "\n", - "# Set number of iterations\n", - "num_warmstart = 10\n", - "num_mcmc = 1000\n", - "num_samples = num_warmstart + num_mcmc\n", - "\n", - "# Set a bunch of hyperparameters. These are ballpark default values.\n", - "alpha = 0.95\n", - "beta = 2\n", - "min_samples_leaf = 1\n", - "max_depth = 20\n", - "num_trees = 50\n", - "cutpoint_grid_size = 100\n", - "global_variance_init = 1.\n", - "tau_init = 0.5\n", - "leaf_prior_scale = np.array([[tau_init]])\n", - "leaf_regression = False\n", - "feature_types = np.append(np.repeat(0, p_X - 2), [1,1]).astype(int)\n", - "var_weights = np.repeat(1.0/p_X, p_X)\n", - "outcome_model_type = 0\n", - "\n", - "# C++ dataset\n", - "forest_dataset = Dataset()\n", - "forest_dataset.add_covariates(Xall)\n", - "\n", - "# Random number generator (std::mt19937)\n", - "if random_seed is not None:\n", - " cpp_rng = RNG(random_seed)\n", - "else:\n", - " cpp_rng = RNG()\n", - "\n", - "# Sampling data structures\n", - "forest_model_config = ForestModelConfig(\n", - " feature_types = feature_types, \n", - " num_trees = num_trees, \n", - " num_features = p_X, \n", - " num_observations = n, \n", - " variable_weights = var_weights, \n", - " leaf_dimension = 1, \n", - " alpha = alpha, \n", - " beta = beta, \n", - " min_samples_leaf = min_samples_leaf, \n", - " max_depth = max_depth, \n", - " leaf_model_type = outcome_model_type, \n", - " leaf_model_scale = leaf_prior_scale, \n", - " cutpoint_grid_size = cutpoint_grid_size\n", - ")\n", - "global_model_config = GlobalModelConfig(global_error_variance=1.0)\n", - "forest_sampler = ForestSampler(\n", - " forest_dataset, global_model_config, forest_model_config\n", - ")\n", - "\n", - "# Container of forest samples\n", - "forest_samples = ForestContainer(num_trees, 1, True, False)\n", - "\n", - "# \"Active\" forest state\n", - "active_forest = Forest(num_trees, 1, True, False)\n", - "\n", - "# Initialize the latent outcome zed\n", - "n1 = np.sum(y)\n", - "zed = 0.25*(2.0*y - 1.0)\n", - "\n", - "# C++ outcome variable\n", - "outcome = Residual(zed)\n", - "\n", - "# Initialize the active forest and subtract each root tree's predictions from outcome\n", - "forest_init_val = np.array([0.0])\n", - "forest_sampler.prepare_for_sampler(\n", - " forest_dataset,\n", - " outcome,\n", - " active_forest,\n", - " outcome_model_type,\n", - " forest_init_val,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now we enter the main loop, which involves only two steps: sample the forest, given the latent utilities, then sample the latent utilities given the estimated conditional means defined by the forest and its parameters. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "gfr_flag = True\n", - "for i in range(num_samples):\n", - " # The first num_warmstart iterations use the grow-from-root algorithm of He and Hahn\n", - " if i >= num_warmstart:\n", - " gfr_flag = False\n", - " \n", - " # Sample forest\n", - " forest_sampler.sample_one_iteration(\n", - " forest_samples, active_forest, forest_dataset, outcome, cpp_rng, \n", - " global_model_config, forest_model_config, keep_forest=True, gfr = gfr_flag\n", - " )\n", - "\n", - " # Get the current means\n", - " eta = np.squeeze(forest_samples.predict_raw_single_forest(forest_dataset, i))\n", - "\n", - " # Sample latent normals, truncated according to the observed outcome y\n", - " mu0 = eta[y == 0]\n", - " mu1 = eta[y == 1]\n", - " u0 = rng.uniform(\n", - " low=0.0,\n", - " high=norm.cdf(0 - mu0),\n", - " size=n-n1,\n", - " )\n", - " u1 = rng.uniform(\n", - " low=norm.cdf(0 - mu1),\n", - " high=1.0,\n", - " size=n1,\n", - " )\n", - " zed[y == 0] = mu0 + norm.ppf(u0)\n", - " zed[y == 1] = mu1 + norm.ppf(u1)\n", - "\n", - " # Update outcome\n", - " new_outcome = np.squeeze(zed) - eta\n", - " outcome.update_data(new_outcome)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Fit the monotone probit model(s)\n", - "\n", - "The monotonicty constraint relies on a data augmentation as described in Papakostas et al (2023). The implementation of this sampler is inherently cumbersome, as one of the \"data\" vectors is constructed from some observed data and some latent data and there are two forest objects, one of which applies to all of the observations and one of which applies to only those observations with $Z = 0$. We go into more details about this sampler in a dedicated vignette. Here we include the code, but without producing the equations derived in Papakostas (2023). What is most important is simply that\n", - "\n", - "\\begin{equation}\n", - "\\begin{aligned}\n", - "\\operatorname{Pr}(V=1 \\mid Z=0, X=x)&=\\pi_a(x) = \\Phi_f(x)\\Phi_h(x),\\\\\n", - "\\operatorname{Pr}(V=1 \\mid Z=1, X=x)&=\\pi_a(x)+\\pi_c(x) = \\Phi_f(x),\n", - "\\end{aligned}\n", - "\\end{equation}\n", - "where $\\Phi_{\\mu}(x)$ denotes the normal cumulative distribution function with mean $\\mu(x)$ and variance 1. \n", - "\n", - "We first create a secondary data matrix for the $Z=0$ group only. We also set all of the hyperparameters and initialize the latent variables." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Fit the monotone probit model to the treatment such that Pr(V = 1 | Z = 1, X=x) >= Pr(V = 1 | Z = 0,X = x) \n", - "X_h = X[z==0,:]\n", - "n0 = np.sum(z==0)\n", - "n1 = np.sum(z==1)\n", - "\n", - "num_trees_f = 50\n", - "num_trees_h = 20\n", - "feature_types = np.repeat(0, p_X-2).astype(int)\n", - "var_weights = np.repeat(1.0/(p_X - 2.0), p_X - 2)\n", - "cutpoint_grid_size = 100\n", - "global_variance_init = 1.\n", - "tau_init_f = 1/num_trees_f\n", - "tau_init_h = 1/num_trees_h\n", - "leaf_prior_scale_f = np.array([[tau_init_f]])\n", - "leaf_prior_scale_h = np.array([[tau_init_h]])\n", - "leaf_regression = False # fit a constant leaf mean BART model\n", - "\n", - "# Instantiate the C++ dataset objects\n", - "forest_dataset_f = Dataset()\n", - "forest_dataset_f.add_covariates(X)\n", - "forest_dataset_h = Dataset()\n", - "forest_dataset_h.add_covariates(X_h)\n", - "\n", - "# Tell it we're fitting a normal BART model\n", - "outcome_model_type = 0\n", - "\n", - "# Set up model configuration objects\n", - "forest_model_config_f = ForestModelConfig(\n", - " feature_types = feature_types, \n", - " num_trees = num_trees_f, \n", - " num_features = X.shape[1], \n", - " num_observations = n, \n", - " variable_weights = var_weights, \n", - " leaf_dimension = 1, \n", - " alpha = alpha, \n", - " beta = beta, \n", - " min_samples_leaf = min_samples_leaf, \n", - " max_depth = max_depth, \n", - " leaf_model_type = outcome_model_type, \n", - " leaf_model_scale = leaf_prior_scale_f, \n", - " cutpoint_grid_size = cutpoint_grid_size\n", - ")\n", - "forest_model_config_h = ForestModelConfig(\n", - " feature_types = feature_types, \n", - " num_trees = num_trees_h, \n", - " num_features = X_h.shape[1], \n", - " num_observations = n0, \n", - " variable_weights = var_weights, \n", - " leaf_dimension = 1, \n", - " alpha = alpha, \n", - " beta = beta, \n", - " min_samples_leaf = min_samples_leaf, \n", - " max_depth = max_depth, \n", - " leaf_model_type = outcome_model_type, \n", - " leaf_model_scale = leaf_prior_scale_h, \n", - " cutpoint_grid_size = cutpoint_grid_size\n", - ")\n", - "global_model_config = GlobalModelConfig(global_error_variance=global_variance_init)\n", - "\n", - "# Instantiate the sampling data structures\n", - "forest_sampler_f = ForestSampler(\n", - " forest_dataset_f, global_model_config, forest_model_config_f\n", - ")\n", - "forest_sampler_h = ForestSampler(\n", - " forest_dataset_h, global_model_config, forest_model_config_h\n", - ")\n", - "\n", - "# Instantiate containers of forest samples\n", - "forest_samples_f = ForestContainer(num_trees_f, 1, True, False)\n", - "forest_samples_h = ForestContainer(num_trees_h, 1, True, False)\n", - "\n", - "# Instantiate \"active\" forests\n", - "active_forest_f = Forest(num_trees_f, 1, True, False)\n", - "active_forest_h = Forest(num_trees_h, 1, True, False)\n", - "\n", - "# Set algorithm specifications \n", - "# these are set in the earlier script for the outcome model; number of draws needs to be commensurable \n", - "\n", - "# num_warmstart = 40\n", - "# num_mcmc = 5000\n", - "# num_samples = num_warmstart + num_mcmc\n", - "\n", - "# Initialize the Markov chain\n", - "\n", - "# Initialize (R0, R1), the latent binary variables that enforce the monotonicty \n", - "v1 = v[z==1]\n", - "v0 = v[z==0]\n", - "\n", - "R1 = np.empty(n0, dtype=float)\n", - "R0 = np.empty(n0, dtype=float)\n", - "\n", - "R1[v0==1] = 1\n", - "R0[v0==1] = 1\n", - "\n", - "nv0 = np.sum(v0==0)\n", - "R1[v0 == 0] = 0\n", - "R0[v0 == 0] = rng.choice([0,1], size = nv0)\n", - "\n", - "# The first n1 observations of vaug are actually observed\n", - "# The next n0 of them are the latent variable R1\n", - "vaug = np.append(v1, R1)\n", - "\n", - "# Initialize the Albert and Chib latent Gaussian variables\n", - "z_f = (2.0*vaug - 1.0)\n", - "z_h = (2.0*R0 - 1.0)\n", - "z_f = z_f/np.std(z_f)\n", - "z_h = z_h/np.std(z_h)\n", - "\n", - "# Pass these variables to the BART models as outcome variables\n", - "outcome_f = Residual(z_f)\n", - "outcome_h = Residual(z_h)\n", - "\n", - "# Initialize active forests to constant (0) predictions\n", - "forest_init_val_f = np.array([0.0])\n", - "forest_sampler_f.prepare_for_sampler(\n", - " forest_dataset_f,\n", - " outcome_f,\n", - " active_forest_f,\n", - " outcome_model_type,\n", - " forest_init_val_f,\n", - ")\n", - "forest_init_val_h = np.array([0.0])\n", - "forest_sampler_h.prepare_for_sampler(\n", - " forest_dataset_h,\n", - " outcome_h,\n", - " active_forest_h,\n", - " outcome_model_type,\n", - " forest_init_val_h,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now we run the main sampling loop, which consists of three key steps: sample the BART forests, given the latent probit utilities, sampling the latent binary outcome pairs (this is the step that is necessary for enforcing monotonicity), given the forest predictions and the latent utilities, and finally sample the latent utilities." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# PART IV: run the Markov chain \n", - "\n", - "# Initialize the Markov chain with num_warmstart grow-from-root iterations\n", - "gfr_flag = True\n", - "for i in range(num_samples):\n", - " # Switch over to random walk Metropolis-Hastings tree updates after num_warmstart\n", - " if i >= num_warmstart:\n", - " gfr_flag = False\n", - " \n", - " # Step 1: Sample the BART forests\n", - "\n", - " # Sample forest for the function f based on (y_f, R1)\n", - " forest_sampler_f.sample_one_iteration(\n", - " forest_samples_f, active_forest_f, forest_dataset_f, outcome_f, cpp_rng, \n", - " global_model_config, forest_model_config_f, keep_forest=True, gfr = gfr_flag\n", - " )\n", - "\n", - " # Sample forest for the function h based on outcome R0\n", - " forest_sampler_h.sample_one_iteration(\n", - " forest_samples_h, active_forest_h, forest_dataset_h, outcome_h, cpp_rng, \n", - " global_model_config, forest_model_config_h, keep_forest=True, gfr = gfr_flag\n", - " )\n", - "\n", - " # Get the current means\n", - " eta_f = np.squeeze(forest_samples_f.predict_raw_single_forest(forest_dataset_f, i))\n", - " eta_h = np.squeeze(forest_samples_h.predict_raw_single_forest(forest_dataset_h, i))\n", - "\n", - " # Step 2: sample the latent binary pair (R0, R1) given eta_h, eta_f, and y_g\n", - "\n", - " # Three cases: (0,0), (0,1), (1,0)\n", - " w1 = (1 - norm.cdf(eta_h[v0==0]))*(1 - norm.cdf(eta_f[n1 + np.where(v0==0)]))\n", - " w2 = (1 - norm.cdf(eta_h[v0==0]))*norm.cdf(eta_f[n1 + np.where(v0==0)])\n", - " w3 = norm.cdf(eta_h[v0==0])*(1 - norm.cdf(eta_f[n1 + np.where(v0==0)]))\n", - "\n", - " s = w1 + w2 + w3\n", - " w1 = w1/s\n", - " w2 = w2/s\n", - " w3 = w3/s\n", - "\n", - " u = rng.uniform(low=0,high=1,size=np.sum(v0==0))\n", - " temp = 1*(np.squeeze(u < w1)) + 2*(np.squeeze((u > w1) & (u < (w1 + w2)))) + 3*(np.squeeze(u > (w1 + w2)))\n", - "\n", - " R1[v0==0] = 1*(temp==2)\n", - " R0[v0==0] = 1*(temp==3)\n", - "\n", - " # Redefine y with the updated R1 component\n", - " vaug = np.append(v1, R1)\n", - "\n", - " # Step 3: sample the latent normals, given (R0, R1) and y_f\n", - "\n", - " # First z0\n", - " mu1 = eta_h[R0==1]\n", - " U1 = rng.uniform(\n", - " low=norm.cdf(0 - mu1), \n", - " high=1,\n", - " size=np.sum(R0).astype(int)\n", - " )\n", - " z_h[R0==1] = mu1 + norm.ppf(U1)\n", - "\n", - " mu0 = eta_h[R0==0]\n", - " U0 = rng.uniform(\n", - " low=0, \n", - " high=norm.cdf(0 - mu0),\n", - " size=(n0 - np.sum(R0)).astype(int)\n", - " )\n", - " z_h[R0==0] = mu0 + norm.ppf(U0)\n", - "\n", - " # Then z1\n", - " mu1 = eta_f[vaug==1]\n", - " U1 = rng.uniform(\n", - " low=norm.cdf(0 - mu1), \n", - " high=1,\n", - " size=np.sum(vaug).astype(int)\n", - " )\n", - " z_f[vaug==1] = mu1 + norm.ppf(U1)\n", - "\n", - " mu0 = eta_f[vaug==0]\n", - " U0 = rng.uniform(\n", - " low=0, \n", - " high=norm.cdf(0 - mu0),\n", - " size=(n - np.sum(vaug)).astype(int)\n", - " )\n", - " z_f[vaug==0] = mu0 + norm.ppf(U0)\n", - "\n", - " # Propagate the updated outcomes through the BART models\n", - " new_outcome_h = np.squeeze(z_h) - eta_h\n", - " outcome_h.update_data(new_outcome_h)\n", - "\n", - " new_outcome_f = np.squeeze(z_f) - eta_f\n", - " outcome_f.update_data(new_outcome_f)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Extracting the estimates and plotting the results.\n", - "\n", - "Now for the most interesting part, which is taking the stochtree BART model fits and producing the causal estimates of interest. \n", - "\n", - "First we set up our grid for plotting the functions in $X$. This is possible in this example because the moderator, age, is one dimensional; in may applied problems this will not be the case and visualization will be substantially trickier. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Extract the credible intervals for the conditional treatment effects as a function of x.\n", - "# We use a grid of values for plotting, with grid points that are typically fewer than the number of observations.\n", - "\n", - "ngrid = 200\n", - "xgrid = np.linspace(start=0.1, stop=2.9, num=ngrid)\n", - "X_11 = np.column_stack((xgrid, np.ones(ngrid), np.ones(ngrid)))\n", - "X_00 = np.column_stack((xgrid, np.zeros(ngrid), np.zeros(ngrid)))\n", - "X_01 = np.column_stack((xgrid, np.zeros(ngrid), np.ones(ngrid)))\n", - "X_10 = np.column_stack((xgrid, np.ones(ngrid), np.zeros(ngrid)))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Next, we compute the truth function evaluations on this plotting grid, using the functions defined above when we generated our data." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Compute the true conditional outcome probabilities for plotting\n", - "pi_strat = pi_s(xgrid, alpha_a, beta_a, alpha_n, beta_n, alpha_c, beta_c)\n", - "w_a = pi_strat[:,0]\n", - "w_n = pi_strat[:,1]\n", - "w_c = pi_strat[:,2]\n", - "\n", - "w = (w_c/(w_a + w_c))\n", - "p11_true = w*gamfun(xgrid,1,1,\"c\") + (1-w)*gamfun(xgrid,1,1,\"a\")\n", - "\n", - "w = (w_c/(w_n + w_c))\n", - "p00_true = w*gamfun(xgrid,0,0,\"c\") + (1-w)*gamfun(xgrid,0,0,\"n\")\n", - "\n", - "# Compute the true ITT_c for plotting and comparison\n", - "itt_c_true = gamfun(xgrid,1,1,\"c\") - gamfun(xgrid,0,0,\"c\")\n", - "\n", - "# Compute the true LATE for plotting and comparison\n", - "LATE_true0 = gamfun(xgrid,1,0,\"c\") - gamfun(xgrid,0,0,\"c\")\n", - "LATE_true1 = gamfun(xgrid,1,1,\"c\") - gamfun(xgrid,0,1,\"c\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Next we populate the data structures for stochtree to operate on, call the predict functions to extract the predictions, convert them to probability scale using the `scipy.stats.norm.cdf` method." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Datasets for counterfactual predictions\n", - "forest_dataset_grid = Dataset()\n", - "forest_dataset_grid.add_covariates(np.expand_dims(xgrid, 1))\n", - "forest_dataset_11 = Dataset()\n", - "forest_dataset_11.add_covariates(X_11)\n", - "forest_dataset_00 = Dataset()\n", - "forest_dataset_00.add_covariates(X_00)\n", - "forest_dataset_10 = Dataset()\n", - "forest_dataset_10.add_covariates(X_10)\n", - "forest_dataset_01 = Dataset()\n", - "forest_dataset_01.add_covariates(X_01)\n", - "\n", - "# Forest predictions\n", - "preds_00 = forest_samples.predict(forest_dataset_00)\n", - "preds_11 = forest_samples.predict(forest_dataset_11)\n", - "preds_01 = forest_samples.predict(forest_dataset_01)\n", - "preds_10 = forest_samples.predict(forest_dataset_10)\n", - "\n", - "# Probability transformations\n", - "phat_00 = norm.cdf(preds_00)\n", - "phat_11 = norm.cdf(preds_11)\n", - "phat_01 = norm.cdf(preds_01)\n", - "phat_10 = norm.cdf(preds_10)\n", - "\n", - "preds_ac = forest_samples_f.predict(forest_dataset_grid)\n", - "phat_ac = norm.cdf(preds_ac)\n", - "\n", - "preds_adj = forest_samples_h.predict(forest_dataset_grid)\n", - "phat_a = norm.cdf(preds_ac) * norm.cdf(preds_adj)\n", - "phat_c = phat_ac - phat_a\n", - "phat_n = 1 - phat_ac" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now we may plot posterior means of various quantities (as a function of $X$) to visualize how well the models are fitting." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fig, (ax1, ax2) = plt.subplots(1, 2)\n", - "ax1.scatter(p11_true, np.mean(phat_11, axis=1), color=\"black\")\n", - "ax1.axline((0, 0), slope=1, color=\"red\", linestyle=(0, (3, 3)))\n", - "ax2.scatter(p00_true, np.mean(phat_00, axis=1), color=\"black\")\n", - "ax2.axline((0, 0), slope=1, color=\"red\", linestyle=(0, (3, 3)))\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fig, (ax1, ax2, ax3) = plt.subplots(1, 3, sharex=\"none\", sharey=\"none\")\n", - "ax1.scatter(np.mean(phat_ac, axis=1), w_c + w_a, color=\"black\")\n", - "ax1.axline((0, 0), slope=1, color=\"red\", linestyle=(0, (3, 3)))\n", - "ax1.set_xlim(0.5,1.1)\n", - "ax1.set_ylim(0.5,1.1)\n", - "ax2.scatter(np.mean(phat_a, axis=1), w_a, color=\"black\")\n", - "ax2.axline((0, 0), slope=1, color=\"red\", linestyle=(0, (3, 3)))\n", - "ax2.set_xlim(0.1,0.4)\n", - "ax2.set_ylim(0.1,0.3)\n", - "ax3.scatter(np.mean(phat_c, axis=1), w_c, color=\"black\")\n", - "ax3.axline((0, 0), slope=1, color=\"red\", linestyle=(0, (3, 3)))\n", - "ax3.set_xlim(0.4,0.9)\n", - "ax3.set_ylim(0.4,0.8)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "These plots are not as pretty as we might hope, but mostly this is a function of how difficult it is to learn conditional probabilities from binary outcomes. That we capture the trend broadly turns out to be adequate for estimating treatment effects. Fit does improve with simpler DGPs and larger training sets, as can be confirmed by experimentation with this script. \n", - "\n", - "Lastly, we can construct the estimate of the $ITT_c$ and compare it to the true value as well as the $Z=0$ and $Z=1$ complier average treatment effects (also called \"local average treatment effects\" or LATE). The key step in this process is to center our posterior on the identified interval (at each iteration of the sampler) at the value implied by a valid exclusion restriction. For some draws this will not be possible, as that value will be outside the identification region." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Generate draws from the posterior of the treatment effect\n", - "# centered at the point-identified value under the exclusion restriction\n", - "itt_c = np.empty((ngrid, phat_c.shape[1]))\n", - "late = np.empty((ngrid, phat_c.shape[1]))\n", - "ss = 6\n", - "for j in range(phat_c.shape[1]):\n", - " # Value of gamma11 implied by an exclusion restriction\n", - " gamest11 = ((phat_a[:,j] + phat_c[:,j])/phat_c[:,j])*phat_11[:,j] - phat_10[:,j]*phat_a[:,j]/phat_c[:,j]\n", - "\n", - " # Identified region for gamma11\n", - " lower11 = np.maximum(0., ((phat_a[:,j] + phat_c[:,j])/phat_c[:,j])*phat_11[:,j] - phat_a[:,j]/phat_c[:,j])\n", - " upper11 = np.minimum(1., ((phat_a[:,j] + phat_c[:,j])/phat_c[:,j])*phat_11[:,j])\n", - "\n", - " # Center a beta distribution at gamma11, but restricted to (lower11, upper11)\n", - " # do this by shifting and scaling the mean, drawing from a beta on (0,1), then shifting and scaling to the \n", - " # correct restricted interval\n", - " m11 = (gamest11 - lower11)/(upper11 - lower11)\n", - "\n", - " # Parameters of the beta\n", - " a1 = ss*m11\n", - " b1 = ss*(1-m11)\n", - "\n", - " # When the corresponding mean is out-of-range, sample from a beta with mass piled near the violated boundary\n", - " a1[m11<0] = 1\n", - " b1[m11<0] = 5\n", - " \n", - " a1[m11>1] = 5\n", - " b1[m11>1] = 1\n", - "\n", - " # Value of gamma00 implied by an exclusion restriction\n", - " gamest00 = ((phat_n[:,j] + phat_c[:,j])/phat_c[:,j])*phat_00[:,j] - phat_01[:,j]*phat_n[:,j]/phat_c[:,j]\n", - "\n", - " # Identified region for gamma00\n", - " lower00 = np.maximum(0., ((phat_n[:,j] + phat_c[:,j])/phat_c[:,j])*phat_00[:,j] - phat_n[:,j]/phat_c[:,j])\n", - " upper00 = np.minimum(1., ((phat_n[:,j] + phat_c[:,j])/phat_c[:,j])*phat_00[:,j])\n", - "\n", - " # Center a beta distribution at gamma00, but restricted to (lower00, upper00)\n", - " # do this by shifting and scaling the mean, drawing from a beta on (0,1), then shifting and scaling to the \n", - " # correct restricted interval\n", - " m00 = (gamest00 - lower00)/(upper00 - lower00)\n", - "\n", - " a0 = ss*m00\n", - " b0 = ss*(1-m00)\n", - " a0[m00<0] = 1\n", - " b0[m00<0] = 5 \n", - " a0[m00>1] = 5\n", - " b0[m00>1] = 1\n", - "\n", - " # ITT and LATE \n", - " itt_c[:,j] = lower11 + (upper11 - lower11)*rng.beta(a=a1,b=b1,size=ngrid) - (lower00 + (upper00 - lower00)*rng.beta(a=a0,b=b0,size=ngrid))\n", - " late[:,j] = gamest11 - gamest00\n", - "\n", - "upperq = np.quantile(itt_c, q=0.975, axis=1)\n", - "lowerq = np.quantile(itt_c, q=0.025, axis=1)\n", - "upperq_er = np.quantile(late, q=0.975, axis=1)\n", - "lowerq_er = np.quantile(late, q=0.025, axis=1)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "And now we can plot all of this, shading posterior quantiles with [pyplot's `fill` function](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.fill.html)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plt.plot(xgrid, itt_c_true, color = \"black\")\n", - "plt.ylim(-0.75, 0.05)\n", - "plt.fill(np.append(xgrid, xgrid[::-1]), np.append(lowerq, upperq[::-1]), color = (0.5,0.5,0,0.25))\n", - "plt.fill(np.append(xgrid, xgrid[::-1]), np.append(lowerq_er, upperq_er[::-1]), color = (0,0,0.5,0.25))\n", - "\n", - "itt_c_est = np.mean(itt_c, axis=1)\n", - "late_est = np.mean(late, axis=1)\n", - "\n", - "plt.plot(xgrid, late_est, color = \"darkgrey\")\n", - "plt.plot(xgrid, itt_c_est, color = \"gold\")\n", - "plt.plot(xgrid, LATE_true0, color = \"black\", linestyle = (0, (2, 2)))\n", - "plt.plot(xgrid, LATE_true1, color = \"black\", linestyle = (0, (4, 4)))\n", - "plt.plot(xgrid, itt_c_true, color = \"black\")\n", - "\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "With a valid exclusion restriction the three black curves would all be the same. With no exclusion restriction, as we have here, the direct effect of $Z$ on $Y$ (the vaccine reminder on flu status) makes it so these three treatment effects are different. Specifically, the $ITT_c$ compares getting the vaccine *and* getting the reminder to not getting the vaccine *and* not getting the reminder. When both things have risk reducing impacts, we see a larger risk reduction over all values of $X$. Meanwhile, the two LATE effects compare the isolated impact of the vaccine among people that got the reminder and those that didn't, respectively. Here, not getting the reminder makes the vaccine more effective because the risk reduction is as a fraction of baseline risk, and the reminder reduces baseline risk in our DGP. \n", - "\n", - "We see also that the posterior mean of the $ITT_c$ estimate (gold) is very similar to the posterior mean under the assumption of an exclusion restriction (gray). This is by design...they will only deviate due to Monte Carlo variation or due to the rare situations where the exclusion restriction is incompatible with the identification interval. \n", - "\n", - "By changing the sample size and various aspects of the DGP this script allows us to build some intuition for how aspects of the DGP affect posterior inferences, particularly how violates of assumptions affect accuracy and posterior uncertainty." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# References\n", - "\n", - "Albert, James H, and Siddhartha Chib. 1993. β€œBayesian Analysis of Binary and Polychotomous Response Data.” *Journal of the American Statistical Association* 88 (422): 669–79.\n", - "\n", - "Hahn, P Richard, Jared S Murray, and Ioanna Manolopoulou. 2016. β€œA Bayesian Partial Identification Approach to Inferring the Prevalence of Accounting Misconduct.” *Journal of the American Statistical Association* 111 (513): 14–26.\n", - "\n", - "Hirano, Keisuke, Guido W. Imbens, Donald B. Rubin, and Xiao-Hua Zhou. 2000. β€œAssessing the Effect of an Influenza Vaccine in an Encouragement Design.” *Biostatistics* 1 (1): 69–88. https://doi.org/10.1093/biostatistics/1.1.69.\n", - "\n", - "Imbens, Guido W., and Donald B. Rubin. 2015. *Causal Inference for Statistics, Social, and Biomedical Sciences: An Introduction*. Cambridge University Press.\n", - "\n", - "McDonald, Clement J, Siu L Hui, and William M Tierney. 1992. β€œEffects of Computer Reminders for Influenza Vaccination on Morbidity During Influenza Epidemics.” *MD Computing: Computers in Medical Practice* 9 (5): 304–12.\n", - "\n", - "Papakostas, Demetrios, P Richard Hahn, Jared Murray, Frank Zhou, and Joseph Gerakos. 2023. β€œDo Forecasts of Bankruptcy Cause Bankruptcy? A Machine Learning Sensitivity Analysis.” *The Annals of Applied Statistics* 17 (1): 711–39.\n", - "\n", - "Richardson, Thomas S., Robin J. Evans, and James M. Robins. 2011. β€œTransparent Parametrizations of Models for Potential Outcomes.” In *Bayesian Statistics 9*. Oxford University Press. https://doi.org/10.1093/acprof:oso/9780199694587.003.0019." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/vignettes/Python/RDD/RDD_DAG.png b/vignettes/Python/RDD/RDD_DAG.png deleted file mode 100644 index a73abc16e..000000000 Binary files a/vignettes/Python/RDD/RDD_DAG.png and /dev/null differ diff --git a/vignettes/Python/RDD/rdd.html b/vignettes/Python/RDD/rdd.html deleted file mode 100644 index 09923ab16..000000000 --- a/vignettes/Python/RDD/rdd.html +++ /dev/null @@ -1,8089 +0,0 @@ - - - - - -rdd - - - - - - - - - - - - -
- - - - - - - - - - - - - - - - - - - - - -
- - diff --git a/vignettes/Python/RDD/rdd.ipynb b/vignettes/Python/RDD/rdd.ipynb deleted file mode 100644 index 31cf4f09d..000000000 --- a/vignettes/Python/RDD/rdd.ipynb +++ /dev/null @@ -1,475 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Regression Discontinuity Design (RDD) with `stochtree`" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Introduction\n", - "\n", - "We study conditional average treatment effect (CATE) estimation for regression discontinuity designs (RDD), in which treatment assignment is based on whether a particular covariate --- referred to as the running variable --- lies above or below a known value, referred to as the cutoff value. Because treatment is deterministically assigned as a known function of the running variable, RDDs are trivially deconfounded: treatment assignment is independent of the outcome variable, given the running variable (because treatment is conditionally constant). However, estimation of treatment effects in RDDs is more complicated than simply controlling for the running variable, because doing so introduces a complete lack of overlap, which is the other key condition needed to justify regression adjustment for causal inference. Nonetheless, the CATE _at the cutoff_, $X=c$, may still be identified provided the conditional expectation $E[Y \\mid X,W]$ is continuous at that point for _all_ $W=w$. We exploit this assumption with the leaf regression BART model implemented in Stochtree, which allows us to define an explicit prior on the CATE. We now describe the RDD setup and our model in more detail, and provide code to implement our approach.\n", - "\n", - "## Regression Discontinuity Design\n", - "\n", - "We conceptualize the treatment effect estimation problem via a quartet of random variables $(Y, X, Z, U)$. The variable $Y$ is the outcome variable; the variable $X$ is the running variable; the variable $Z$ is the treatment assignment indicator variable; and the variable $U$ represents additional, possibly unobserved, causal factors. What specifically makes this correspond to an RDD is that we stipulate that $Z = I(X > c)$, for cutoff $c$. We assume $c = 0$ without loss of generality. \n", - "\t \n", - "The following figure depicts a causal diagram representing the assumed causal relationships between these variables. Two key features of this diagram are one, that $X$ blocks the impact of $U$ on $Z$: in other words, $X$ satisfies the back-door criterion for learning causal effects of $Z$ on $Y$. And two, $X$ and $U$ are not descendants of $Z$.\n", - "\n", - "![RDD_DAG](RDD_DAG.png)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Using this causal diagram, we may express $Y$ as some function of its graph parents, the random variables $(X,Z,U)$: $$Y = F(X,Z,U).$$ In principle, we may obtain draws of $Y$ by first drawing $(X,Z,U)$ according to their joint distribution and then applying the function $F$. Similarly, we may relate this formulation to the potential outcomes framework straightforwardly:\n", - "\\begin{equation}\n", - "\\begin{split}\n", - "Y^1 &= F(X,1,U),\\\\\n", - "Y^0 &= F(X,0,U).\n", - "\\end{split}\n", - "\\end{equation}\n", - "Here, draws of $(Y^1, Y^0)$ may be obtained (in principle) by drawing $(X,Z,U)$ from their joint distribution and using only the $(X,U)$ elements as arguments in the above two equations, \"discarding\" the drawn value of $Z$. Note that this construction implies the _consistency_ condition: $Y = Y^1 Z + Y^0 ( 1 - Z)$. Likewise, this construction implies the _no interference_ condition because each $Y_i$ is considered to be produced with arguments ($X_i, Z_i, U_i)$ and not those from other units $j$; in particular, in constructing $Y_i$, $F$ does not take $Z_j$ for $j \\neq i$ as an argument." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Next, we define the following conditional expectations\n", - "\\begin{equation}\n", - "\\begin{split}\n", - "\\mu_1(x) &= E[ F(x, 1, U) \\mid X = x] ,\\\\\n", - "\\mu_0(x) &= E[ F(x, 0, U) \\mid X = x],\n", - "\\end{split}\n", - "\\end{equation}\n", - "with which we can define the treatment effect function\n", - "$$\\tau(x) = \\mu_1(x) - \\mu_0(x).$$\n", - "Because $X$ satisfies the back-door criterion, $\\mu_1$ and $\\mu_0$ are estimable from the data, meaning that \n", - "\\begin{equation}\n", - "\\begin{split}\n", - "\\mu_1(x) &= E[ F(x, 1, U) \\mid X = x] = E[Y \\mid X=x, Z=1],\\\\\n", - "\\mu_0(x) &= E[ F(x, 0, U) \\mid X = x] = E[Y \\mid X=x, Z=0],\n", - "\\end{split}\n", - "\\end{equation}\t\n", - "the right-hand-sides of which can be estimated from sample data, which we supposed to be independent and identically distributed realizations of $(Y_i, X_i, Z_i)$ for $i = 1, \\dots, n$. However, because $Z = I(X >0)$ we can in fact only learn $\\mu_1(x)$ for $X > 0$ and $\\mu_0(x)$ for $X < 0$. In potential outcomes terminology, conditioning on $X$ satisfies ignorability,\n", - "$$(Y^1, Y^0) \\perp \\!\\!\\! \\perp Z \\mid X,$$\n", - "but not _strong ignorability_, because overlap is violated. Overlap would require that\n", - "$$0 < P(Z = 1 \\mid X=x) < 1 \\;\\;\\;\\; \\forall x,$$\n", - "which is clearly violated by the RDD assumption that $Z = I(X > 0)$. Consequently, the overall ATE, \n", - "$\\bar{\\tau} = E(\\tau(X)),$ is unidentified, and we must content ourselves with estimating $\\tau(0)$, the conditional average effect at the point $x = 0$, which we estimate as the difference between $\\mu_1(0) - \\mu_0(0)$. This is possible for continuous $X$ so long as one is willing to assume that $\\mu_1(x)$ and $\\mu_0(x)$ are both suitably smooth functions of $x$: any inferred discontinuity at $x = 0$ must therefore be attributable to treatment effect." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Conditional average treatment effects in RDD\n", - "\n", - "We are concerned with learning not only $\\tau(0)$, the \"RDD ATE\" (e.g. the CATE at $x = 0$), but also RDD CATEs, $\\tau(0, \\mathrm{w})$ for some covariate vector $\\mathrm{w}$. Incorporating additional covariates in the above framework turns out to be straightforward, simply by defining $W = \\varphi(U)$ to be an observable function of the (possibly unobservable) causal factors $U$. We may then define our potential outcome means as\n", - "\\begin{equation}\n", - "\\begin{split}\n", - "\\mu_1(x,\\mathrm{w}) &= E[ F(x, 1, U) \\mid X = x, W = \\mathrm{w}] = E[Y \\mid X=x, W=\\mathrm{w}, Z=1],\\\\\n", - "\\mu_0(x,\\mathrm{w}) &= E[ F(x, 0, U) \\mid X = x, W = \\mathrm{w}] = E[Y \\mid X=x, W =\\mathrm{w}, Z=0],\n", - "\\end{split}\n", - "\\end{equation}\n", - "and our treatment effect function as\n", - "\\begin{equation}\n", - "\\tau(x,\\mathrm{w}) = \\mu_1(x,\\mathrm{w}) - \\mu_0(x,\\mathrm{w})\n", - "\\end{equation}\n", - "We consider our data to be independent and identically distributed realizations $(Y_i, X_i, Z_i, W_i)$ for $i = 1, \\dots, n$. Furthermore, we must assume that $\\mu_1(x,\\mathrm{w})$ and $\\mu_0(x,\\mathrm{w})$ are suitably smooth functions of $x$, {\\em for every} $\\mathrm{w}$; in other words, for each value of $\\mathrm{w}$ the usual continuity-based identification assumptions must hold. \n", - "\n", - "With this framework and notation established, CATE estimation in RDDs boils down to estimation of condition expectation functions $E[Y \\mid X=x, W=\\mathrm{w}, Z=z]$, for which we turn to BART models." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## The BARDDT Model\n", - "\n", - "We propose a BART model where the trees are allowed to split on $(x,\\mathrm{w})$ but where each leaf node parameter is a vector of regression coefficients tailored to the RDD context (rather than a scalar constant as in default BART). In one sense, such a model can be seen as implying distinct RDD ATE regressions for each subgroup determined by a given tree; however, this intuition is only heuristic, as the entire model is fit jointly as an ensemble of such trees. Instead, we motivate this model as a way to estimate the necessary conditional expectations via a parametrization where the conditional treatment effect function can be explicitly regularized, as follows.\n", - "\n", - "Let $\\psi$ denote the following basis vector:\n", - "\\begin{equation}\n", - "\\psi(x,z) = \\begin{bmatrix}\n", - "1 & z x & (1-z) x & z\n", - "\\end{bmatrix}.\n", - "\\end{equation}\n", - "To generalize the original BART model, we define $g_j(x, \\mathrm{w}, z)$ as a piecewise linear function as follows. Let $b_j(x, \\mathrm{w})$ denote the node in the $j$th tree which contains the point $(x, \\mathrm{w})$; then the prediction function for tree $j$ is defined to be:\n", - "\\begin{equation}\n", - "g_j(x, \\mathrm{w}, z) = \\psi(x, z) \\Gamma_{b_j(x, \\mathrm{w})}\n", - "\\end{equation}\t\n", - "for a leaf-specific regression vector $\\Gamma_{b_j} = (\\eta_{b_j}, \\lambda_{b_j}, \\theta_{b_j}, \\Delta_{b_j})^t$. Therefore, letting $n_{b_j}$ denote the number of data points allocated to node $b$ in the $j$th tree and $\\Psi_{b_j}$ denote the $n_{b_j} \\times 4$ matrix, with rows equal to $\\psi(x,z)$ for all $(x_i,z_i) \\in b_j$, the model for observations assigned to leaf $b_j$, can be expressed in matrix notation as:\n", - "\\begin{equation}\n", - "\\begin{split}\n", - "\\mathbf{Y}_{b_j} \\mid \\Gamma_{b_j}, \\sigma^2 &\\sim \\mathrm{N}(\\Psi_{b_j} \\Gamma_{b_j},\\sigma^2)\\\\\n", - "\\Gamma_{b_j} &\\sim \\mathrm{N} (0, \\Sigma_0),\n", - "\\end{split}\n", - "\\end{equation}\n", - "where we set $\\Sigma_0 = \\frac{0.033}{J} \\mathbf{I}$ as a default (for $x$ vectors standardized to have unit variance in-sample). " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This choice of basis entails that the RDD CATE at $\\mathrm{w}$, $\\tau(0, \\mathrm{w})$, is a sum of the $\\Delta_{b_j(0, \\mathrm{w})}$ elements across all trees $j = 1, \\dots, J$:\n", - "\\begin{equation}\n", - "\\begin{split}\n", - "\\tau(0, \\mathrm{w}) &= E[Y^1 \\mid X=0, W = \\mathrm{w}] - E[Y^0 \\mid X = 0, W = \\mathrm{w}]\\\\\n", - "& = E[Y \\mid X=0, W = \\mathrm{w}, Z = 1] - E[Y \\mid X = 0, W = \\mathrm{w}, Z = 0]\\\\\n", - "&= \\sum_{j = 1}^J g_j(0, \\mathrm{w}, 1) - \\sum_{j = 1}^J g_j(0, \\mathrm{w}, 0)\\\\\n", - "&= \\sum_{j = 1}^J \\psi(0, 1) \\Gamma_{b_j(0, \\mathrm{w})} - \\sum_{j = 1}^J \\psi(0, 0) \\Gamma_{b_j(0, \\mathrm{w})} \\\\\n", - "& = \\sum_{j = 1}^J \\Bigl( \\psi(0, 1) - \\psi(0, 0) \\Bigr) \\Gamma_{b_j(0, \\mathrm{w})} \\\\\n", - "& = \\sum_{j = 1}^J \\Bigl( (1,0,0,1) - (1,0,0,0) \\Bigr) \\Gamma_{b_j(0, \\mathrm{w})} \\\\\n", - "&= \\sum_{j=1}^J \\Delta_{b_j(0, \\mathrm{w})}.\n", - "\\end{split}\n", - "\\end{equation}\n", - "As a result, the priors on the $\\Delta$ coefficients directly regularize the treatment effect. We set the tree and error variance priors as in the original BART model. " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The following figures provide a graphical depiction of how the BARDDT model fits a response surface and thereby estimates CATEs for distinct values of $\\mathrm{w}$. For simplicity only two trees are used in the illustration, while in practice dozens or hundreds of trees may be used (in our simulations and empirical example, we use 150 trees)." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "
\n", - " \n", - "
Two regression trees with splits in x and a single scalar w. Node images depict the g(x,w,z) function (in x) defined by that node's coefficients. The vertical gap between the two line segments in a node that contain x=0 is that node's contribution to the CATE at X = 0. Note that only such nodes contribute for CATE prediction at x=0
\n", - "
" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "
\n", - " \n", - "
The two top figures show the same two regression trees as in the preceding figure, now represented as a partition of the x-w plane. Labels in each partition correspond to the leaf nodes depicted in the previous picture. The bottom figure shows the partition of the x-w plane implied by the sum of the two trees; the red dashed line marks point W=w* and the combination of nodes that include this point
\n", - "
" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "
\n", - " \"trees3\"/\n", - "
Left: The function fit at W = w* for the two trees shown in the previous two figures, shown superimposed. Right: The aggregated fit achieved by summing the contributes of two regression tree fits shown at left. The magnitude of the discontinuity at x = 0 (located at the dashed gray vertical line) represents the treatment effect at that point. Different values of w will produce distinct fits; for the two trees shown, there can be three distinct fits based on the value of w.
\n", - "
" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "vscode": { - "languageId": "plaintext" - } - }, - "source": [ - "An interesting property of BARDDT can be seen in this small illustration --- by letting the regression trees split on the running variable, there is no need to separately define a 'bandwidth' as is used in the polynomial approach to RDD. Instead, the regression trees automatically determine (in the course of posterior sampling) when to 'prune' away regions away from the cutoff value. There are two notable features of this approach. One, different trees in the ensemble are effectively using different local bandwidths and these fits are then blended together. For example, in the bottom panel of the second figure, we obtain one bandwidth for the region $d+i$, and a different one for regions $a+g$ and $d+g$. Two, for cells in the tree partition that do not span the cutoff, the regression within that partition contains no causal contrasts --- all observations either have $Z = 1$ or $Z = 0$. For those cells, the treatment effect coefficient is ill-posed and in those cases the posterior sampling is effectively a draw from the prior; however, such draws correspond to points where the treatment effect is unidentified and none of these draws contribute to the estimation of $\\tau(0, \\mathrm{w})$ --- for example, only nodes $a+g$, $d+g$, and $d+i$ provide any contribution. This implies that draws of $\\Delta$ corresponding to nodes not predicting at $X=0$ will always be draws from the prior, which has some intuitive appeal." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Demo\n", - "\n", - "In this section, we provide code for implementing our model in `stochtree` on a popular RDD dataset.\n", - "First, let us load `stochtree` and all the necessary libraries for our posterior analysis." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import matplotlib.pyplot as plt\n", - "import seaborn as sns\n", - "import numpy as np\n", - "import pandas as pd\n", - "from sklearn.tree import DecisionTreeRegressor, plot_tree\n", - "from stochtree import BARTModel" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Dataset\n", - "\n", - "The data comes from Lindo et al (2010), who analyze data on college students enrolled in a large Canadian university in order to evaluate the effectiveness of an academic probation policy. Students who present a grade point average (GPA) lower than a certain threshold at the end of each term are placed on academic probation and must improve their GPA in the subsequent term or else face suspension. We are interested in how being put on probation or not, $Z$, affects students' GPA, $Y$, at the end of the current term. The running variable, $X$, is the negative distance between a student's previous-term GPA and the probation threshold, so that students placed on probation ($Z = 1$) have a positive score and the cutoff is 0. Potential moderators, $W$, are:\n", - "\n", - "* gender (`male`), \n", - "* age upon entering university (`age_at_entry`)\n", - "* a dummy for being born in North America (`bpl_north_america`), \n", - "* the number of credits taken in the first year (`totcredits_year1`)\n", - "* an indicator designating each of three campuses (`loc_campus` 1, 2 and 3), and\n", - "* high school GPA as a quantile w.r.t the university's incoming class (`hsgrade_pct`).\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Load and organize data\n", - "data = pd.read_csv(\"https://raw.githubusercontent.com/rdpackages-replication/CIT_2024_CUP/refs/heads/main/CIT_2024_CUP_discrete.csv\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "y = data.loc[:,\"nextGPA\"].to_numpy()\n", - "x = data.loc[:,\"X\"].to_numpy()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "x = x/np.std(x)\n", - "w = data.iloc[:,3:11]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ordered_cat = pd.api.types.CategoricalDtype(ordered=True)\n", - "unordered_cat = pd.api.types.CategoricalDtype(ordered=False)\n", - "w.loc[:,\"totcredits_year1\"] = w.loc[:,\"totcredits_year1\"].astype(ordered_cat)\n", - "w.loc[:,\"male\"] = w.loc[:,\"male\"].astype(unordered_cat)\n", - "w.loc[:,\"bpl_north_america\"] = w.loc[:,\"bpl_north_america\"].astype(unordered_cat)\n", - "w.loc[:,\"loc_campus1\"] = w.loc[:,\"loc_campus1\"].astype(unordered_cat)\n", - "w.loc[:,\"loc_campus2\"] = w.loc[:,\"loc_campus2\"].astype(unordered_cat)\n", - "w.loc[:,\"loc_campus3\"] = w.loc[:,\"loc_campus3\"].astype(unordered_cat)\n", - "c = 0\n", - "n = data.shape[0]\n", - "z = np.where(x > c, 1.0, 0.0)\n", - "# Window for prediction sample\n", - "h = 0.1\n", - "test = (x > -h) & (x < h)\n", - "ntest = np.sum(np.where(test, 1, 0))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Target estimand\n", - "\n", - "Generically, our estimand is the CATE function at $x = 0$; i.e. $\\tau(0, \\mathrm{w})$. The key practical question is which values of $\\mathrm{w}$ to consider. Some values of $\\mathrm{w}$ will not be well-represented near $x=0$ and so no estimation technique will be able to estimate those points effectively. As such, to focus on feasible points --- which will lead to interesting comparisons between methods --- we recommend restricting the evaluation points to the observed $\\mathrm{w}_i$ such that $|x_i| \\leq \\delta$, for some $\\delta > 0$. In our example, we use $\\delta = 0.1$ for a standardized $x$ variable. Therefore, our estimand of interest is a vector of treatment effects:\n", - "$$\\tau(0, \\mathrm{w}_i) \\;\\;\\; \\forall i \\;\\text{ such that }\\; |x_i| \\leq \\delta$$" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Implementing BARDDT\n", - "\n", - "In order to implement our model, we write the Psi vector, as defined before: `Psi = np.column_stack([np.ones(n), z * x, (1 - z) * x, z])`. The training matrix for the model is `np.column_stack([x, w])`, which we feed into the `BARTModel` sampler via the `X_train` parameter. The basis vector `Psi` is fed into the function via the `leaf_basis_train` parameter. The parameter list `barddt_mean_params` defines options for the mean forest (a different list can be defined for a variance forest in the case of heteroscedastic BART, which we do not consider here). Importantly, in this list we define parameter `sigma2_leaf_init = np.diag(np.repeat(0.1/150, 4))`, which sets $\\Sigma_0$ as described above. Now, we can fit the model, which is saved in object `barddt_model`.\n", - "\n", - "Once the model is fit, we need 3 elements to obtain the CATE predictions: the basis vectors at the cutoff for $z=1$ and $z=0$, the test matrix $[X \\quad W]$ at the cutoff, and the testing sample. We define the prediction basis vectors $\\psi_1 = [1 \\quad 0 \\quad 0 \\quad 1]$ and $\\psi_0 = [1 \\quad 0 \\quad 0 \\quad 0]$, which correspond to $\\psi$ at $(x=0,z=1)$, and $(x=0,z=0)$, respectively. These vectors are written into Python as `Psi1 = np.column_stack([np.ones(n), np.repeat(c, n), np.zeros(n), np.ones(n)])` and `Psi0 = np.column_stack([np.ones(n), np.zeros(n), np.repeat(c, n), np.zeros(n)])`. Then, we write the test matrix at $(x=0,\\mathrm{w})$ as `xmat_test = np.column_stack([np.zeros(n), w])[test,:]`. Finally, we must define the testing window. As discussed previously, our window is set such that $|x| \\leq 0.1$, which can be set in Python as `test = (x > -h) & (x < h)`.\n", - "\n", - "Once all of these elements are set, we can obtain the outcome predictions at the cutoff by running `barddt_model.predict(xmat_test, Psi1)` (resp. `barddt_model.predict(xmat_test, Psi0)`). Each of these calls returns a list, from which we can extract element `y_hat` to obtain the posterior distribution for the outcome. In the code below, the treated and control outcome predictions are saved in the matrix objects `pred1` and `pred0`, respectively. Now, we can obtain draws from the CATE posterior by simply subtracting these matrices. The function below outlines how to perform each of these steps in Python." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def estimate_barddt(y,x,w,z,test,c,num_gfr=10,num_mcmc=100,seed=None):\n", - " ## Lists of parameters for the Stochtree BART function\n", - " barddt_global_params = {\n", - " \"standardize\": True,\n", - " \"sample_sigma_global\": True,\n", - " \"sigma2_global_init\": 0.1\n", - " }\n", - " if seed is not None:\n", - " barddt_global_params[\"random_seed\"] = seed\n", - " barddt_mean_params = {\n", - " \"num_trees\": 50,\n", - " \"min_samples_leaf\": 20,\n", - " \"alpha\": 0.95,\n", - " \"beta\": 2,\n", - " \"max_depth\": 20,\n", - " \"sample_sigma2_leaf\": False,\n", - " \"sigma2_leaf_init\": np.diag(np.repeat(0.1/150, 4))\n", - " }\n", - " ## Set basis vector for leaf regressions\n", - " n = y.shape[0]\n", - " Psi = np.column_stack([np.ones(n), z * x, (1 - z) * x, z])\n", - " covariates = np.column_stack([x, w])\n", - " ## Model fit\n", - " barddt_model = BARTModel()\n", - " barddt_model.sample(\n", - " X_train=covariates,\n", - " y_train=y,\n", - " leaf_basis_train=Psi,\n", - " num_gfr=num_gfr,\n", - " num_mcmc=num_mcmc,\n", - " general_params=barddt_global_params,\n", - " mean_forest_params=barddt_mean_params\n", - " )\n", - " ## Define basis vectors and test matrix for outcome predictions at X=c\n", - " Psi1 = np.column_stack([np.ones(n), np.repeat(c, n), np.zeros(n), np.ones(n)])\n", - " Psi0 = np.column_stack([np.ones(n), np.zeros(n), np.repeat(c, n), np.zeros(n)])\n", - " Psi1 = Psi1[test,:]\n", - " Psi0 = Psi0[test,:]\n", - " xmat_test = np.column_stack([np.zeros(n), w])[test,:]\n", - " ## Obtain outcome predictions\n", - " pred1 = barddt_model.predict(xmat_test, Psi1)\n", - " pred0 = barddt_model.predict(xmat_test, Psi0)\n", - " ## Obtain CATE posterior\n", - " return pred1 - pred0" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now, we proceed to fit the BARDDT model." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "num_chains = 4\n", - "num_gfr = 2\n", - "num_mcmc = 100\n", - "cate_result = np.empty((ntest, num_chains*num_mcmc))\n", - "for i in range(num_chains):\n", - " cate_rdd = estimate_barddt(y,x,w,z,test,c,num_gfr=2,num_mcmc=100,seed=i)\n", - " cate_result[:,(i*num_mcmc):((i+1)*num_mcmc)] = cate_rdd" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We now proceed to analyze the CATE posterior. The figure produced below presents a summary of the CATE posterior produced by BARDDT for this application. This picture is produced fitting a regression tree, using $W$ as the predictors, to the individual posterior mean CATEs:\n", - "\\begin{equation}\n", - "\\bar{\\tau}_i = \\frac{1}{M} \\sum_{h = 1}^M \\tau^{(h)}(0, \\mathrm{w}_i),\n", - "\\end{equation}\n", - "where $h$ indexes each of $M$ total posterior samples. As in our simulation studies, we restrict our posterior analysis to use $\\mathrm{w}_i$ values of observations with $|x_i| \\leq \\delta = 0.1$ (after normalizing $X$ to have standard deviation 1 in-sample). For the Lindo et al (2010) data, this means that BARDDT was trained on $n = 40,582$ observations, of which 1,602 satisfy $x_i \\leq 0.1$, which were used to generate the effect moderation tree." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "## Fit regression tree\n", - "y_surrogate = np.mean(cate_rdd, axis=1)\n", - "X_surrogate = w.iloc[test,:]\n", - "cate_surrogate = DecisionTreeRegressor(min_impurity_decrease=0.0001)\n", - "cate_surrogate.fit(X=X_surrogate, y=y_surrogate)\n", - "plot_tree(cate_surrogate, impurity=False, filled=True, feature_names=w.columns, proportion=False, label='root', node_ids=True)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The resulting effect moderation tree indicates that course load (credits attempted) in the academic term leading to their probation is a strong moderator. Contextually, this result is plausible, both because course load could relate to latent character attributes that influence a student's responsiveness to sanctions and also because it could predict course load in the current term, which would in turn have implications for the GPA (i.e. it is harder to get a high GPA while taking more credit hours). The tree also suggests that effects differ by age and gender of the student. These findings are all prima facie plausible as well.\n", - "\n", - "To gauge how strong these findings are statistically, we can zoom in on isolated subgroups and compare the posteriors of their subgroup average treatment effects. This approach is valid because in fitting the effect moderation tree to the posterior mean CATEs we in no way altered the posterior itself; the effect moderation tree is a posterior summary tool and not any additional inferential approach; the posterior is obtained once and can be explored freely using a variety of techniques without vitiating its statistical validity. Investigating the most extreme differences is a good place to start: consider the two groups of students at opposite ends of the treatment effect range discovered by the effect moderation tree:\t\n", - "\n", - "* **Group A** a male student that attempted more than 4.8 credits in their first year (rightmost leaf node, colored red, comprising 211 individuals)\n", - "* **Group B** a female student of any gender who entered college younger than 19 (leftmost leaf node, colored deep orange, comprising 369 individuals).\n", - "\n", - "Subgroup CATEs are obtained by aggregating CATEs across the observed $\\mathrm{w}_i$ values for individuals in each group; this can be done for individual posterior samples, yielding a posterior distribution over the subgroup CATE:\n", - "\\begin{equation}\n", - "\\bar{\\tau}_A^{(h)} = \\frac{1}{n_A} \\sum_{i : \\mathrm{w}_i} \\tau^{(h)}(0, \\mathrm{w}_i),\n", - "\\end{equation}\n", - "where $h$ indexes a posterior draw and $n_A$ denotes the number of individuals in the group A.\n", - "\n", - "The code below produces a contour plot for a bivariate kernel density estimate of the joint CATE posterior distribution for subgroups A and B. The contour lines are nearly all above the $45^{\\circ}$ line, indicating that the preponderance of posterior probability falls in the region where the treatment effect for Group A is greater than that of Group B, meaning that the difference in the subgroup treatment effects flagged by the effect moderation tree persist even after accounting for estimation uncertainty in the underlying CATE function." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "predicted_nodes = cate_surrogate.apply(X=X_surrogate)\n", - "posterior_group_a = np.mean(cate_result[predicted_nodes==2,:],axis=0)\n", - "posterior_group_b = np.mean(cate_result[predicted_nodes==6,:],axis=0)\n", - "posterior_df = pd.DataFrame({'group_a': posterior_group_a, 'group_b': posterior_group_b})\n", - "sns.kdeplot(data=posterior_df, x=\"group_b\", y=\"group_a\")\n", - "plt.axline((0, 0), slope=1, color=\"black\", linestyle=(0, (3, 3)))\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "As always, CATEs that vary with observable factors do not necessarily represent a _causal_ moderating relationship. Here, if the treatment effect of academic probation is seen to vary with the number of credits, that does not imply that this association is causal: prescribing students to take a certain number of credits will not necessarily lead to a more effective probation policy, it may simply be that the type of student to naturally enroll for fewer credit hours is more likely to be responsive to academic probation. An entirely distinct set of causal assumptions are required to interpret the CATE variations themselves as causal. All the same, uncovering these patterns of treatment effect variability are crucial to suggesting causal mechanism to be investigated in future studies." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# References\n", - "\n", - "Lindo, Jason M., Nicholas J. Sanders, and Philip Oreopoulos. \"Ability, gender, and performance standards: Evidence from academic probation.\" American economic journal: Applied economics 2, no. 2 (2010): 95-117." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.17" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/vignettes/Python/RDD/trees1.png b/vignettes/Python/RDD/trees1.png deleted file mode 100644 index 0a5bc3acf..000000000 Binary files a/vignettes/Python/RDD/trees1.png and /dev/null differ diff --git a/vignettes/Python/RDD/trees2.png b/vignettes/Python/RDD/trees2.png deleted file mode 100644 index 90b7bd5f1..000000000 Binary files a/vignettes/Python/RDD/trees2.png and /dev/null differ diff --git a/vignettes/Python/RDD/trees3.png b/vignettes/Python/RDD/trees3.png deleted file mode 100644 index ededa55f5..000000000 Binary files a/vignettes/Python/RDD/trees3.png and /dev/null differ diff --git a/vignettes/R/IV/iv.Rmd b/vignettes/R/IV/iv.Rmd deleted file mode 100644 index 232c6f3d3..000000000 --- a/vignettes/R/IV/iv.Rmd +++ /dev/null @@ -1,774 +0,0 @@ ---- -title: 'Instrumental Variables (IV) with Stochtree' -author: - - P. Richard Hahn, Arizona State University - - Drew Herren, University of Texas at Austin -date: "`r Sys.Date()`" -output: html_document -bibliography: iv.bib ---- - - - -```{r setup, include=FALSE} -knitr::opts_chunk$set(echo = TRUE) -``` - -# Introduction - -Here we consider a causal inference problem with a binary treatment and a binary outcome where there -is unobserved confounding, but an exogenous instrument is available (also binary). This problem will require a number of extensions to the basic BART model, all of which can be implemented straightforwardly as Gibbs samplers using `stochtree`. We'll go through all of the model fitting steps in quite a lot of detail here. - -# Background - -To be concrete, suppose we wish to measure the effect of receiving a flu vaccine on the probability of getting the flu. Individuals who opt to get a flu shot differ in many ways from those that don't, and these lifestyle differences presumably also affect their respective chances of getting the flu. Consequently, comparing the percentage of individuals who get the flu in the vaccinated and unvaccinated groups does not give a clear picture of the vaccine efficacy. - -However, a so-called encouragement design can be implemented, where some individuals are selected at random to be given some extra incentive to get a flu shot (free clinics at the workplace or a personalized reminder, for example). Studying the impact of this randomized encouragement allows us to tease apart the impact of the vaccine from the confounding factors, at least to some extent. This exact problem has been considered several times in the literature, starting with @mcdonald1992effects with follow-on analysis by @hirano2000assessing, @richardson2011transparent, and @imbens2015causal. - -Our analysis here follows the Bayesian nonparametric approach described in the supplement to @hahn2016bayesian. - -## Notation - -Let $V$ denote the treatment variable (as in "vaccine"). Let $Y$ denote the response variable (getting the flu), $Z$ denote the instrument (encouragement or reminder to get a flu shot), and $X$ denote an additional observable covariate (for instance, patient age). - -Further, let $S$ denote the so-called *principal strata*, which is an exhaustive characterization of how individuals' might be affected by the encouragement regarding the flu shot. Some people will get a flu shot no matter what: these are the *always takers* (a). Some people will not get the flu shot no matter what: these are the *never takers* (n). For both always-takers and never-takers, the randomization of the encouragement is irrelevant and our data set contains no always takers who skipped the vaccine and no never takers who got the vaccine and so the treatment effect of the vaccine in these groups is fundamentally non-identifiable. - -By contrast, we also have *compliers* (c): folks who would not have gotten the shot but for the fact that they were encouraged to do so. These are the people about whom our randomized encouragement provides some information, because they are precisely the ones that have been randomized to treatment. - -Lastly, we could have *defiers* (d): contrarians who who were planning on getting the shot, but -- upon being reminded -- decided not to! For our analysis we will do the usual thing of assuming that there are no defiers. And because we are going to simulate our data, we can make sure that this assumption is true. - -## The causal diagram - -The causal diagram for this model can be expressed as follows. Here we are considering one confounder and moderator variable ($X$), which is the patient's age. In our data generating process (which we know because this is a simulation demonstration) higher age will make it more likely that a person is an always taker or complier and less likely that they are a never taker, which in turn has an effect on flu risk. We stipulate here that always takers are at lower risk and never takers at higher risk. Simultaneously, age has an increasing and then decreasing direct effect on flu risk; very young and very old are at higher risk, while young and middle age adults are at lower risk. In this DGP the flu efficacy has a multiplicative effect, reducing flu risk as a fixed proportion of baseline risk -- accordingly, the treatment effect (as a difference) is nonlinear in Age (for each principal stratum). - -```{r pressure, echo=FALSE, fig.cap="The causal directed acyclic graph (CDAG) for the instrumental variables flu example.", fig.align="center", out.width = '50%'} -knitr::include_graphics("IV_CDAG.png") -``` - -The biggest question about this graph concerns the dashed red arrow from the putative instrument $Z$ to the outcome (flu). We say "putative" because if that dashed red arrow is there, then technically $Z$ is not a valid instrument. The assumption/assertion that there is no dashed red arrow is called the "exclusion restriction". In this vignette, we will explore what sorts of inferences are possible if we remain agnostic about the presence or absence of that dashed red arrow. - -## Potential outcomes - -There are two relevant potential outcomes in an instrumental variables analysis, corresponding to the causal effect of the instrument on the treatment and the causal effect of the treatment on the outcome. In this example, that is the effect of the reminder/encouragement on vaccine status and the effect of the vaccine itself on the flu. The notation is $V(Z)$ and $Y(V(Z),Z)$ respectively, so that we have six distinct random variables: $V(0)$, $V(1)$, $Y(0,0)$, $Y(1,0)$, $Y(0,1)$ and $Y(1,1)$. The problem -- sometimes called the *fundamental problem of causal inference* -- is that some of these random variables can never be seen simultaneously, they are observationally mutually exclusive. For this reason, it may be helpful to think about causal inference as a missing data problem, as depicted in the following table. - -```{r missing_data, echo=FALSE} -d <- data.frame(i = c(1:4, "$\\vdots$"), z = c(1,0,0,1, "$\\vdots$"),v0=c("?",1,0,"?", "$\\vdots$"),v1=c(1,"?","?",0, "$\\vdots$"), y00 = c("?","?",1,"?", "$\\vdots$"), y10 = c("?",1,"?","?", "$\\vdots$"), y01 = c("?","?","?",0, "$\\vdots$"), y11 = c(0,"?","?","?", "$\\vdots$")) -library(kableExtra) -colnames(d) <- c("$i$","$Z_i$", "$V_i(0)$","$V_i(1)$","$Y_i(0,0)$","$Y_i(1,0)$","$Y_i(0,1)$","$Y_i(1,1)$") -knitr::kable(d, escape = FALSE, align = 'c') %>% kable_styling("striped", position = "center") -``` - -Likewise, with this notation we can formally define the principal strata: - -```{r principle_strata, echo=FALSE} -d <- data.frame(v0=c(0,1,0,1),v1=c(0,1,1,0), S = c("Never Taker (n)", "Always Taker (a)", "Complier (c)", "Defier (d)")) -colnames(d) <- c("$V_i(0)$","$V_i(1)$","$S_i$") -knitr::kable(d, escape = FALSE, align='c') %>% kable_styling("striped", position = "center") -``` - -## Estimands and Identification - -Let $\pi_s(x)$ denote the conditional (on $x$) probability that an individual belongs to principal stratum $s$: -\begin{equation} -\pi_s(x)=\operatorname{Pr}(S=s \mid X=x), -\end{equation} -and let $\gamma_s^{v z}(x)$ denote the potential outcome probability for given values $v$ and $z$: -\begin{equation} -\gamma_s^{v z}(x)=\operatorname{Pr}(Y(v, z)=1 \mid S=s, X=x). -\end{equation} - -Various estimands of interest may be expressed in terms of the functions $\gamma_c^{vz}(x)$. In particular, the complier conditional average treatment effect $$\gamma_c^{1,z}(x) - \gamma_c^{0,z}(x)$$ is the ultimate goal (for either $z=0$ or $z=1$). Under an exclusion restriction, we would have $\gamma_s^{vz}(x) = \gamma_s^{v}(x)$ and the reminder status $z$ itself would not matter. In that case, we can estimate $$\gamma_c^{1,z}(x) - \gamma_c^{0,z}$$ and $$\gamma_c^{1,1}(x) - \gamma_c^{0,0}(x).$$ This latter quantity is called the complier intent-to-treat effect, or $ITT_c$, and it can be partially identify even if the exclusion restriction is violated, as follows. - -The left-hand side of the following system of equations are all estimable quantities that can be learned from observable data, while the right hand side expressions involve the unknown functions of interest, $\gamma_s^{vz}(x)$: - -\begin{equation} -\begin{aligned} -p_{1 \mid 00}(x) = \operatorname{Pr}(Y=1 \mid V=0, Z=0, X=x)=\frac{\pi_c(x)}{\pi_c(x)+\pi_n(x)} \gamma_c^{00}(x)+\frac{\pi_n(x)}{\pi_c(x)+\pi_n(x)} \gamma_n^{00}(x) \\ -p_{1 \mid 11}(x) =\operatorname{Pr}(Y=1 \mid V=1, Z=1, X=x)=\frac{\pi_c(x)}{\pi_c(x)+\pi_a(x)} \gamma_c^{11}(x)+\frac{\pi_a(x)}{\pi_c(x)+\pi_a(x)} \gamma_a^{11}(x) \\ -p_{1 \mid 01}(x) =\operatorname{Pr}(Y=1 \mid V=0, Z=1, X=x)=\frac{\pi_d(x)}{\pi_d(x)+\pi_n(x)} \gamma_d^{01}(x)+\frac{\pi_n(x)}{\pi_d(x)+\pi_n(x)} \gamma_n^{01}(x) \\ -p_{1 \mid 10}(x) =\operatorname{Pr}(Y=1 \mid V=1, Z=0, X=x)=\frac{\pi_d(x)}{\pi_d(x)+\pi_a(x)} \gamma_d^{10}(x)+\frac{\pi_a(x)}{\pi_d(x)+\pi_a(x)} \gamma_a^{10}(x) -\end{aligned} -\end{equation} - -Furthermore, we have -\begin{equation} -\begin{aligned} -\operatorname{Pr}(V=1 \mid Z=0, X=x)&=\pi_a(x)+\pi_d(x)\\ -\operatorname{Pr}(V=1 \mid Z=1, X=x)&=\pi_a(x)+\pi_c(x) -\end{aligned} -\end{equation} - -Under the monotonicy assumption, $\pi_d(x) = 0$ and these expressions simplify somewhat. -\begin{equation} -\begin{aligned} -p_{1 \mid 00}(x)&=\frac{\pi_c(x)}{\pi_c(x)+\pi_n(x)} \gamma_c^{00}(x)+\frac{\pi_n(x)}{\pi_c(x)+\pi_n(x)} \gamma_n^{00}(x) \\ -p_{1 \mid 11}(x)&=\frac{\pi_c(x)}{\pi_c(x)+\pi_a(x)} \gamma_c^{11}(x)+\frac{\pi_a(x)}{\pi_c(x)+\pi_a(x)} \gamma_a^{11}(x) \\ -p_{1 \mid 01}(x)&=\gamma_n^{01}(x) \\ -p_{1 \mid 10}(x)&=\gamma_a^{10}(x) -\end{aligned} -\end{equation} -and -\begin{equation} -\begin{aligned} -\operatorname{Pr}(V=1 \mid Z=0, X=x)&=\pi_a(x)\\ -\operatorname{Pr}(V=1 \mid Z=1, X=x)&=\pi_a(x)+\pi_c(x) -\end{aligned} -\end{equation} - -The exclusion restriction would dictate that $\gamma_s^{01}(x) = \gamma_s^{00}(x)$ and $\gamma_s^{11}(x) = \gamma_s^{10}(x)$ for all $s$. This has two implications. One, $\gamma_n^{01}(x) = \gamma_n^{00}(x)$ and $\gamma_a^{10}(x) = \gamma_a^{11}(x)$,and because the left-hand terms are identified, this permits $\gamma_c^{11}(x)$ and $\gamma_c^{00}(x)$ to be solved for by substitution. Two, with these two quantities solved for, we also have the two other quantities (the different settings of $z$), since $\gamma_c^{11}(x) = \gamma_c^{10}(x)$ and $\gamma_c^{00}(x) = \gamma_c^{01}(x)$. Consequently, both of our estimands from above can be estimated: - -$$\gamma_c^{11}(x) - \gamma_c^{01}(x)$$ -and - -$$\gamma_c^{10}(x) - \gamma_c^{00}(x)$$ -because they are both (supposing the exclusion restriction holds) the same as - -$$\gamma_c^{11}(x) - \gamma_c^{00}(x).$$ -If the exclusion restriction does *not* hold, then the three above treatment effects are all (potentially) distinct and not much can be said about the former two. The latter one, the $ITT_c$, however, can be partially identified, by recognizing that the first two equations (in our four equation system) provide non-trivial bounds based on the fact that while $\gamma_c^{11}(x)$ and $\gamma_c^{00}(x)$ are no longer identified, as probabilities both must lie between 0 and 1. Thus, - -\begin{equation} -\begin{aligned} - \max\left( - 0, \frac{\pi_c(x)+\pi_n(x)}{\pi_c(x)}p_{1\mid 00}(x) - \frac{\pi_n(x)}{\pi_c(x)} - \right) -&\leq\gamma^{00}_c(x)\leq - \min\left( - 1, \frac{\pi_c(x)+\pi_n(x)}{\pi_c(x)}p_{1\mid 00}(x) - \right)\\\\ -% -\max\left( - 0, \frac{\pi_a(x)+\pi_c(x)}{\pi_c(x)}p_{1\mid 11}(x) - \frac{\pi_a(x)}{\pi_c(x)} -\right) -&\leq\gamma^{11}_c(x)\leq -\min\left( - 1, \frac{\pi_a(x)+\pi_c(x)}{\pi_c(x)}p_{1\mid 11}(x) -\right) -\end{aligned} -\end{equation} - -The point of all this is that the data (plus a no-defiers assumption) lets us estimate all the necessary inputs to these upper and lower bounds on $\gamma^{11}_c(x)$ and $\gamma^{00}_c(x)$ which in turn define our estimand. What remains is to estimate those inputs, as functions of $x$, and to do so while enforcing the monotonicty restriction $$\operatorname{Pr}(V=1 \mid Z=0, X=x)=\pi_a(x) \leq -\operatorname{Pr}(V=1 \mid Z=1, X=x)=\pi_a(x)+\pi_c(x).$$ - -We can do all of this with calls to stochtree from R (or Python). But first, let's generate some test data. - -### Simulate the data - -Start with some initial setup / housekeeping - -```{r preliminaries} -library(stochtree) - -# size of the training sample -n <- 20000 - -# To set the seed for reproducibility/illustration purposes, replace "NULL" with a positive integer -random_seed <- NULL -``` - -First, we generate the instrument exogenously - -```{r instrument} -z <- rbinom(n, 1, 0.5) -``` - -Next, we generate the covariate. (For this example, let's think of it as patient age, although we are generating it from a uniform distribution between 0 and 3, so you have to imagine that it has been pre-standardized to this scale. It keeps the DGPs cleaner for illustration purposes.) - -```{r covariate} -p_X <- 1 -X <- matrix(runif(n*p_X, 0, 3), ncol = p_X) -x <- X[,1] # for ease of reference later -``` - -Next, we generate the principal strata $S$ based on the observed value of $X$. We generate it according to a logistic regression with two coefficients per strata, an intercept and a slope. Here, these coefficients are set so that the probability of being a never taker decreases with age. - -```{r principal strata} -alpha_a <- 0 -beta_a <- 1 - -alpha_n <- 1 -beta_n <- -1 - -alpha_c <- 1 -beta_c <- 1 - -# define function (a logistic model) to generate Pr(S = s | X = x) -pi_s <- function(xval){ - - w_a <- exp(alpha_a + beta_a*xval) - w_n <- exp(alpha_n + beta_n*xval) - w_c <- exp(alpha_c + beta_c*xval) - - w <- cbind(w_a, w_n, w_c) - colnames(w) <- c("w_a","w_n","w_c") - w <- w/rowSums(w) - - return(w) - -} -s <- sapply(1:n, function(j) sample(c("a","n","c"), 1, prob = pi_s(X[j,1]))) -``` - -Next, we generate the treatment variable, here denoted $V$ (for "vaccine"), as a *deterministic* function of $S$ and $Z$; this is what gives the principal strata their meaning. - -```{r vaccine} -v <- 1*(s=="a") + 0*(s=="n") + z*(s=="c") + (1-z)*(s == "d") -``` - -Finally, the outcome structural model is specified, based on which the outcome is sampled. By varying this function in particular ways, we can alter the identification conditions. - -```{r ymodel} -gamfun <- function(xval,vval, zval,sval){ - - # if this function depends on zval, then exclusion restriction is violated - # if this function does not depend on sval, then IV analysis wasn't necessary - # if this function does not depend on x, then there are no HTEs - - baseline <- pnorm(2 -1*xval - 2.5*(xval-1.5)^2 - 0.5*zval + 1*(sval=="n") - 1*(sval=="a") ) - prob <- baseline - 0.5*vval*baseline # 0.5*vval*baseline - - return(prob) -} - -# Generate the observed outcome -y <- rbinom(n, 1, gamfun(X[,1],v,z,s)) -``` - -Lastly, we perform some organization for our supervised learning algorithms later on. - -```{r organizedata} -# Concatenate X, v and z for our supervised learning algorithms -Xall <- cbind(X,v,z) - -# update the size of "X" to be the size of Xall -p_X <- p_X + 2 - -# For the monotone probit model it is necessary to sort the observations so that the Z=1 cases are all together -# at the start of the outcome vector. -index <- sort(z,decreasing = TRUE, index.return = TRUE) - -X <- matrix(X[index$ix,],ncol= 1) -Xall <- Xall[index$ix,] -z <- z[index$ix] -v <- v[index$ix] -s <- s[index$ix] -y <- y[index$ix] -x <- x[index$ix] -``` - -Now let's see if we can recover these functions from the observed data. - - -### Fit the outcome model - -We have to fit three models here, the treatment models: $\operatorname{Pr}(V = 1 | Z = 1, X=x)$ and $\operatorname{Pr}(V = 1 | Z = 0,X = x)$, subject to the monotonicity constraint $\operatorname{Pr}(V = 1 | Z = 1, X=x) \geq \operatorname{Pr}(V = 1 | Z = 0,X = x)$, and an outcome model $\operatorname{Pr}(Y = 1 | Z = 1, V = 1, X = x)$. All of this will be done with stochtree. - -The outcome model is fit with a single (S-learner) BART model. This part of the model could be fit as a T-Learner or as a BCF model. Here we us an S-Learner for simplicity. Both models are probit models, and use the well-known @albert1993bayesian data augmentation Gibbs sampler. This section covers the more straightforward outcome model. The next section describes how the monotonicity constraint is handled with a data augmentation Gibbs sampler. - -These models could (and probably should) be wrapped as functions. Here they are implemented as scripts, with the full loops shown. The output -- at the end of the loops -- are stochtree forest objects from which we can extract posterior samples and generate predictions. In particular, the $ITT_c$ will be constructed using posterior counterfactual predictions derived from these forest objects. - -We begin by setting a bunch of hyperparameters and instantiating the forest objects to be operated upon in the main sampling loop. We also initialize the latent variables. - -```{r outcomefit1} -# Fit the BART model for Pr(Y = 1 | Z = 1, V = 1, X = x) - -# Set number of iterations -num_warmstart <- 10 -num_mcmc <- 1000 -num_samples <- num_warmstart + num_mcmc - -# Set a bunch of hyperparameters. These are ballpark default values. -alpha <- 0.95 -beta <- 2 -min_samples_leaf <- 1 -max_depth <- 20 -num_trees <- 50 -cutpoint_grid_size = 100 -global_variance_init = 1. -tau_init = 0.5 -leaf_prior_scale = matrix(c(tau_init), ncol = 1) -a_leaf <- 2. -b_leaf <- 0.5 -leaf_regression <- F -feature_types <- as.integer(c(rep(0, p_X-2),1,1)) # 0 = numeric -var_weights <- rep(1,p_X)/p_X -outcome_model_type <- 0 - -# C++ dataset -forest_dataset <- createForestDataset(Xall) - -# Random number generator (std::mt19937) -if (is.null(random_seed)) { - rng <- createCppRNG(-1) -} else { - rng <- createCppRNG(random_seed) -} - -# Sampling data structures -forest_model_config <- createForestModelConfig( - feature_types = feature_types, num_trees = num_trees, num_features = p_X, - num_observations = n, variable_weights = var_weights, leaf_dimension = 1, - alpha = alpha, beta = beta, min_samples_leaf = min_samples_leaf, - max_depth = max_depth, leaf_model_type = outcome_model_type, - leaf_model_scale = leaf_prior_scale, cutpoint_grid_size = cutpoint_grid_size -) -global_model_config <- createGlobalModelConfig(global_error_variance = 1) -forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) - -# Container of forest samples -forest_samples <- createForestSamples(num_trees, 1, T, F) - -# "Active" forest state -active_forest <- createForest(num_trees, 1, T, F) - -# Initialize the latent outcome zed -n1 <- sum(y) -zed <- 0.25*(2*as.numeric(y) - 1) - -# C++ outcome variable -outcome <- createOutcome(zed) - -# Initialize the active forest and subtract each root tree's predictions from outcome -active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, outcome_model_type, 0.0) -active_forest$adjust_residual(forest_dataset, outcome, forest_model, FALSE, FALSE) -``` - -Now we enter the main loop, which involves only two steps: sample the forest, given the latent utilities, then sample the latent utilities given the estimated conditional means defined by the forest and its parameters. - -```{r outcomefit2} -# Initialize the Markov chain with num_warmstart grow-from-root iterations -gfr_flag <- T -for (i in 1:num_samples) { - - # The first num_warmstart iterations use the grow-from-root algorithm of He and Hahn - if (i > num_warmstart){ - gfr_flag <- F - } - - # Sample forest - forest_model$sample_one_iteration( - forest_dataset, outcome, forest_samples, active_forest, - rng, forest_model_config, global_model_config, - keep_forest = T, gfr = gfr_flag - ) - - # Get the current means - eta <- forest_samples$predict_raw_single_forest(forest_dataset, i-1) - - # Sample latent normals, truncated according to the observed outcome y - U1 <- runif(n1,pnorm(0,eta[y==1],1),1) - zed[y==1] <- qnorm(U1,eta[y==1],1) - U0 <- runif(n - n1,0, pnorm(0,eta[y==0],1)) - zed[y==0] <- qnorm(U0,eta[y==0],1) - - # Propagate the newly sampled latent outcome to the BART model - outcome$update_data(zed) - forest_model$propagate_residual_update(outcome) -} -``` - -### Fit the monotone probit model(s) - -The monotonicty constraint relies on a data augmentation as described in @papakostas2023forecasts. The implementation of this sampler is inherently cumbersome, as one of the "data" vectors is constructed from some observed data and some latent data and there are two forest objects, one of which applies to all of the observations and one of which applies to only those observations with $Z = 0$. We go into more details about this sampler in a dedicated vignette. Here we include the code, but without producing the equations derived in @papakostas2023forecasts. What is most important is simply that - -\begin{equation} -\begin{aligned} -\operatorname{Pr}(V=1 \mid Z=0, X=x)&=\pi_a(x) = \Phi_f(x)\Phi_h(x),\\ -\operatorname{Pr}(V=1 \mid Z=1, X=x)&=\pi_a(x)+\pi_c(x) = \Phi_f(x), -\end{aligned} -\end{equation} -where $\Phi_{\mu}(x)$ denotes the normal cumulative distribution function with mean $\mu(x)$ and variance 1. - -We first create a secondary data matrix for the $Z=0$ group only. We also set all of the hyperparameters and initialize the latent variables. - -```{r treatmentfit1} -# Fit the monotone probit model to the treatment such that Pr(V = 1 | Z = 1, X=x) >= Pr(V = 1 | Z = 0,X = x) - -X_h <- as.matrix(X[z==0,]) -n0 <- sum(z==0) -n1 <- sum(z==1) - -num_trees_f <- 50 -num_trees_h <- 20 -feature_types <- as.integer(rep(0, 1)) # 0 = numeric -var_weights <- rep(1,1) -cutpoint_grid_size = 100 -global_variance_init = 1. -tau_init = 1/num_trees_h -leaf_prior_scale = matrix(c(tau_init), ncol = 1) -nu <- 4 -lambda <- 0.5 -a_leaf <- 2. -b_leaf <- 0.5 -leaf_regression <- F # fit a constant leaf mean BART model - -# Instantiate the C++ dataset objects -forest_dataset_f <- createForestDataset(X) -forest_dataset_h <- createForestDataset(X_h) - -# Tell it we're fitting a normal BART model -outcome_model_type <- 0 - -# Set up model configuration objects -forest_model_config_f <- createForestModelConfig( - feature_types = feature_types, num_trees = num_trees_f, num_features = ncol(X), - num_observations = nrow(X), variable_weights = var_weights, leaf_dimension = 1, - alpha = alpha, beta = beta, min_samples_leaf = min_samples_leaf, - max_depth = max_depth, leaf_model_type = outcome_model_type, - leaf_model_scale = leaf_prior_scale, cutpoint_grid_size = cutpoint_grid_size -) -forest_model_config_h <- createForestModelConfig( - feature_types = feature_types, num_trees = num_trees_h, num_features = ncol(X_h), - num_observations = nrow(X_h), variable_weights = var_weights, leaf_dimension = 1, - alpha = alpha, beta = beta, min_samples_leaf = min_samples_leaf, - max_depth = max_depth, leaf_model_type = outcome_model_type, - leaf_model_scale = leaf_prior_scale, cutpoint_grid_size = cutpoint_grid_size -) -global_model_config <- createGlobalModelConfig(global_error_variance = 1) - -# Instantiate the sampling data structures -forest_model_f <- createForestModel(forest_dataset_f, forest_model_config_f, global_model_config) -forest_model_h <- createForestModel(forest_dataset_h, forest_model_config_h, global_model_config) - -# Instantiate containers of forest samples -forest_samples_f <- createForestSamples(num_trees_f, 1, T) -forest_samples_h <- createForestSamples(num_trees_h, 1, T) - -# Instantiate "active" forests -active_forest_f <- createForest(num_trees_f, 1, T) -active_forest_h <- createForest(num_trees_h, 1, T) - -# Set algorithm specifications -# these are set in the earlier script for the outcome model; number of draws needs to be commensurable - -#num_warmstart <- 10 -#num_mcmc <- 2000 -#num_samples <- num_warmstart + num_mcmc - -# Initialize the Markov chain - -# Initialize (R0, R1), the latent binary variables that enforce the monotonicty - -v1 <- v[z==1] -v0 <- v[z==0] - -R1 = rep(NA,n0) -R0 = rep(NA,n0) - -R1[v0==1] <- 1 -R0[v0==1] <- 1 - -R1[v0 == 0] <- 0 -R0[v0 == 0] <- sample(c(0,1),sum(v0==0),replace=TRUE) - -# The first n1 observations of vaug are actually observed -# The next n0 of them are the latent variable R1 -vaug <- c(v1, R1) - -# Initialize the Albert and Chib latent Gaussian variables -z_f <- (2*as.numeric(vaug) - 1) -z_h <- (2*as.numeric(R0)-1) -z_f <- z_f/sd(z_f) -z_h <- z_h/sd(z_h) - -# Pass these variables to the BART models as outcome variables -outcome_f <- createOutcome(z_f) -outcome_h <- createOutcome(z_h) - -# Initialize active forests to constant (0) predictions -active_forest_f$prepare_for_sampler(forest_dataset_f, outcome_f, forest_model_f, outcome_model_type, 0.0) -active_forest_h$prepare_for_sampler(forest_dataset_h, outcome_h, forest_model_h, outcome_model_type, 0.0) -active_forest_f$adjust_residual(forest_dataset_f, outcome_f, forest_model_f, FALSE, FALSE) -active_forest_h$adjust_residual(forest_dataset_h, outcome_h, forest_model_h, FALSE, FALSE) -``` - -Now we run the main sampling loop, which consists of three key steps: sample the BART forests, given the latent probit utilities, sampling the latent binary outcome pairs (this is the step that is necessary for enforcing monotonicity), given the forest predictions and the latent utilities, and finally sample the latent utilities. - -```{r treatmentfit2} -# PART IV: run the Markov chain - -# Initialize the Markov chain with num_warmstart grow-from-root iterations -gfr_flag <- T -for (i in 1:num_samples) { - - # Switch over to random walk Metropolis-Hastings tree updates after num_warmstart - if (i > num_warmstart) { - gfr_flag <- F - } - - # Step 1: Sample the BART forests - - # Sample forest for the function f based on (y_f, R1) - forest_model_f$sample_one_iteration( - forest_dataset_f, outcome_f, forest_samples_f, active_forest_f, - rng, forest_model_config_f, global_model_config, - keep_forest = T, gfr = gfr_flag - ) - - # Sample forest for the function h based on outcome R0 - forest_model_h$sample_one_iteration( - forest_dataset_h, outcome_h, forest_samples_h, active_forest_h, - rng, forest_model_config_h, global_model_config, - keep_forest = T, gfr = gfr_flag - ) - - # Extract the means for use in sampling the latent variables - eta_f <- forest_samples_f$predict_raw_single_forest(forest_dataset_f, i-1) - eta_h <- forest_samples_h$predict_raw_single_forest(forest_dataset_h, i-1) - - - # Step 2: sample the latent binary pair (R0, R1) given eta_h, eta_f, and y_g - - # Three cases: (0,0), (0,1), (1,0) - w1 <- (1 - pnorm(eta_h[v0==0]))*(1-pnorm(eta_f[n1 + which(v0==0)])) - w2 <- (1 - pnorm(eta_h[v0==0]))*pnorm(eta_f[n1 + which(v0==0)]) - w3 <- pnorm(eta_h[v0==0])*(1 - pnorm(eta_f[n1 + which(v0==0)])) - - s <- w1 + w2 + w3 - w1 <- w1/s - w2 <- w2/s - w3 <- w3/s - - u <- runif(sum(v0==0)) - temp <- 1*(u < w1) + 2*(u > w1 & u < w1 + w2) + 3*(u > w1 + w2) - - R1[v0==0] <- 1*(temp==2) - R0[v0==0] <- 1*(temp==3) - - # Redefine y with the updated R1 component - vaug <- c(v1, R1) - - # Step 3: sample the latent normals, given (R0, R1) and y_f - - # First z0 - U1 <- runif(sum(R0),pnorm(0, eta_h[R0==1],1),1) - z_h[R0==1] <- qnorm(U1, eta_h[R0==1],1) - - U0 <- runif(n0 - sum(R0),0, pnorm(0, eta_h[R0==0],1)) - z_h[R0==0] <- qnorm(U0, eta_h[R0==0],1) - - # Then z1 - U1 <- runif(sum(vaug),pnorm(0, eta_f[vaug==1],1),1) - z_f[vaug==1] <- qnorm(U1, eta_f[vaug==1],1) - - U0 <- runif(n - sum(vaug),0, pnorm(0, eta_f[vaug==0],1)) - z_f[vaug==0] <- qnorm(U0, eta_f[vaug==0],1) - - # Propagate the updated outcomes through the BART models - outcome_h$update_data(z_h) - forest_model_h$propagate_residual_update(outcome_h) - - outcome_f$update_data(z_f) - forest_model_f$propagate_residual_update(outcome_f) - - # No more steps, just repeat a bunch of times -} -``` - -### Extracting the estimates and plotting the results. - -Now for the most interesting part, which is taking the stochtree BART model fits and producing the causal estimates of interest. - -First we set up our grid for plotting the functions in $X$. This is possible in this example because the moderator, age, is one dimensional; in may applied problems this will not be the case and visualization will be substantially trickier. - -```{r plot1} -# Extract the credible intervals for the conditional treatment effects as a function of x. -# We use a grid of values for plotting, with grid points that are typically fewer than the number of observations. - -ngrid <- 200 -xgrid <- seq(0.1,2.5,length.out = ngrid) -X_11 <- cbind(xgrid,rep(1,ngrid),rep(1,ngrid)) - -X_00 <- cbind(xgrid,rep(0,ngrid),rep(0,ngrid)) -X_01 <- cbind(xgrid,rep(0,ngrid),rep(1,ngrid)) -X_10 <- cbind(xgrid,rep(1,ngrid),rep(0,ngrid)) -``` - -Next, we compute the truth function evaluations on this plotting grid, using the functions defined above when we generated our data. - -```{r plot2} -# Compute the true conditional outcome probabilities for plotting -pi_strat <- pi_s(xgrid) -w_a <- pi_strat[,1] -w_n <- pi_strat[,2] -w_c <- pi_strat[,3] - -w <- (w_c/(w_a + w_c)) - -p11_true <- w*gamfun(xgrid,1,1,"c") + (1-w)*gamfun(xgrid,1,1,"a") - -w <- (w_c/(w_n + w_c)) - -p00_true <- w*gamfun(xgrid,0,0,"c") + (1-w)*gamfun(xgrid,0,0,"n") - -# Compute the true ITT_c for plotting and comparison -itt_c_true <- gamfun(xgrid,1,1,"c") - gamfun(xgrid,0,0,"c") - -# Compute the true LATE for plotting and comparison -LATE_true0 <- gamfun(xgrid,1,0,"c") - gamfun(xgrid,0,0,"c") -LATE_true1 <- gamfun(xgrid,1,1,"c") - gamfun(xgrid,0,1,"c") -``` - -Next we populate the data structures for stochtree to operate on, call the predict functions to extract the predictions, convert them to probability scale using the built in `pnorm` function. - -```{r plot3} -# Datasets for counterfactual predictions -forest_dataset_grid <- createForestDataset(cbind(xgrid)) -forest_dataset_11 <- createForestDataset(X_11) -forest_dataset_00 <- createForestDataset(X_00) -forest_dataset_10 <- createForestDataset(X_10) -forest_dataset_01 <- createForestDataset(X_01) - -# Forest predictions -preds_00 <- forest_samples$predict(forest_dataset_00) -preds_11 <- forest_samples$predict(forest_dataset_11) -preds_01 <- forest_samples$predict(forest_dataset_01) -preds_10 <- forest_samples$predict(forest_dataset_10) - -# Probability transformations -phat_00 <- pnorm(preds_00) -phat_11 <- pnorm(preds_11) -phat_01 <- pnorm(preds_01) -phat_10 <- pnorm(preds_10) - -# Cleanup -rm(preds_00) -rm(preds_11) -rm(preds_01) -rm(preds_10) - - -preds_ac <- forest_samples_f$predict(forest_dataset_grid) -phat_ac <- pnorm(preds_ac) - -preds_adj <- forest_samples_h$predict(forest_dataset_grid) -phat_a <- pnorm(preds_ac)*pnorm(preds_adj) -rm(preds_adj) -rm(preds_ac) - -phat_c <- phat_ac - phat_a - -phat_n <- 1 - phat_ac -``` - -Now we may plot posterior means of various quantities (as a function of $X$) to visualize how well the models are fitting. - -```{r plot4, fig.align='center'} -# Set up the plotting window -par(mfrow=c(1,2)) - -# Plot the fitted outcome probabilities against the truth -plot(p11_true,rowMeans(phat_11),pch=20,cex=0.5,bty='n') -abline(0,1,col='red') - -plot(p00_true,rowMeans(phat_00),pch=20,cex=0.5,bty='n') -abline(0,1,col='red') - -par(mfrow=c(1,3)) -plot(rowMeans(phat_ac),w_c +w_a,pch=20) -abline(0,1,col='red') - -plot(rowMeans(phat_a),w_a,pch=20) -abline(0,1,col='red') - -plot(rowMeans(phat_c),w_c,pch=20) -abline(0,1,col='red') -``` - -These plots are not as pretty as we might hope, but mostly this is a function of how difficult it is to learn conditional probabilities from binary outcomes. That we capture the trend broadly turns out to be adequate for estimating treatment effects. Fit does improve with simpler DGPs and larger training sets, as can be confirmed by experimentation with this script. - -Lastly, we can construct the estimate of the $ITT_c$ and compare it to the true value as well as the $Z=0$ and $Z=1$ complier average treatment effects (also called "local average treatment effects" or LATE). The key step in this process is to center our posterior on the identified interval (at each iteration of the sampler) at the value implied by a valid exclusion restriction. For some draws this will not be possible, as that value will be outside the identification region. - -```{r plot5, fig.height = 3} -# Generate draws from the posterior of the treatment effect -# centered at the point-identified value under the exclusion restriction -itt_c <- late <- matrix(NA,ngrid, ncol(phat_c)) -ss <- 6 -for (j in 1:ncol(phat_c)){ - - # Value of gamma11 implied by an exclusion restriction - gamest11 <- ((phat_a[,j] + phat_c[,j])/phat_c[,j])*phat_11[,j] - phat_10[,j]*phat_a[,j]/phat_c[,j] - - # Identified region for gamma11 - lower11 <- pmax(rep(0,ngrid), ((phat_a[,j] + phat_c[,j])/phat_c[,j])*phat_11[,j] - phat_a[,j]/phat_c[,j]) - upper11 <- pmin(rep(1,ngrid), ((phat_a[,j] + phat_c[,j])/phat_c[,j])*phat_11[,j]) - - # Center a beta distribution at gamma11, but restricted to (lower11, upper11) - # do this by shifting and scaling the mean, drawing from a beta on (0,1), then shifting and scaling to the - # correct restricted interval - m11 <- (gamest11 - lower11)/(upper11-lower11) - - # Parameters to the beta - a1 <- ss*m11 - b1 <- ss*(1-m11) - - # When the corresponding mean is out-of-range, sample from a beta with mass piled near the violated boundary - a1[m11<0] <- 1 - b1[m11<0] <- 5 - - a1[m11>1] <- 5 - b1[m11>1] <- 1 - - # Value of gamma00 implied by an exclusion restriction - gamest00 <- ((phat_n[,j] + phat_c[,j])/phat_c[,j])*phat_00[,j] - phat_01[,j]*phat_n[,j]/phat_c[,j] - - # Identified region for gamma00 - lower00 <- pmax(rep(0,ngrid), ((phat_n[,j] + phat_c[,j])/phat_c[,j])*phat_00[,j] - phat_n[,j]/phat_c[,j]) - upper00 <- pmin(rep(1,ngrid), ((phat_n[,j] + phat_c[,j])/phat_c[,j])*phat_00[,j]) - - # Center a beta distribution at gamma00, but restricted to (lower00, upper00) - # do this by shifting and scaling the mean, drawing from a beta on (0,1), then shifting and scaling to the - # correct restricted interval - m00 <- (gamest00 - lower00)/(upper00-lower00) - - a0 <- ss*m00 - b0 <- ss*(1-m00) - - a0[m00<0] <- 1 - b0[m00<0] <- 5 - - a0[m00>1] <- 5 - b0[m00>1] <- 1 - - # ITT and LATE - itt_c[,j] <- lower11 + (upper11 - lower11)*rbeta(ngrid,a1,b1) - (lower00 + (upper00 - lower00)*rbeta(ngrid,a0,b0)) - late[,j] <- gamest11 - gamest00 -} - -upperq <- apply(itt_c,1,quantile,0.975) -lowerq <- apply(itt_c,1,quantile,0.025) -``` - -And now we can plot all of this, using the "polygon" function to shade posterior quantiles. - -```{r plot6, fig.align="center"} -par(mfrow=c(1,1)) -plot(xgrid,itt_c_true,pch=20,cex=0.5,ylim=c(-0.75,0.05),bty='n',type='n',xlab='x',ylab='Treatment effect') - -upperq_er <- apply(late,1,quantile,0.975,na.rm=TRUE) - -lowerq_er <- apply(late,1,quantile,0.025,na.rm=TRUE) - -polygon(c(xgrid,rev(xgrid)),c(lowerq,rev(upperq)),col=rgb(0.5,0.25,0,0.25),pch=20,border=FALSE) -polygon(c(xgrid,rev(xgrid)),c(lowerq_er,rev(upperq_er)),col=rgb(0,0,0.5,0.25),pch=20,border=FALSE) - -itt_c_est <- rowMeans(itt_c) -late_est <- rowMeans(late) -lines(xgrid,late_est,col="slategray",lwd=3) - -lines(xgrid,itt_c_est,col='goldenrod1',lwd=1) - -lines(xgrid,LATE_true0,col="black",lwd=2,lty=3) -lines(xgrid,LATE_true1,col="black",lwd=2,lty=2) - -lines(xgrid,itt_c_true,col="black",lwd=1) -``` - -With a valid exclusion restriction the three black curves would all be the same. With no exclusion restriction, as we have here, the direct effect of $Z$ on $Y$ (the vaccine reminder on flu status) makes it so these three treatment effects are different. Specifically, the $ITT_c$ compares getting the vaccine *and* getting the reminder to not getting the vaccine *and* not getting the reminder. When both things have risk reducing impacts, we see a larger risk reduction over all values of $X$. Meanwhile, the two LATE effects compare the isolated impact of the vaccine among people that got the reminder and those that didn't, respectively. Here, not getting the reminder makes the vaccine more effective because the risk reduction is as a fraction of baseline risk, and the reminder reduces baseline risk in our DGP. - -We see also that the posterior mean of the $ITT_c$ estimate (gold) is very similar to the posterior mean under the assumption of an exclusion restriction (gray). This is by design...they will only deviate due to Monte Carlo variation or due to the rare situations where the exclusion restriction is incompatible with the identification interval. - -By changing the sample size and various aspects of the DGP this script allows us to build some intuition for how aspects of the DGP affect posterior inferences, particularly how violates of assumptions affect accuracy and posterior uncertainty. - -# References diff --git a/vignettes/R/IV/iv.bib b/vignettes/R/IV/iv.bib deleted file mode 100644 index b0eba609b..000000000 --- a/vignettes/R/IV/iv.bib +++ /dev/null @@ -1,79 +0,0 @@ -@article{mcdonald1992effects, - title={Effects of computer reminders for influenza vaccination on morbidity during influenza epidemics.}, - author={McDonald, Clement J and Hui, Siu L and Tierney, William M}, - journal={MD computing: computers in medical practice}, - volume={9}, - number={5}, - pages={304--312}, - year={1992} -} - -@article{hirano2000assessing, - author = {Hirano, Keisuke and Imbens, Guido W. and Rubin, Donald B. and Zhou, Xiao-Hua}, - title = {Assessing the effect of an influenza vaccine in an - encouragement design }, - journal = {Biostatistics}, - volume = {1}, - number = {1}, - pages = {69-88}, - year = {2000}, - month = {03}, - issn = {1465-4644}, - doi = {10.1093/biostatistics/1.1.69}, - url = {https://doi.org/10.1093/biostatistics/1.1.69}, - eprint = {https://academic.oup.com/biostatistics/article-pdf/1/1/69/17744019/100069.pdf}, -} - -@incollection{richardson2011transparent, - author = {Richardson, Thomas S. and Evans, Robin J. and Robins, James M.}, - isbn = {9780199694587}, - title = {Transparent Parametrizations of Models for Potential Outcomes}, - booktitle = {Bayesian Statistics 9}, - publisher = {Oxford University Press}, - year = {2011}, - month = {10}, - doi = {10.1093/acprof:oso/9780199694587.003.0019}, - url = {https://doi.org/10.1093/acprof:oso/9780199694587.003.0019}, - eprint = {https://academic.oup.com/book/0/chapter/141661815/chapter-ag-pdf/45787772/book\_1879\_section\_141661815.ag.pdf}, -} - -@book{imbens2015causal, - place={Cambridge}, - title={Causal Inference for Statistics, Social, and Biomedical Sciences: An Introduction}, - publisher={Cambridge University Press}, - author={Imbens, Guido W. and Rubin, Donald B.}, - year={2015} -} - -@article{hahn2016bayesian, - title={A Bayesian partial identification approach to inferring the prevalence of accounting misconduct}, - author={Hahn, P Richard and Murray, Jared S and Manolopoulou, Ioanna}, - journal={Journal of the American Statistical Association}, - volume={111}, - number={513}, - pages={14--26}, - year={2016}, - publisher={Taylor \& Francis} -} - -@article{albert1993bayesian, - title={Bayesian analysis of binary and polychotomous response data}, - author={Albert, James H and Chib, Siddhartha}, - journal={Journal of the American statistical Association}, - volume={88}, - number={422}, - pages={669--679}, - year={1993}, - publisher={Taylor \& Francis} -} - -@article{papakostas2023forecasts, - title={Do forecasts of bankruptcy cause bankruptcy? A machine learning sensitivity analysis}, - author={Papakostas, Demetrios and Hahn, P Richard and Murray, Jared and Zhou, Frank and Gerakos, Joseph}, - journal={The Annals of Applied Statistics}, - volume={17}, - number={1}, - pages={711--739}, - year={2023}, - publisher={Institute of Mathematical Statistics} -} \ No newline at end of file diff --git a/vignettes/R/IV/iv.html b/vignettes/R/IV/iv.html deleted file mode 100644 index 98de3c632..000000000 --- a/vignettes/R/IV/iv.html +++ /dev/null @@ -1,1795 +0,0 @@ - - - - - - - - - - - - - - - - -Instrumental Variables (IV) with Stochtree - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
- - - - - - - - -
-

Introduction

-

Here we consider a causal inference problem with a binary treatment -and a binary outcome where there is unobserved confounding, but an -exogenous instrument is available (also binary). This problem will -require a number of extensions to the basic BART model, all of which can -be implemented straightforwardly as Gibbs samplers using -stochtree. We’ll go through all of the model fitting steps -in quite a lot of detail here.

-
-
-

Background

-

To be concrete, suppose we wish to measure the effect of receiving a -flu vaccine on the probability of getting the flu. Individuals who opt -to get a flu shot differ in many ways from those that don’t, and these -lifestyle differences presumably also affect their respective chances of -getting the flu. Consequently, comparing the percentage of individuals -who get the flu in the vaccinated and unvaccinated groups does not give -a clear picture of the vaccine efficacy.

-

However, a so-called encouragement design can be implemented, where -some individuals are selected at random to be given some extra incentive -to get a flu shot (free clinics at the workplace or a personalized -reminder, for example). Studying the impact of this randomized -encouragement allows us to tease apart the impact of the vaccine from -the confounding factors, at least to some extent. This exact problem has -been considered several times in the literature, starting with McDonald, Hui, and Tierney (1992) with follow-on -analysis by Hirano et al. (2000), Richardson, Evans, and Robins (2011), and Imbens and Rubin (2015).

-

Our analysis here follows the Bayesian nonparametric approach -described in the supplement to Hahn, Murray, and -Manolopoulou (2016).

-
-

Notation

-

Let \(V\) denote the treatment -variable (as in β€œvaccine”). Let \(Y\) -denote the response variable (getting the flu), \(Z\) denote the instrument (encouragement or -reminder to get a flu shot), and \(X\) -denote an additional observable covariate (for instance, patient -age).

-

Further, let \(S\) denote the -so-called principal strata, which is an exhaustive -characterization of how individuals’ might be affected by the -encouragement regarding the flu shot. Some people will get a flu shot no -matter what: these are the always takers (a). Some people will -not get the flu shot no matter what: these are the never takers -(n). For both always-takers and never-takers, the randomization of the -encouragement is irrelevant and our data set contains no always takers -who skipped the vaccine and no never takers who got the vaccine and so -the treatment effect of the vaccine in these groups is fundamentally -non-identifiable.

-

By contrast, we also have compliers (c): folks who would not -have gotten the shot but for the fact that they were encouraged to do -so. These are the people about whom our randomized encouragement -provides some information, because they are precisely the ones that have -been randomized to treatment.

-

Lastly, we could have defiers (d): contrarians who who were -planning on getting the shot, but – upon being reminded – decided not -to! For our analysis we will do the usual thing of assuming that there -are no defiers. And because we are going to simulate our data, we can -make sure that this assumption is true.

-
-
-

The causal diagram

-

The causal diagram for this model can be expressed as follows. Here -we are considering one confounder and moderator variable (\(X\)), which is the patient’s age. In our -data generating process (which we know because this is a simulation -demonstration) higher age will make it more likely that a person is an -always taker or complier and less likely that they are a never taker, -which in turn has an effect on flu risk. We stipulate here that always -takers are at lower risk and never takers at higher risk. -Simultaneously, age has an increasing and then decreasing direct effect -on flu risk; very young and very old are at higher risk, while young and -middle age adults are at lower risk. In this DGP the flu efficacy has a -multiplicative effect, reducing flu risk as a fixed proportion of -baseline risk – accordingly, the treatment effect (as a difference) is -nonlinear in Age (for each principal stratum).

-
-The causal directed acyclic graph (CDAG) for the instrumental variables flu example. -

-The causal directed acyclic graph (CDAG) for the instrumental variables -flu example. -

-
-

The biggest question about this graph concerns the dashed red arrow -from the putative instrument \(Z\) to -the outcome (flu). We say β€œputative” because if that dashed red arrow is -there, then technically \(Z\) is not a -valid instrument. The assumption/assertion that there is no dashed red -arrow is called the β€œexclusion restriction”. In this vignette, we will -explore what sorts of inferences are possible if we remain agnostic -about the presence or absence of that dashed red arrow.

-
-
-

Potential outcomes

-

There are two relevant potential outcomes in an instrumental -variables analysis, corresponding to the causal effect of the instrument -on the treatment and the causal effect of the treatment on the outcome. -In this example, that is the effect of the reminder/encouragement on -vaccine status and the effect of the vaccine itself on the flu. The -notation is \(V(Z)\) and \(Y(V(Z),Z)\) respectively, so that we have -six distinct random variables: \(V(0)\), \(V(1)\), \(Y(0,0)\), \(Y(1,0)\), \(Y(0,1)\) and \(Y(1,1)\). The problem – sometimes called -the fundamental problem of causal inference – is that some of -these random variables can never be seen simultaneously, they are -observationally mutually exclusive. For this reason, it may be helpful -to think about causal inference as a missing data problem, as depicted -in the following table.

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
-\(i\) - -\(Z_i\) - -\(V_i(0)\) - -\(V_i(1)\) - -\(Y_i(0,0)\) - -\(Y_i(1,0)\) - -\(Y_i(0,1)\) - -\(Y_i(1,1)\) -
-1 - -1 - -? - -1 - -? - -? - -? - -0 -
-2 - -0 - -1 - -? - -? - -1 - -? - -? -
-3 - -0 - -0 - -? - -1 - -? - -? - -? -
-4 - -1 - -? - -0 - -? - -? - -0 - -? -
-\(\vdots\) - -\(\vdots\) - -\(\vdots\) - -\(\vdots\) - -\(\vdots\) - -\(\vdots\) - -\(\vdots\) - -\(\vdots\) -
-

Likewise, with this notation we can formally define the principal -strata:

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
-\(V_i(0)\) - -\(V_i(1)\) - -\(S_i\) -
-0 - -0 - -Never Taker (n) -
-1 - -1 - -Always Taker (a) -
-0 - -1 - -Complier (c) -
-1 - -0 - -Defier (d) -
-
-
-

Estimands and Identification

-

Let \(\pi_s(x)\) denote the -conditional (on \(x\)) probability that -an individual belongs to principal stratum \(s\): \[\begin{equation} -\pi_s(x)=\operatorname{Pr}(S=s \mid X=x), -\end{equation}\] and let \(\gamma_s^{v -z}(x)\) denote the potential outcome probability for given values -\(v\) and \(z\): \[\begin{equation} -\gamma_s^{v z}(x)=\operatorname{Pr}(Y(v, z)=1 \mid S=s, X=x). -\end{equation}\]

-

Various estimands of interest may be expressed in terms of the -functions \(\gamma_c^{vz}(x)\). In -particular, the complier conditional average treatment effect \[\gamma_c^{1,z}(x) - \gamma_c^{0,z}(x)\] is -the ultimate goal (for either \(z=0\) -or \(z=1\)). Under an exclusion -restriction, we would have \(\gamma_s^{vz}(x) -= \gamma_s^{v}(x)\) and the reminder status \(z\) itself would not matter. In that case, -we can estimate \[\gamma_c^{1,z}(x) - -\gamma_c^{0,z}\] and \[\gamma_c^{1,1}(x) - \gamma_c^{0,0}(x).\] -This latter quantity is called the complier intent-to-treat effect, or -\(ITT_c\), and it can be partially -identify even if the exclusion restriction is violated, as follows.

-

The left-hand side of the following system of equations are all -estimable quantities that can be learned from observable data, while the -right hand side expressions involve the unknown functions of interest, -\(\gamma_s^{vz}(x)\):

-

\[\begin{equation} -\begin{aligned} -p_{1 \mid 00}(x) = \operatorname{Pr}(Y=1 \mid V=0, Z=0, -X=x)=\frac{\pi_c(x)}{\pi_c(x)+\pi_n(x)} -\gamma_c^{00}(x)+\frac{\pi_n(x)}{\pi_c(x)+\pi_n(x)} \gamma_n^{00}(x) \\ -p_{1 \mid 11}(x) =\operatorname{Pr}(Y=1 \mid V=1, Z=1, -X=x)=\frac{\pi_c(x)}{\pi_c(x)+\pi_a(x)} -\gamma_c^{11}(x)+\frac{\pi_a(x)}{\pi_c(x)+\pi_a(x)} \gamma_a^{11}(x) \\ -p_{1 \mid 01}(x) =\operatorname{Pr}(Y=1 \mid V=0, Z=1, -X=x)=\frac{\pi_d(x)}{\pi_d(x)+\pi_n(x)} -\gamma_d^{01}(x)+\frac{\pi_n(x)}{\pi_d(x)+\pi_n(x)} \gamma_n^{01}(x) \\ -p_{1 \mid 10}(x) =\operatorname{Pr}(Y=1 \mid V=1, Z=0, -X=x)=\frac{\pi_d(x)}{\pi_d(x)+\pi_a(x)} -\gamma_d^{10}(x)+\frac{\pi_a(x)}{\pi_d(x)+\pi_a(x)} \gamma_a^{10}(x) -\end{aligned} -\end{equation}\]

-

Furthermore, we have \[\begin{equation} -\begin{aligned} -\operatorname{Pr}(V=1 \mid Z=0, X=x)&=\pi_a(x)+\pi_d(x)\\ -\operatorname{Pr}(V=1 \mid Z=1, X=x)&=\pi_a(x)+\pi_c(x) -\end{aligned} -\end{equation}\]

-

Under the monotonicy assumption, \(\pi_d(x) -= 0\) and these expressions simplify somewhat. \[\begin{equation} -\begin{aligned} -p_{1 \mid 00}(x)&=\frac{\pi_c(x)}{\pi_c(x)+\pi_n(x)} -\gamma_c^{00}(x)+\frac{\pi_n(x)}{\pi_c(x)+\pi_n(x)} \gamma_n^{00}(x) \\ -p_{1 \mid 11}(x)&=\frac{\pi_c(x)}{\pi_c(x)+\pi_a(x)} -\gamma_c^{11}(x)+\frac{\pi_a(x)}{\pi_c(x)+\pi_a(x)} \gamma_a^{11}(x) \\ -p_{1 \mid 01}(x)&=\gamma_n^{01}(x) \\ -p_{1 \mid 10}(x)&=\gamma_a^{10}(x) -\end{aligned} -\end{equation}\] and \[\begin{equation} -\begin{aligned} -\operatorname{Pr}(V=1 \mid Z=0, X=x)&=\pi_a(x)\\ -\operatorname{Pr}(V=1 \mid Z=1, X=x)&=\pi_a(x)+\pi_c(x) -\end{aligned} -\end{equation}\]

-

The exclusion restriction would dictate that \(\gamma_s^{01}(x) = \gamma_s^{00}(x)\) and -\(\gamma_s^{11}(x) = \gamma_s^{10}(x)\) -for all \(s\). This has two -implications. One, \(\gamma_n^{01}(x) = -\gamma_n^{00}(x)\) and \(\gamma_a^{10}(x) = \gamma_a^{11}(x)\),and -because the left-hand terms are identified, this permits \(\gamma_c^{11}(x)\) and \(\gamma_c^{00}(x)\) to be solved for by -substitution. Two, with these two quantities solved for, we also have -the two other quantities (the different settings of \(z\)), since \(\gamma_c^{11}(x) = \gamma_c^{10}(x)\) and -\(\gamma_c^{00}(x) = -\gamma_c^{01}(x)\). Consequently, both of our estimands from -above can be estimated:

-

\[\gamma_c^{11}(x) - -\gamma_c^{01}(x)\] and

-

\[\gamma_c^{10}(x) - -\gamma_c^{00}(x)\] because they are both (supposing the exclusion -restriction holds) the same as

-

\[\gamma_c^{11}(x) - -\gamma_c^{00}(x).\] If the exclusion restriction does -not hold, then the three above treatment effects are all -(potentially) distinct and not much can be said about the former two. -The latter one, the \(ITT_c\), however, -can be partially identified, by recognizing that the first two equations -(in our four equation system) provide non-trivial bounds based on the -fact that while \(\gamma_c^{11}(x)\) -and \(\gamma_c^{00}(x)\) are no longer -identified, as probabilities both must lie between 0 and 1. Thus,

-

\[\begin{equation} -\begin{aligned} - \max\left( - 0, \frac{\pi_c(x)+\pi_n(x)}{\pi_c(x)}p_{1\mid 00}(x) - -\frac{\pi_n(x)}{\pi_c(x)} - \right) -&\leq\gamma^{00}_c(x)\leq - \min\left( - 1, \frac{\pi_c(x)+\pi_n(x)}{\pi_c(x)}p_{1\mid 00}(x) - \right)\\\\ -% -\max\left( - 0, \frac{\pi_a(x)+\pi_c(x)}{\pi_c(x)}p_{1\mid 11}(x) - -\frac{\pi_a(x)}{\pi_c(x)} -\right) -&\leq\gamma^{11}_c(x)\leq -\min\left( - 1, \frac{\pi_a(x)+\pi_c(x)}{\pi_c(x)}p_{1\mid 11}(x) -\right) -\end{aligned} -\end{equation}\]

-

The point of all this is that the data (plus a no-defiers assumption) -lets us estimate all the necessary inputs to these upper and lower -bounds on \(\gamma^{11}_c(x)\) and -\(\gamma^{00}_c(x)\) which in turn -define our estimand. What remains is to estimate those inputs, as -functions of \(x\), and to do so while -enforcing the monotonicty restriction \[\operatorname{Pr}(V=1 \mid Z=0, X=x)=\pi_a(x) -\leq -\operatorname{Pr}(V=1 \mid Z=1, X=x)=\pi_a(x)+\pi_c(x).\]

-

We can do all of this with calls to stochtree from R (or Python). But -first, let’s generate some test data.

-
-

Simulate the data

-

Start with some initial setup / housekeeping

-
library(stochtree)
-
-# size of the training sample
-n <- 20000
-
-# To set the seed for reproducibility/illustration purposes, replace "NULL" with a positive integer
-random_seed <- NULL
-

First, we generate the instrument exogenously

-
z <- rbinom(n, 1, 0.5)
-

Next, we generate the covariate. (For this example, let’s think of it -as patient age, although we are generating it from a uniform -distribution between 0 and 3, so you have to imagine that it has been -pre-standardized to this scale. It keeps the DGPs cleaner for -illustration purposes.)

-
p_X <- 1
-X <- matrix(runif(n*p_X, 0, 3), ncol = p_X)
-x <- X[,1] # for ease of reference later
-

Next, we generate the principal strata \(S\) based on the observed value of \(X\). We generate it according to a logistic -regression with two coefficients per strata, an intercept and a slope. -Here, these coefficients are set so that the probability of being a -never taker decreases with age.

-
alpha_a <- 0
-beta_a <- 1
-
-alpha_n <- 1
-beta_n <- -1
-
-alpha_c <- 1
-beta_c <- 1
-
-# define function (a logistic model) to generate Pr(S = s | X = x)
-pi_s <- function(xval){
-  
-  w_a <- exp(alpha_a + beta_a*xval)
-  w_n <- exp(alpha_n + beta_n*xval)
-  w_c <- exp(alpha_c + beta_c*xval)
-   
-  w <- cbind(w_a, w_n, w_c)
-  colnames(w) <- c("w_a","w_n","w_c")
-  w <- w/rowSums(w)
-  
-  return(w)
-  
-}
-s <- sapply(1:n, function(j) sample(c("a","n","c"), 1, prob = pi_s(X[j,1])))
-

Next, we generate the treatment variable, here denoted \(V\) (for β€œvaccine”), as a -deterministic function of \(S\) and \(Z\); this is what gives the principal -strata their meaning.

-
v <- 1*(s=="a") + 0*(s=="n") + z*(s=="c") + (1-z)*(s == "d")
-

Finally, the outcome structural model is specified, based on which -the outcome is sampled. By varying this function in particular ways, we -can alter the identification conditions.

-
gamfun <- function(xval,vval, zval,sval){
-  
-  # if this function depends on zval, then exclusion restriction is violated
-  # if this function does not depend on sval, then IV analysis wasn't necessary
-  # if this function does not depend on x, then there are no HTEs
-  
-  baseline <- pnorm(2 -1*xval - 2.5*(xval-1.5)^2 - 0.5*zval + 1*(sval=="n") - 1*(sval=="a") )
-  prob <- baseline - 0.5*vval*baseline # 0.5*vval*baseline
-  
-  return(prob)
-}
-
-# Generate the observed outcome
-y <- rbinom(n, 1, gamfun(X[,1],v,z,s))
-

Lastly, we perform some organization for our supervised learning -algorithms later on.

-
# Concatenate X, v and z for our supervised learning algorithms
-Xall <- cbind(X,v,z)
-
-# update the size of "X" to be the size of Xall
-p_X <- p_X + 2
-
-# For the monotone probit model it is necessary to sort the observations so that the Z=1 cases are all together
-# at the start of the outcome vector.  
-index <- sort(z,decreasing = TRUE, index.return = TRUE)
-
-X <- matrix(X[index$ix,],ncol= 1)
-Xall <- Xall[index$ix,]
-z <- z[index$ix]
-v <- v[index$ix]
-s <- s[index$ix]
-y <- y[index$ix]
-x <- x[index$ix]
-

Now let’s see if we can recover these functions from the observed -data.

-
-
-

Fit the outcome model

-

We have to fit three models here, the treatment models: \(\operatorname{Pr}(V = 1 | Z = 1, X=x)\) and -\(\operatorname{Pr}(V = 1 | Z = 0,X = -x)\), subject to the monotonicity constraint \(\operatorname{Pr}(V = 1 | Z = 1, X=x) \geq -\operatorname{Pr}(V = 1 | Z = 0,X = x)\), and an outcome model -\(\operatorname{Pr}(Y = 1 | Z = 1, V = 1, X = -x)\). All of this will be done with stochtree.

-

The outcome model is fit with a single (S-learner) BART model. This -part of the model could be fit as a T-Learner or as a BCF model. Here we -us an S-Learner for simplicity. Both models are probit models, and use -the well-known Albert and Chib (1993) data -augmentation Gibbs sampler. This section covers the more straightforward -outcome model. The next section describes how the monotonicity -constraint is handled with a data augmentation Gibbs sampler.

-

These models could (and probably should) be wrapped as functions. -Here they are implemented as scripts, with the full loops shown. The -output – at the end of the loops – are stochtree forest objects from -which we can extract posterior samples and generate predictions. In -particular, the \(ITT_c\) will be -constructed using posterior counterfactual predictions derived from -these forest objects.

-

We begin by setting a bunch of hyperparameters and instantiating the -forest objects to be operated upon in the main sampling loop. We also -initialize the latent variables.

-
# Fit the BART model for Pr(Y = 1 | Z = 1, V = 1, X = x)
-
-# Set number of iterations
-num_warmstart <- 10
-num_mcmc <- 1000
-num_samples <- num_warmstart + num_mcmc
-
-# Set a bunch of hyperparameters. These are ballpark default values.
-alpha <- 0.95
-beta <- 2
-min_samples_leaf <- 1
-max_depth <- 20
-num_trees <- 50
-cutpoint_grid_size = 100
-global_variance_init = 1.
-tau_init = 0.5
-leaf_prior_scale = matrix(c(tau_init), ncol = 1)
-a_leaf <- 2.
-b_leaf <- 0.5
-leaf_regression <- F
-feature_types <- as.integer(c(rep(0, p_X-2),1,1)) # 0 = numeric
-var_weights <- rep(1,p_X)/p_X
-outcome_model_type <- 0
-
-# C++ dataset
-forest_dataset <- createForestDataset(Xall)
-
-# Random number generator (std::mt19937)
-if (is.null(random_seed)) {
-    rng <- createCppRNG(-1)
-} else {
-    rng <- createCppRNG(random_seed)
-}
-
-# Sampling data structures
-forest_model_config <- createForestModelConfig(
-  feature_types = feature_types, num_trees = num_trees, num_features = p_X, 
-  num_observations = n, variable_weights = var_weights, leaf_dimension = 1, 
-  alpha = alpha, beta = beta, min_samples_leaf = min_samples_leaf, 
-  max_depth = max_depth, leaf_model_type = outcome_model_type, 
-  leaf_model_scale = leaf_prior_scale, cutpoint_grid_size = cutpoint_grid_size
-)
-global_model_config <- createGlobalModelConfig(global_error_variance = 1)
-forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config)
-
-# Container of forest samples
-forest_samples <- createForestSamples(num_trees, 1, T, F)
-
-# "Active" forest state
-active_forest <- createForest(num_trees, 1, T, F)
-
-# Initialize the latent outcome zed
-n1 <- sum(y)
-zed <- 0.25*(2*as.numeric(y) - 1)
-
-# C++ outcome variable
-outcome <- createOutcome(zed)
-
-# Initialize the active forest and subtract each root tree's predictions from outcome
-active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, outcome_model_type, 0.0)
-active_forest$adjust_residual(forest_dataset, outcome, forest_model, FALSE, FALSE)
-

Now we enter the main loop, which involves only two steps: sample the -forest, given the latent utilities, then sample the latent utilities -given the estimated conditional means defined by the forest and its -parameters.

-
# Initialize the Markov chain with num_warmstart grow-from-root iterations
-gfr_flag <- T
-for (i in 1:num_samples) {
-  
-  # The first num_warmstart iterations use the grow-from-root algorithm of He and Hahn
-  if (i > num_warmstart){
-    gfr_flag <- F
-  } 
-  
-  # Sample forest
-  forest_model$sample_one_iteration(
-    forest_dataset, outcome, forest_samples, active_forest, 
-    rng, forest_model_config, global_model_config, 
-    keep_forest = T, gfr = gfr_flag
-  )
-  
-  # Get the current means
-  eta <- forest_samples$predict_raw_single_forest(forest_dataset, i-1)
-  
-  # Sample latent normals, truncated according to the observed outcome y
-  U1 <- runif(n1,pnorm(0,eta[y==1],1),1)
-  zed[y==1] <- qnorm(U1,eta[y==1],1)
-  U0 <- runif(n - n1,0, pnorm(0,eta[y==0],1))
-  zed[y==0] <- qnorm(U0,eta[y==0],1)
-  
-  # Propagate the newly sampled latent outcome to the BART model
-  outcome$update_data(zed)
-  forest_model$propagate_residual_update(outcome)
-}
-
-
-

Fit the monotone probit model(s)

-

The monotonicty constraint relies on a data augmentation as described -in Papakostas et al. (2023). The -implementation of this sampler is inherently cumbersome, as one of the -β€œdata” vectors is constructed from some observed data and some latent -data and there are two forest objects, one of which applies to all of -the observations and one of which applies to only those observations -with \(Z = 0\). We go into more details -about this sampler in a dedicated vignette. Here we include the code, -but without producing the equations derived in Papakostas et al. (2023). What is most important -is simply that

-

\[\begin{equation} -\begin{aligned} -\operatorname{Pr}(V=1 \mid Z=0, X=x)&=\pi_a(x) = -\Phi_f(x)\Phi_h(x),\\ -\operatorname{Pr}(V=1 \mid Z=1, X=x)&=\pi_a(x)+\pi_c(x) = \Phi_f(x), -\end{aligned} -\end{equation}\] where \(\Phi_{\mu}(x)\) denotes the normal -cumulative distribution function with mean \(\mu(x)\) and variance 1.

-

We first create a secondary data matrix for the \(Z=0\) group only. We also set all of the -hyperparameters and initialize the latent variables.

-
# Fit the monotone probit model to the treatment such that Pr(V = 1 | Z = 1, X=x) >= Pr(V = 1 | Z = 0,X = x) 
-
-X_h <- as.matrix(X[z==0,])
-n0 <- sum(z==0)
-n1 <- sum(z==1)
-
-num_trees_f <- 50
-num_trees_h <- 20
-feature_types <- as.integer(rep(0, 1)) # 0 = numeric
-var_weights <- rep(1,1)
-cutpoint_grid_size = 100
-global_variance_init = 1.
-tau_init = 1/num_trees_h
-leaf_prior_scale = matrix(c(tau_init), ncol = 1)
-nu <- 4
-lambda <- 0.5
-a_leaf <- 2.
-b_leaf <- 0.5
-leaf_regression <- F # fit a constant leaf mean BART model
-
-# Instantiate the C++ dataset objects
-forest_dataset_f <- createForestDataset(X)
-forest_dataset_h <- createForestDataset(X_h)
-
-# Tell it we're fitting a normal BART model
-outcome_model_type <- 0
-
-# Set up model configuration objects
-forest_model_config_f <- createForestModelConfig(
-  feature_types = feature_types, num_trees = num_trees_f, num_features = ncol(X), 
-  num_observations = nrow(X), variable_weights = var_weights, leaf_dimension = 1, 
-  alpha = alpha, beta = beta, min_samples_leaf = min_samples_leaf, 
-  max_depth = max_depth, leaf_model_type = outcome_model_type, 
-  leaf_model_scale = leaf_prior_scale, cutpoint_grid_size = cutpoint_grid_size
-)
-forest_model_config_h <- createForestModelConfig(
-  feature_types = feature_types, num_trees = num_trees_h, num_features = ncol(X_h), 
-  num_observations = nrow(X_h), variable_weights = var_weights, leaf_dimension = 1, 
-  alpha = alpha, beta = beta, min_samples_leaf = min_samples_leaf, 
-  max_depth = max_depth, leaf_model_type = outcome_model_type, 
-  leaf_model_scale = leaf_prior_scale, cutpoint_grid_size = cutpoint_grid_size
-)
-global_model_config <- createGlobalModelConfig(global_error_variance = 1)
-
-# Instantiate the sampling data structures
-forest_model_f <- createForestModel(forest_dataset_f, forest_model_config_f, global_model_config)
-forest_model_h <- createForestModel(forest_dataset_h, forest_model_config_h, global_model_config)
-
-# Instantiate containers of forest samples
-forest_samples_f <- createForestSamples(num_trees_f, 1, T)
-forest_samples_h <- createForestSamples(num_trees_h, 1, T)
-
-# Instantiate "active" forests
-active_forest_f <- createForest(num_trees_f, 1, T)
-active_forest_h <- createForest(num_trees_h, 1, T)
-
-# Set algorithm specifications 
-# these are set in the earlier script for the outcome model; number of draws needs to be commensurable 
-
-#num_warmstart <- 10
-#num_mcmc <- 2000
-#num_samples <- num_warmstart + num_mcmc
-
-# Initialize the Markov chain
-
-# Initialize (R0, R1), the latent binary variables that enforce the monotonicty 
-
-v1 <- v[z==1]
-v0 <- v[z==0]
-
-R1 = rep(NA,n0)
-R0 = rep(NA,n0)
-
-R1[v0==1] <- 1
-R0[v0==1] <- 1
-
-R1[v0 == 0] <- 0
-R0[v0 == 0] <- sample(c(0,1),sum(v0==0),replace=TRUE)
-
-# The first n1 observations of vaug are actually observed
-# The next n0 of them are the latent variable R1
-vaug <- c(v1, R1)
-
-# Initialize the Albert and Chib latent Gaussian variables
-z_f <- (2*as.numeric(vaug) - 1)
-z_h <- (2*as.numeric(R0)-1)
-z_f <- z_f/sd(z_f)
-z_h <- z_h/sd(z_h)
-
-# Pass these variables to the BART models as outcome variables
-outcome_f <- createOutcome(z_f)
-outcome_h <- createOutcome(z_h)
-
-# Initialize active forests to constant (0) predictions
-active_forest_f$prepare_for_sampler(forest_dataset_f, outcome_f, forest_model_f, outcome_model_type, 0.0)
-active_forest_h$prepare_for_sampler(forest_dataset_h, outcome_h, forest_model_h, outcome_model_type, 0.0)
-active_forest_f$adjust_residual(forest_dataset_f, outcome_f, forest_model_f, FALSE, FALSE)
-active_forest_h$adjust_residual(forest_dataset_h, outcome_h, forest_model_h, FALSE, FALSE)
-

Now we run the main sampling loop, which consists of three key steps: -sample the BART forests, given the latent probit utilities, sampling the -latent binary outcome pairs (this is the step that is necessary for -enforcing monotonicity), given the forest predictions and the latent -utilities, and finally sample the latent utilities.

-
# PART IV: run the Markov chain 
-
-# Initialize the Markov chain with num_warmstart grow-from-root iterations
-gfr_flag <- T
-for (i in 1:num_samples) {
-  
-  # Switch over to random walk Metropolis-Hastings tree updates after num_warmstart
-  if (i > num_warmstart) {
-    gfr_flag <- F
-  }
-  
-  # Step 1: Sample the BART forests
-  
-  # Sample forest for the function f based on (y_f, R1)
-  forest_model_f$sample_one_iteration(
-    forest_dataset_f, outcome_f, forest_samples_f, active_forest_f, 
-    rng, forest_model_config_f, global_model_config, 
-    keep_forest = T, gfr = gfr_flag
-  )
-  
-  # Sample forest for the function h based on outcome R0
-  forest_model_h$sample_one_iteration(
-    forest_dataset_h, outcome_h, forest_samples_h, active_forest_h,
-    rng, forest_model_config_h, global_model_config, 
-    keep_forest = T, gfr = gfr_flag
-  )
-  
-  # Extract the means for use in sampling the latent variables
-  eta_f <- forest_samples_f$predict_raw_single_forest(forest_dataset_f, i-1)
-  eta_h <- forest_samples_h$predict_raw_single_forest(forest_dataset_h, i-1)
-  
-  
-  # Step 2: sample the latent binary pair (R0, R1) given eta_h, eta_f, and y_g
-  
-  # Three cases: (0,0), (0,1), (1,0)
-  w1 <- (1 - pnorm(eta_h[v0==0]))*(1-pnorm(eta_f[n1 + which(v0==0)]))
-  w2 <-   (1 - pnorm(eta_h[v0==0]))*pnorm(eta_f[n1 + which(v0==0)])
-  w3 <- pnorm(eta_h[v0==0])*(1 - pnorm(eta_f[n1 + which(v0==0)]))
-  
-  s <- w1 + w2 + w3
-  w1 <- w1/s
-  w2 <- w2/s
-  w3 <- w3/s
-  
-  u <- runif(sum(v0==0))
-  temp <- 1*(u < w1) + 2*(u > w1 & u < w1 + w2) + 3*(u > w1 + w2)
-  
-  R1[v0==0] <- 1*(temp==2)
-  R0[v0==0] <- 1*(temp==3)
-  
-  # Redefine y with the updated R1 component 
-  vaug <- c(v1, R1)
-  
-  # Step 3: sample the latent normals, given (R0, R1) and y_f
-  
-  # First z0
-  U1 <- runif(sum(R0),pnorm(0, eta_h[R0==1],1),1)
-  z_h[R0==1] <- qnorm(U1, eta_h[R0==1],1)
-  
-  U0 <- runif(n0 - sum(R0),0, pnorm(0, eta_h[R0==0],1))
-  z_h[R0==0] <- qnorm(U0, eta_h[R0==0],1)
-  
-  # Then z1
-  U1 <- runif(sum(vaug),pnorm(0, eta_f[vaug==1],1),1)
-  z_f[vaug==1] <- qnorm(U1, eta_f[vaug==1],1)
-  
-  U0 <- runif(n - sum(vaug),0, pnorm(0, eta_f[vaug==0],1))
-  z_f[vaug==0] <- qnorm(U0, eta_f[vaug==0],1)
-  
-  # Propagate the updated outcomes through the BART models
-  outcome_h$update_data(z_h)
-  forest_model_h$propagate_residual_update(outcome_h)
-  
-  outcome_f$update_data(z_f)
-  forest_model_f$propagate_residual_update(outcome_f)
-  
-  # No more steps, just repeat a bunch of times
-}
-
-
-

Extracting the estimates and plotting the results.

-

Now for the most interesting part, which is taking the stochtree BART -model fits and producing the causal estimates of interest.

-

First we set up our grid for plotting the functions in \(X\). This is possible in this example -because the moderator, age, is one dimensional; in may applied problems -this will not be the case and visualization will be substantially -trickier.

-
# Extract the credible intervals for the conditional treatment effects as a function of x.
-# We use a grid of values for plotting, with grid points that are typically fewer than the number of observations.
-
-ngrid <- 200
-xgrid <- seq(0.1,2.5,length.out = ngrid)
-X_11 <- cbind(xgrid,rep(1,ngrid),rep(1,ngrid))
-
-X_00 <- cbind(xgrid,rep(0,ngrid),rep(0,ngrid))
-X_01 <- cbind(xgrid,rep(0,ngrid),rep(1,ngrid))
-X_10 <- cbind(xgrid,rep(1,ngrid),rep(0,ngrid))
-

Next, we compute the truth function evaluations on this plotting -grid, using the functions defined above when we generated our data.

-
# Compute the true conditional outcome probabilities for plotting
-pi_strat <- pi_s(xgrid)
-w_a <- pi_strat[,1]
-w_n <- pi_strat[,2]
-w_c <- pi_strat[,3]
-
-w <- (w_c/(w_a + w_c))
-
-p11_true <- w*gamfun(xgrid,1,1,"c") + (1-w)*gamfun(xgrid,1,1,"a")
-
-w <- (w_c/(w_n + w_c))
-
-p00_true <- w*gamfun(xgrid,0,0,"c") + (1-w)*gamfun(xgrid,0,0,"n")
-
-# Compute the true ITT_c for plotting and comparison
-itt_c_true <- gamfun(xgrid,1,1,"c") - gamfun(xgrid,0,0,"c")
-
-# Compute the true LATE for plotting and comparison
-LATE_true0 <- gamfun(xgrid,1,0,"c") - gamfun(xgrid,0,0,"c")
-LATE_true1 <- gamfun(xgrid,1,1,"c") - gamfun(xgrid,0,1,"c")
-

Next we populate the data structures for stochtree to operate on, -call the predict functions to extract the predictions, convert them to -probability scale using the built in pnorm function.

-
# Datasets for counterfactual predictions
-forest_dataset_grid <- createForestDataset(cbind(xgrid))
-forest_dataset_11 <- createForestDataset(X_11)
-forest_dataset_00 <- createForestDataset(X_00)
-forest_dataset_10 <- createForestDataset(X_10)
-forest_dataset_01 <- createForestDataset(X_01)
-
-# Forest predictions
-preds_00 <- forest_samples$predict(forest_dataset_00)
-preds_11 <- forest_samples$predict(forest_dataset_11)
-preds_01 <- forest_samples$predict(forest_dataset_01)
-preds_10 <- forest_samples$predict(forest_dataset_10)
-
-# Probability transformations
-phat_00 <- pnorm(preds_00)
-phat_11 <- pnorm(preds_11)
-phat_01 <- pnorm(preds_01)
-phat_10 <- pnorm(preds_10)
-
-# Cleanup
-rm(preds_00)
-rm(preds_11)
-rm(preds_01)
-rm(preds_10)
-
-
-preds_ac <- forest_samples_f$predict(forest_dataset_grid)
-phat_ac <- pnorm(preds_ac)
-
-preds_adj <- forest_samples_h$predict(forest_dataset_grid)
-phat_a <- pnorm(preds_ac)*pnorm(preds_adj)
-rm(preds_adj)
-rm(preds_ac)
-
-phat_c <- phat_ac - phat_a
-
-phat_n <- 1 - phat_ac
-

Now we may plot posterior means of various quantities (as a function -of \(X\)) to visualize how well the -models are fitting.

-
# Set up the plotting window
-par(mfrow=c(1,2))
-
-# Plot the fitted outcome probabilities against the truth
-plot(p11_true,rowMeans(phat_11),pch=20,cex=0.5,bty='n')
-abline(0,1,col='red')
-
-plot(p00_true,rowMeans(phat_00),pch=20,cex=0.5,bty='n')
-abline(0,1,col='red')
-

-
par(mfrow=c(1,3))
-plot(rowMeans(phat_ac),w_c +w_a,pch=20)
-abline(0,1,col='red')
-
-plot(rowMeans(phat_a),w_a,pch=20)
-abline(0,1,col='red')
-
-plot(rowMeans(phat_c),w_c,pch=20)
-abline(0,1,col='red')
-

-

These plots are not as pretty as we might hope, but mostly this is a -function of how difficult it is to learn conditional probabilities from -binary outcomes. That we capture the trend broadly turns out to be -adequate for estimating treatment effects. Fit does improve with simpler -DGPs and larger training sets, as can be confirmed by experimentation -with this script.

-

Lastly, we can construct the estimate of the \(ITT_c\) and compare it to the true value as -well as the \(Z=0\) and \(Z=1\) complier average treatment effects -(also called β€œlocal average treatment effects” or LATE). The key step in -this process is to center our posterior on the identified interval (at -each iteration of the sampler) at the value implied by a valid exclusion -restriction. For some draws this will not be possible, as that value -will be outside the identification region.

-
# Generate draws from the posterior of the treatment effect
-# centered at the point-identified value under the exclusion restriction
-itt_c <- late <- matrix(NA,ngrid, ncol(phat_c))
-ss <- 6
-for (j in 1:ncol(phat_c)){
-  
-  # Value of gamma11 implied by an exclusion restriction
-  gamest11 <- ((phat_a[,j] + phat_c[,j])/phat_c[,j])*phat_11[,j] - phat_10[,j]*phat_a[,j]/phat_c[,j]
-  
-  # Identified region for gamma11
-  lower11 <- pmax(rep(0,ngrid), ((phat_a[,j] + phat_c[,j])/phat_c[,j])*phat_11[,j] - phat_a[,j]/phat_c[,j])
-  upper11 <- pmin(rep(1,ngrid), ((phat_a[,j] + phat_c[,j])/phat_c[,j])*phat_11[,j])
-  
-  # Center a beta distribution at gamma11, but restricted to (lower11, upper11)
-  # do this by shifting and scaling the mean, drawing from a beta on (0,1), then shifting and scaling to the 
-  # correct restricted interval
-  m11 <- (gamest11 - lower11)/(upper11-lower11)
-
-  # Parameters to the beta
-  a1 <- ss*m11
-  b1 <- ss*(1-m11)
-  
-  # When the corresponding mean is out-of-range, sample from a beta with mass piled near the violated boundary
-  a1[m11<0] <- 1
-  b1[m11<0] <- 5
-  
-  a1[m11>1] <- 5
-  b1[m11>1] <- 1
-  
-  # Value of gamma00 implied by an exclusion restriction
-  gamest00 <- ((phat_n[,j] + phat_c[,j])/phat_c[,j])*phat_00[,j] - phat_01[,j]*phat_n[,j]/phat_c[,j]
-  
-  # Identified region for gamma00
-  lower00 <- pmax(rep(0,ngrid), ((phat_n[,j] + phat_c[,j])/phat_c[,j])*phat_00[,j] - phat_n[,j]/phat_c[,j])
-  upper00 <- pmin(rep(1,ngrid), ((phat_n[,j] + phat_c[,j])/phat_c[,j])*phat_00[,j])
-  
-  # Center a beta distribution at gamma00, but restricted to (lower00, upper00)
-  # do this by shifting and scaling the mean, drawing from a beta on (0,1), then shifting and scaling to the 
-  # correct restricted interval
-  m00 <- (gamest00 - lower00)/(upper00-lower00)
-  
-  a0 <- ss*m00
-  b0 <- ss*(1-m00)
-  
-  a0[m00<0] <- 1
-  b0[m00<0] <- 5
-  
-  a0[m00>1] <- 5
-  b0[m00>1] <- 1
- 
-  # ITT and LATE    
-  itt_c[,j] <- lower11 + (upper11 - lower11)*rbeta(ngrid,a1,b1) - (lower00 + (upper00 - lower00)*rbeta(ngrid,a0,b0))
-  late[,j] <- gamest11 - gamest00
-}
-
-upperq <- apply(itt_c,1,quantile,0.975)
-lowerq <- apply(itt_c,1,quantile,0.025)
-

And now we can plot all of this, using the β€œpolygon” function to -shade posterior quantiles.

-
par(mfrow=c(1,1))
-plot(xgrid,itt_c_true,pch=20,cex=0.5,ylim=c(-0.75,0.05),bty='n',type='n',xlab='x',ylab='Treatment effect')
-
-upperq_er <- apply(late,1,quantile,0.975,na.rm=TRUE)
-
-lowerq_er <- apply(late,1,quantile,0.025,na.rm=TRUE)
-
-polygon(c(xgrid,rev(xgrid)),c(lowerq,rev(upperq)),col=rgb(0.5,0.25,0,0.25),pch=20,border=FALSE)
-polygon(c(xgrid,rev(xgrid)),c(lowerq_er,rev(upperq_er)),col=rgb(0,0,0.5,0.25),pch=20,border=FALSE)
-
-itt_c_est <- rowMeans(itt_c)
-late_est <- rowMeans(late)
-lines(xgrid,late_est,col="slategray",lwd=3)
-
-lines(xgrid,itt_c_est,col='goldenrod1',lwd=1)
-
-lines(xgrid,LATE_true0,col="black",lwd=2,lty=3)
-lines(xgrid,LATE_true1,col="black",lwd=2,lty=2)
-
-lines(xgrid,itt_c_true,col="black",lwd=1)
-

-

With a valid exclusion restriction the three black curves would all -be the same. With no exclusion restriction, as we have here, the direct -effect of \(Z\) on \(Y\) (the vaccine reminder on flu status) -makes it so these three treatment effects are different. Specifically, -the \(ITT_c\) compares getting the -vaccine and getting the reminder to not getting the vaccine -and not getting the reminder. When both things have risk -reducing impacts, we see a larger risk reduction over all values of -\(X\). Meanwhile, the two LATE effects -compare the isolated impact of the vaccine among people that got the -reminder and those that didn’t, respectively. Here, not getting the -reminder makes the vaccine more effective because the risk reduction is -as a fraction of baseline risk, and the reminder reduces baseline risk -in our DGP.

-

We see also that the posterior mean of the \(ITT_c\) estimate (gold) is very similar to -the posterior mean under the assumption of an exclusion restriction -(gray). This is by design…they will only deviate due to Monte Carlo -variation or due to the rare situations where the exclusion restriction -is incompatible with the identification interval.

-

By changing the sample size and various aspects of the DGP this -script allows us to build some intuition for how aspects of the DGP -affect posterior inferences, particularly how violates of assumptions -affect accuracy and posterior uncertainty.

-
-
-
-
-

References

-
-
-Albert, James H, and Siddhartha Chib. 1993. β€œBayesian Analysis of -Binary and Polychotomous Response Data.” Journal of the -American Statistical Association 88 (422): 669–79. -
-
-Hahn, P Richard, Jared S Murray, and Ioanna Manolopoulou. 2016. β€œA -Bayesian Partial Identification Approach to Inferring the Prevalence of -Accounting Misconduct.” Journal of the American Statistical -Association 111 (513): 14–26. -
-
-Hirano, Keisuke, Guido W. Imbens, Donald B. Rubin, and Xiao-Hua Zhou. -2000. β€œAssessing the Effect of an Influenza Vaccine in an -Encouragement Design.” Biostatistics 1 (1): 69–88. https://doi.org/10.1093/biostatistics/1.1.69. -
-
-Imbens, Guido W., and Donald B. Rubin. 2015. Causal Inference for -Statistics, Social, and Biomedical Sciences: An Introduction. -Cambridge University Press. -
-
-McDonald, Clement J, Siu L Hui, and William M Tierney. 1992. -β€œEffects of Computer Reminders for Influenza Vaccination on -Morbidity During Influenza Epidemics.” MD Computing: -Computers in Medical Practice 9 (5): 304–12. -
-
-Papakostas, Demetrios, P Richard Hahn, Jared Murray, Frank Zhou, and -Joseph Gerakos. 2023. β€œDo Forecasts of Bankruptcy Cause -Bankruptcy? A Machine Learning Sensitivity Analysis.” The -Annals of Applied Statistics 17 (1): 711–39. -
-
-Richardson, Thomas S., Robin J. Evans, and James M. Robins. 2011. -β€œTransparent Parametrizations of Models for Potential -Outcomes.” In Bayesian Statistics 9. Oxford University -Press. https://doi.org/10.1093/acprof:oso/9780199694587.003.0019. -
-
-
- - - - -
- - - - - - - - - - - - - - - diff --git a/vignettes/R/RDD/rdd.bib b/vignettes/R/RDD/rdd.bib deleted file mode 100644 index ec0c287a7..000000000 --- a/vignettes/R/RDD/rdd.bib +++ /dev/null @@ -1,10 +0,0 @@ -@article{lindo2010ability, - title={Ability, gender, and performance standards: Evidence from academic probation}, - author={Lindo, Jason M and Sanders, Nicholas J and Oreopoulos, Philip}, - journal={American economic journal: Applied economics}, - volume={2}, - number={2}, - pages={95--117}, - year={2010}, - publisher={American Economic Association} -} \ No newline at end of file diff --git a/vignettes/R/RDD/rdd.html b/vignettes/R/RDD/rdd.html deleted file mode 100644 index ba24f0f2e..000000000 --- a/vignettes/R/RDD/rdd.html +++ /dev/null @@ -1,1032 +0,0 @@ - - - - - - - - - - - - - - - - - -Regression Discontinuity Design (RDD) with stochtree - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
- - - - - - - -
-

Introduction

-

We study conditional average treatment effect (CATE) estimation for -regression discontinuity designs (RDD), in which treatment assignment is -based on whether a particular covariate β€” referred to as the running -variable β€” lies above or below a known value, referred to as the cutoff -value. Because treatment is deterministically assigned as a known -function of the running variable, RDDs are trivially deconfounded: -treatment assignment is independent of the outcome variable, given the -running variable (because treatment is conditionally constant). However, -estimation of treatment effects in RDDs is more complicated than simply -controlling for the running variable, because doing so introduces a -complete lack of overlap, which is the other key condition needed to -justify regression adjustment for causal inference. Nonetheless, the -CATE at the cutoff, \(X=c\), -may still be identified provided the conditional expectation \(E[Y \mid X,W]\) is continuous at that point -for all \(W=w\). We exploit -this assumption with the leaf regression BART model implemented in -Stochtree, which allows us to define an explicit prior on the CATE. We -now describe the RDD setup and our model in more detail, and provide -code to implement our approach.

-
-
-

Regression Discontinuity Design

-

We conceptualize the treatment effect estimation problem via a -quartet of random variables \((Y, X, Z, -U)\). The variable \(Y\) is the -outcome variable; the variable \(X\) is -the running variable; the variable \(Z\) is the treatment assignment indicator -variable; and the variable \(U\) -represents additional, possibly unobserved, causal factors. What -specifically makes this correspond to an RDD is that we stipulate that -\(Z = I(X > c)\), for cutoff \(c\). We assume \(c = 0\) without loss of generality.

-

The following figure depicts a causal diagram representing the -assumed causal relationships between these variables. Two key features -of this diagram are one, that \(X\) -blocks the impact of \(U\) on \(Z\): in other words, \(X\) satisfies the back-door criterion for -learning causal effects of \(Z\) on -\(Y\). And two, \(X\) and \(U\) are not descendants of \(Z\).

-
-A causal directed acyclic graph representing the general structure of a regression discontinuity design problem -

-A causal directed acyclic graph representing the general structure of a -regression discontinuity design problem -

-
-

Using this causal diagram, we may express \(Y\) as some function of its graph parents, -the random variables \((X,Z,U)\): \[Y = F(X,Z,U).\] In principle, we may -obtain draws of \(Y\) by first drawing -\((X,Z,U)\) according to their joint -distribution and then applying the function \(F\). Similarly, we may relate this -formulation to the potential outcomes framework straightforwardly: \[\begin{equation} -\begin{split} -Y^1 &= F(X,1,U),\\ -Y^0 &= F(X,0,U). -\end{split} -\end{equation}\] Here, draws of \((Y^1, -Y^0)\) may be obtained (in principle) by drawing \((X,Z,U)\) from their joint distribution and -using only the \((X,U)\) elements as -arguments in the above two equations, ``discarding’’ the drawn value of -\(Z\). Note that this construction -implies the consistency condition: \(Y = Y^1 Z + Y^0 ( 1 - Z)\). Likewise, this -construction implies the no interference condition because each -\(Y_i\) is considered to be produced -with arguments (\(X_i, Z_i, U_i)\) and -not those from other units \(j\); in -particular, in constructing \(Y_i\), -\(F\) does not take \(Z_j\) for \(j -\neq i\) as an argument.

-

Next, we define the following conditional expectations \[\begin{equation} -\begin{split} -\mu_1(x) &= E[ F(x, 1, U) \mid X = x] ,\\ -\mu_0(x) &= E[ F(x, 0, U) \mid X = x], -\end{split} -\end{equation}\] with which we can define the treatment effect -function \[\tau(x) = \mu_1(x) - -\mu_0(x).\] Because \(X\) -satisfies the back-door criterion, \(\mu_1\) and \(\mu_0\) are estimable from the data, -meaning that \[\begin{equation} -\begin{split} -\mu_1(x) &= E[ F(x, 1, U) \mid X = x] = E[Y \mid X=x, Z=1],\\ -\mu_0(x) &= E[ F(x, 0, U) \mid X = x] = E[Y \mid X=x, Z=0], -\end{split} -\end{equation}\]
-the right-hand-sides of which can be estimated from sample data, which -we supposed to be independent and identically distributed realizations -of \((Y_i, X_i, Z_i)\) for \(i = 1, \dots, n\). However, because \(Z = I(X >0)\) we can in fact only learn -\(\mu_1(x)\) for \(X > 0\) and \(\mu_0(x)\) for \(X < 0\). In potential outcomes -terminology, conditioning on \(X\) -satisfies ignorability, \[(Y^1, Y^0) \perp -\!\!\! \perp Z \mid X,\] but not strong ignorability, -because overlap is violated. Overlap would require that \[0 < P(Z = 1 \mid X=x) < 1 \;\;\;\; \forall -x,\] which is clearly violated by the RDD assumption that \(Z = I(X > 0)\). Consequently, the -overall ATE, \(\bar{\tau} = -E(\tau(X)),\) is unidentified, and we must content ourselves with -estimating \(\tau(0)\), the conditional -average effect at the point \(x = 0\), -which we estimate as the difference between \(\mu_1(0) - \mu_0(0)\). This is possible for -continuous \(X\) so long as one is -willing to assume that \(\mu_1(x)\) and -\(\mu_0(x)\) are both suitably smooth -functions of \(x\): any inferred -discontinuity at \(x = 0\) must -therefore be attributable to treatment effect.

-
-

Conditional average treatment effects in RDD

-

We are concerned with learning not only \(\tau(0)\), the β€œRDD ATE” (e.g.Β the CATE at -\(x = 0\)), but also RDD CATEs, \(\tau(0, \mathrm{w})\) for some covariate -vector \(\mathrm{w}\). Incorporating -additional covariates in the above framework turns out to be -straightforward, simply by defining \(W = -\varphi(U)\) to be an observable function of the (possibly -unobservable) causal factors \(U\). We -may then define our potential outcome means as \[\begin{equation} -\begin{split} -\mu_1(x,\mathrm{w}) &= E[ F(x, 1, U) \mid X = x, W = \mathrm{w}] = -E[Y \mid X=x, W=\mathrm{w}, Z=1],\\ -\mu_0(x,\mathrm{w}) &= E[ F(x, 0, U) \mid X = x, W = \mathrm{w}] = -E[Y \mid X=x, W = \mathrm{w}, Z=0], -\end{split} -\end{equation}\] and our treatment effect function as \[\tau(x,\mathrm{w}) = \mu_1(x,\mathrm{w}) - -\mu_0(x,\mathrm{w}).\] We consider our data to be independent and -identically distributed realizations \((Y_i, -X_i, Z_i, W_i)\) for \(i = 1, \dots, -n\). Furthermore, we must assume that \(\mu_1(x,\mathrm{w})\) and \(\mu_0(x,\mathrm{w})\) are suitably smooth -functions of \(x\), {} \(\mathrm{w}\); in other words, for each -value of \(\mathrm{w}\) the usual -continuity-based identification assumptions must hold.

-

With this framework and notation established, CATE estimation in RDDs -boils down to estimation of condition expectation functions \(E[Y \mid X=x, W=\mathrm{w}, Z=z]\), for -which we turn to BART models.

-
-
-
-

The BARDDT Model

-

We propose a BART model where the trees are allowed to split on \((x,\mathrm{w})\) but where each leaf node -parameter is a vector of regression coefficients tailored to the RDD -context (rather than a scalar constant as in default BART). In one -sense, such a model can be seen as implying distinct RDD ATE regressions -for each subgroup determined by a given tree; however, this intuition is -only heuristic, as the entire model is fit jointly as an ensemble of -such trees. Instead, we motivate this model as a way to estimate the -necessary conditional expectations via a parametrization where the -conditional treatment effect function can be explicitly regularized, as -follows.

-

Let \(\psi\) denote the following -basis vector: \[\begin{equation} -\psi(x,z) = \begin{bmatrix} -1 & z x & (1-z) x & z -\end{bmatrix}. -\end{equation}\] To generalize the original BART model, we define -\(g_j(x, \mathrm{w}, z)\) as a -piecewise linear function as follows. Let \(b_j(x, \mathrm{w})\) denote the node in the -\(j\)th tree which contains the point -\((x, \mathrm{w})\); then the -prediction function for tree \(j\) is -defined to be: \[\begin{equation} -g_j(x, \mathrm{w}, z) = \psi(x, z) \Gamma_{b_j(x, \mathrm{w})} -\end{equation}\]
-for a leaf-specific regression vector \(\Gamma_{b_j} = (\eta_{b_j}, \lambda_{b_j}, -\theta_{b_j}, \Delta_{b_j})^t\). Therefore, letting \(n_{b_j}\) denote the number of data points -allocated to node \(b\) in the \(j\)th tree and \(\Psi_{b_j}\) denote the \(n_{b_j} \times 4\) matrix, with rows equal -to \(\psi(x,z)\) for all \((x_i,z_i) \in b_j\), the model for -observations assigned to leaf \(b_j\), -can be expressed in matrix notation as: \[\begin{equation} -\begin{split} -\mathbf{Y}_{b_j} \mid \Gamma_{b_j}, \sigma^2 &\sim -\mathrm{N}(\Psi_{b_j} \Gamma_{b_j},\sigma^2)\\ -\Gamma_{b_j} &\sim \mathrm{N}(0, \Sigma_0), -\end{split} \label{eq:leaf.regression} -\end{equation}\] where we set \(\Sigma_0 = \frac{0.033}{J} \mbox{I}\) as a -default (for \(x\) vectors standardized -to have unit variance in-sample).

-

This choice of basis entails that the RDD CATE at \(\mathrm{w}\), \(\tau(0, \mathrm{w})\), is a sum of the -\(\Delta_{b_j(0, \mathrm{w})}\) -elements across all trees \(j = 1, \dots, -J\): \[\begin{equation} -\begin{split} -\tau(0, \mathrm{w}) &= E[Y^1 \mid X=0, W = \mathrm{w}] - E[Y^0 \mid -X = 0, W = \mathrm{w}]\\ -& = E[Y \mid X=0, W = \mathrm{w}, Z = 1] - E[Y \mid X = 0, W = -\mathrm{w}, Z = 0]\\ -&= \sum_{j = 1}^J g_j(0, \mathrm{w}, 1) - \sum_{j = 1}^J g_j(0, -\mathrm{w}, 0)\\ -&= \sum_{j = 1}^J \psi(0, 1) \Gamma_{b_j(0, \mathrm{w})} - \sum_{j -= 1}^J \psi(0, 0) \Gamma_{b_j(0, \mathrm{w})} \\ -& = \sum_{j = 1}^J \Bigl( \psi(0, 1) - \psi(0, 0) -\Bigr) \Gamma_{b_j(0, \mathrm{w})} \\ -& = \sum_{j = 1}^J \Bigl( (1,0,0,1) - -(1,0,0,0) \Bigr) \Gamma_{b_j(0, \mathrm{w})} \\ -&= \sum_{j=1}^J \Delta_{b_j(0, \mathrm{w})}. -\end{split} -\end{equation}\] As a result, the priors on the \(\Delta\) coefficients directly regularize -the treatment effect. We set the tree and error variance priors as in -the original BART model.

-

The following figures provide a graphical depiction of how the BARDDT -model fits a response surface and thereby estimates CATEs for distinct -values of \(\mathrm{w}\). For -simplicity only two trees are used in the illustration, while in -practice dozens or hundreds of trees may be used (in our simulations and -empirical example, we use 150 trees).

-
-Two regression trees with splits in x and a single scalar w. Node images depict the g(x,w,z) function (in x) defined by that node's coefficients. The vertical gap between the two line segments in a node that contain x=0 is that node's contribution to the CATE at X = 0. Note that only such nodes contribute for CATE prediction at x=0 -

-Two regression trees with splits in x and a single scalar w. Node images -depict the g(x,w,z) function (in x) defined by that node’s coefficients. -The vertical gap between the two line segments in a node that contain -x=0 is that node’s contribution to the CATE at X = 0. Note that only -such nodes contribute for CATE prediction at x=0 -

-
-
-The two top figures show the same two regression trees as in the preceding figure, now represented as a partition of the x-w plane. Labels in each partition correspond to the leaf nodes depicted in the previous picture. The bottom figure shows the partition of the x-w plane implied by the sum of the two trees; the red dashed line marks point W=w* and the combination of nodes that include this point -

-The two top figures show the same two regression trees as in the -preceding figure, now represented as a partition of the x-w plane. -Labels in each partition correspond to the leaf nodes depicted in the -previous picture. The bottom figure shows the partition of the x-w plane -implied by the sum of the two trees; the red dashed line marks point -W=w* and the combination of nodes that include this point -

-
-
-Left: The function fit at W = w* for the two trees shown in the previous two figures, shown superimposed. Right: The aggregated fit achieved by summing the contributes of two regression tree fits shown at left. The magnitude of the discontinuity at x = 0 (located at the dashed gray vertical line) represents the treatment effect at that point. Different values of w will produce distinct fits; for the two trees shown, there can be three distinct fits based on the value of w. -

-Left: The function fit at W = w* for the two trees shown in the previous -two figures, shown superimposed. Right: The aggregated fit achieved by -summing the contributes of two regression tree fits shown at left. The -magnitude of the discontinuity at x = 0 (located at the dashed gray -vertical line) represents the treatment effect at that point. Different -values of w will produce distinct fits; for the two trees shown, there -can be three distinct fits based on the value of w. -

-
-

An interesting property of BARDDT can be seen in this small -illustration β€” by letting the regression trees split on the running -variable, there is no need to separately define a β€˜bandwidth’ as is used -in the polynomial approach to RDD. Instead, the regression trees -automatically determine (in the course of posterior sampling) when to -β€˜prune’ away regions away from the cutoff value. There are two notable -features of this approach. One, different trees in the ensemble are -effectively using different local bandwidths and these fits are then -blended together. For example, in the bottom panel of the second figure, -we obtain one bandwidth for the region \(d+i\), and a different one for regions -\(a+g\) and \(d+g\). Two, for cells in the tree partition -that do not span the cutoff, the regression within that partition -contains no causal contrasts β€” all observations either have \(Z = 1\) or \(Z = -0\). For those cells, the treatment effect coefficient is -ill-posed and in those cases the posterior sampling is effectively a -draw from the prior; however, such draws correspond to points where the -treatment effect is unidentified and none of these draws contribute to -the estimation of \(\tau(0, -\mathrm{w})\) β€” for example, only nodes \(a+g\), \(d+g\), and \(d+i\) provide any contribution. This -implies that draws of \(\Delta\) -corresponding to nodes not predicting at \(X=0\) will always be draws from the prior, -which has some intuitive appeal.

-
-
-

Demo

-

In this section, we provide code for implementing our model in -stochtree on a popular RDD dataset. First, let us load -stochtree and all the necessary libraries for our posterior -analysis.

-
## Load libraries
-library(stochtree)
-library(rpart)
-library(rpart.plot)
-library(xtable)
-library(foreach)
-library(doParallel)
-
## Loading required package: iterators
-
## Loading required package: parallel
-
-

Dataset

-

The data comes from Lindo, Sanders, and -Oreopoulos (2010), who analyze data on college students enrolled -in a large Canadian university in order to evaluate the effectiveness of -an academic probation policy. Students who present a grade point average -(GPA) lower than a certain threshold at the end of each term are placed -on academic probation and must improve their GPA in the subsequent term -or else face suspension. We are interested in how being put on probation -or not, \(Z\), affects students’ GPA, -\(Y\), at the end of the current term. -The running variable, \(X\), is the -negative distance between a student’s previous-term GPA and the -probation threshold, so that students placed on probation (\(Z = 1\)) have a positive score and the -cutoff is 0. Potential moderators, \(W\), are:

-
    -
  • gender (male),
  • -
  • age upon entering university (age_at_entry)
  • -
  • a dummy for being born in North America -(bpl_north_america),
  • -
  • the number of credits taken in the first year -(totcredits_year1)
  • -
  • an indicator designating each of three campuses -(loc_campus 1, 2 and 3), and
  • -
  • high school GPA as a quantile w.r.t the university’s incoming class -(hsgrade_pct).
  • -
-
## Load and organize data
-data <- read.csv("https://raw.githubusercontent.com/rdpackages-replication/CIT_2024_CUP/refs/heads/main/CIT_2024_CUP_discrete.csv")
-y <- data$nextGPA
-x <- data$X
-x <- x/sd(x) ## we always standardize X
-w <- data[,4:11]
-### Must define categorical features as ordered/unordered factors
-w$totcredits_year1 <- factor(w$totcredits_year1,ordered=TRUE)
-w$male <- factor(w$male,ordered=FALSE)
-w$bpl_north_america <- factor(w$bpl_north_america,ordered=FALSE)
-w$loc_campus1 <- factor(w$loc_campus1,ordered=FALSE)
-w$loc_campus2 <- factor(w$loc_campus2,ordered=FALSE)
-w$loc_campus3 <- factor(w$loc_campus3,ordered=FALSE)
-c <- 0
-n <- nrow(data)
-z <- as.numeric(x>c)
-h <- 0.1 ## window for prediction sample
-test <- -h < x & x < h
-ntest <- sum(test)
-
-
-

Target estimand

-

Generically, our estimand is the CATE function at \(x = 0\); i.e.Β \(\tau(0, \mathrm{w})\). The key practical -question is which values of \(\mathrm{w}\) to consider. Some values of -\(\mathrm{w}\) will not be -well-represented near \(x=0\) and so no -estimation technique will be able to estimate those points effectively. -As such, to focus on feasible points β€” which will lead to interesting -comparisons between methods β€” we recommend restricting the evaluation -points to the observed \(\mathrm{w}_i\) -such that \(|x_i| \leq \delta\), for -some \(\delta > 0\). In our example, -we use \(\delta = 0.1\) for a -standardized \(x\) variable. Therefore, -our estimand of interest is a vector of treatment effects: \[\begin{equation} -\tau(0, \mathrm{w}_i) \;\;\; \forall i \;\mbox{ such that }\; |x_i| \leq -\delta. -\end{equation}\]

-
-
-

Implementing BARDDT

-

In order to implement our model, we write the Psi vector, as defined -before: Psi <- cbind(z*x,(1-z)*x, z,rep(1,n)). The -training matrix for the model is as.matrix(cbind(x,w)), -which we feed into the stochtree::bart function via the -X_train parameter. The basis vector Psi is fed -into the function via the leaf_basis_train parameter. The -list object barddt.mean.parmlist defines options for the -mean forest (a different list can be defined for a variance forest in -the case of heteroscedastic BART, which we do not consider here). -Importantly, in this list we define parameter -sigma2_leaf_init = diag(rep(0.1/150,4)), which sets \(\Sigma_0\) as described above. Now, we can -fit the model, which is saved in object barddt.fit.

-

Once the model is fit, we need 3 elements to obtain the CATE -predictions: the basis vectors at the cutoff for \(z=1\) and \(z=0\), the test matrix \([X \quad W]\) at the cutoff, and the -testing sample. We define the prediction basis vectors \(\psi_1 = [1 \quad 0 \quad 0 \quad 1]\) and -\(\psi_0 = [1 \quad 0 \quad 0 \quad -0]\), which correspond to \(\psi\) at \((x=0,z=1)\), and \((x=0,z=0)\), respectively. These vectors -are written into R as -Psi1 <- cbind(rep(1,n), rep(c,n), rep(0,n), rep(1,n)) -and -Psi0 <- cbind(rep(1,n), rep(0,n), rep(c,n), rep(0,n)). -Then, we write the test matrix at \((x=0,\mathrm{w})\) as -xmat_test <- as.matrix(cbind(rep(0,n),w). Finally, we -must define the testing window. As discussed previously, our window is -set such that \(|x| \leq 0.1\), which -can be set in R as -test <- -0.1 < x & x <0.1.

-

Once all of these elements are set, we can obtain the outcome -predictions at the cutoff by running -predict(barddt.fit, xmat_test, Psi1) (resp. -predict(barddt.fit, xmat_test, Psi0)). Each of these calls -returns a list, from which we can extract element y_hat to -obtain the posterior distribution for the outcome. In the code below, -the treated and control outcome predictions are saved in the matrix -objects pred1 and pred0, respectively. Now, we -can obtain draws from the CATE posterior by simply subtracting these -matrices. The function below outlines how to perform each of these steps -in R.

-
fit.barddt <- function(y,x,w,z,test,c)
-{
-  ## Lists of parameters for the Stochtree BART function
-  barddt.global.parmlist <- list(standardize=T,sample_sigma_global=TRUE,sigma2_global_init=0.1)
-  barddt.mean.parmlist <- list(num_trees=50, min_samples_leaf=20, alpha=0.95, beta=2,
-                               max_depth=20, sample_sigma2_leaf=FALSE, sigma2_leaf_init = diag(rep(0.1/150,4)))
-  ## Set basis vector for leaf regressions
-  Psi <- cbind(rep(1,n),z*x,(1-z)*x,z)
-  ## Model fit
-  barddt.fit = stochtree::bart(X_train= as.matrix(cbind(x,w)), y_train=y,
-                               leaf_basis_train = Psi, mean_forest_params=barddt.mean.parmlist,
-                               general_params=barddt.global.parmlist,
-                               num_mcmc=1000,num_gfr=30)
-  ## Define basis vectors and test matrix for outcome predictions at X=c
-  Psi1 <- cbind(rep(1,n), rep(c,n), rep(0,n), rep(1,n))
-  Psi0 <- cbind(rep(1,n), rep(0,n), rep(c,n), rep(0,n))
-  Psi1 <- Psi1[test,]
-  Psi0 <- Psi0[test,]
-  xmat_test <- as.matrix(cbind(rep(0,n),w)[test,])
-  ## Obtain outcome predictions
-  pred1 <- predict(barddt.fit,xmat_test,Psi1)$y_hat
-  pred0 <- predict(barddt.fit,xmat_test,Psi0)$y_hat
-  ## Obtain CATE posterior
-  out <- pred1-pred0
-  return(out)
-}
-

Now, we proceed to fit the BARDDT model. The procedure is exactly the -same as described in the simulation section.

-
## We will sample multiple chains sequentially
-num_chains <- 20
-num_gfr <- 2
-num_burnin <- 0
-num_mcmc <- 500
-bart_models <- list()
-## Define basis functions for training and testing
-B <- cbind(z*x,(1-z)*x, z,rep(1,n))
-B1 <- cbind(rep(c,n), rep(0,n), rep(1,n), rep(1,n))
-B0 <- cbind(rep(0,n), rep(c,n), rep(0,n), rep(1,n))
-B1 <- B1[test,]
-B0 <- B0[test,]
-B_test <- rbind(B1,B0)
-xmat_test <- cbind(x=rep(0,n),w)[test,]
-xmat_test <- rbind(xmat_test,xmat_test)
-### We combine the basis for Z=1 and Z=0 to feed it to the BART call and get the Y(z) predictions instantaneously
-### Then we separate the posterior matrix between each Z and calculate the CATE prediction
-## Sampling trees in parallel
-ncores <- 5
-cl <- makeCluster(ncores)
-registerDoParallel(cl)
-
-start_time <- Sys.time()
-bart_model_outputs <- foreach (i = 1:num_chains) %dopar% {
-  random_seed <- i
-  ## Lists to define BARDDT parameters
-  barddt.global.parmlist <- list(standardize=T,sample_sigma_global=TRUE,sigma2_global_init=0.1)
-  barddt.mean.parmlist <- list(num_trees=50, min_samples_leaf=20, alpha=0.95, beta=2,
-                               max_depth=20, sample_sigma2_leaf=FALSE, sigma2_leaf_init = diag(rep(0.1/50,4)))
-  bart_model <- stochtree::bart(
-    X_train = cbind(x,w), leaf_basis_train = B, y_train = y, 
-    X_test = xmat_test, leaf_basis_test = B_test,
-    num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, 
-    general_params = barddt.global.parmlist, mean_forest_params = barddt.mean.parmlist
-  )
-  bart_model <- bart_model$y_hat_test[1:ntest,]-bart_model$y_hat_test[(ntest+1):(2*ntest),]
-}
-stopCluster(cl)
-## Combine CATE predictions
-pred <- do.call("cbind",bart_model_outputs)
-
-end_time <- Sys.time()
-
-print(end_time - start_time)
-
## Time difference of 9.554316 mins
-
## Save the results
-saveRDS(pred, "bart_rdd_posterior.rds")
-

We now proceed to analyze the CATE posterior. The figure produced -below presents a summary of the CATE posterior produced by BARDDT for -this application. This picture is produced fitting a regression tree, -using \(W\) as the predictors, to the -individual posterior mean CATEs: \[\begin{equation} -\bar{\tau}_i = \frac{1}{M} \sum_{h = 1}^M \tau^{(h)}(0, \mathrm{w}_i), -\end{equation}\] where \(h\) -indexes each of \(M\) total posterior -samples. As in our simulation studies, we restrict our posterior -analysis to use \(\mathrm{w}_i\) values -of observations with \(|x_i| \leq \delta = -0.1\) (after normalizing \(X\) -to have standard deviation 1 in-sample). For the Lindo, Sanders, and Oreopoulos (2010) data, this -means that BARDDT was trained on \(n = -40,582\) observations, of which 1,602 satisfy \(x_i \leq 0.1\), which were used to generate -the effect moderation tree.

-
## Fit regression tree
-cate <- rpart(y~.,data.frame(y=rowMeans(pred),w[test,]),control = rpart.control(cp=0.015))
-## Define separate colors for left and rightmost nodes
-plot.cart <- function(rpart.obj)
-{
-  rpart.frame <- rpart.obj$frame
-  left <- which.min(rpart.frame$yval)
-  right <- which.max(rpart.frame$yval)
-  nodes <- rep(NA,nrow(rpart.frame))
-  for (i in 1:length(nodes))
-  {
-    if (rpart.frame$yval[i]==rpart.frame$yval[right]) nodes[i] <- "gold2"
-    else if (rpart.frame$yval[i]==rpart.frame$yval[left]) nodes[i] <- "tomato3"
-    else nodes[i] <- "lightblue3"
-  }
-  return(nodes)
-}
-## Plot regression tree
-rpart.plot(cate,main="",box.col=plot.cart(cate))
-
-Regression tree fit to posterior point estimates of individual treatment effects: top number in each box is the average subgroup treatment effect, lower number shows the percentage of the total sample in that subgroup; the tree flags credits in first year, gender, and age at entry as important moderators. -

-Regression tree fit to posterior point estimates of individual treatment -effects: top number in each box is the average subgroup treatment -effect, lower number shows the percentage of the total sample in that -subgroup; the tree flags credits in first year, gender, and age at entry -as important moderators. -

-
-

The resulting effect moderation tree indicates that course load -(credits attempted) in the academic term leading to their probation is a -strong moderator. Contextually, this result is plausible, both because -course load could relate to latent character attributes that influence a -student’s responsiveness to sanctions and also because it could predict -course load in the current term, which would in turn have implications -for the GPA (i.e.Β it is harder to get a high GPA while taking more -credit hours). The tree also suggests that effects differ by campus, and -age and gender of the student. These findings are all prima facie -plausible as well.

-

To gauge how strong these findings are statistically, we can zoom in -on isolated subgroups and compare the posteriors of their subgroup -average treatment effects. This approach is valid because in fitting the -effect moderation tree to the posterior mean CATEs we in no way altered -the posterior itself; the effect moderation tree is a posterior summary -tool and not any additional inferential approach; the posterior is -obtained once and can be explored freely using a variety of techniques -without vitiating its statistical validity. Investigating the most -extreme differences is a good place to start: consider the two groups of -students at opposite ends of the treatment effect range discovered by -the effect moderation tree:

-
    -
  • Group A a male student that entered college older -than 19 and attempted more than 4.8 credits in the first year (leftmost -leaf node, colored red, comprising 128 individuals)
  • -
  • Group B a student of any gender who entered college -younger than 19 and attempted between 4.3 and 4.8 credits in the first -year (rightmost leaf node, colored gold, comprising 108 -individuals).
  • -
-

Subgroup CATEs are obtained by aggregating CATEs across the observed -\(\mathrm{w}_i\) values for individuals -in each group; this can be done for individual posterior samples, -yielding a posterior distribution over the subgroup CATE: \[\begin{equation} -\bar{\tau}_A^{(h)} = \frac{1}{n_A} \sum_{i : \mathrm{w}_i} \tau^{(h)}(0, -\mathrm{w}_i), -\end{equation}\] where \(h\) -indexes a posterior draw and \(n_A\) -denotes the number of individuals in the group A.

-

The code below produces a contour plot for a bivariate kernel density -estimate of the joint CATE posterior distribution for subgroups A and B. -The contour lines are nearly all above the \(45^{\circ}\) line, indicating that the -preponderance of posterior probability falls in the region where the -treatment effect for Group B is greater than that of Group A, meaning -that the difference in the subgroup treatment effects flagged by the -effect moderation tree persist even after accounting for estimation -uncertainty in the underlying CATE function.

-
## Define function to produce KD estimates of the joint distribution of two subgroups
-cate.kde <- function(rpart.obj,pred)
-{
-  rpart.frame <- rpart.obj$frame
-  left <- rpart.obj$where==which.min(rpart.frame$yval)
-  right <- rpart.obj$where==which.max(rpart.frame$yval)
-  ## Calculate CATE posterior for groups A and B
-  cate.a <- do.call("cbind",by(pred,left, colMeans))
-  cate.b <- do.call("cbind",by(pred,right, colMeans))
-  cate.a <- cate.a[,2]
-  cate.b <- cate.b[,2]
-  ## Estimate kernel density
-  denshat <- MASS::kde2d(cate.a, cate.b, n=200)
-  return(denshat)
-}
-contour(cate.kde(cate,pred),bty='n',xlab="Group A",ylab="Group B")
-abline(a=0,b=1)
-
-Kernel density estimates for the joint CATE posterior between male students who entered college older than 19 and attempted more than 4.8 credits in the first year (leftmost leaf node, red) and students who entered college younger than 19 and attempted between 4.3 and 4.8 credits in the first year (rightmost leaf node, gold) -

-Kernel density estimates for the joint CATE posterior between male -students who entered college older than 19 and attempted more than 4.8 -credits in the first year (leftmost leaf node, red) and students who -entered college younger than 19 and attempted between 4.3 and 4.8 -credits in the first year (rightmost leaf node, gold) -

-
-

As always, CATEs that vary with observable factors do not necessarily -represent a causal moderating relationship. Here, if the -treatment effect of academic probation is seen to vary with the number -of credits, that does not imply that this association is causal: -prescribing students to take a certain number of credits will not -necessarily lead to a more effective probation policy, it may simply be -that the type of student to naturally enroll for fewer credit hours is -more likely to be responsive to academic probation. An entirely distinct -set of causal assumptions are required to interpret the CATE variations -themselves as causal. All the same, uncovering these patterns of -treatment effect variability are crucial to suggesting causal mechanism -to be investigated in future studies.

-
-
-
-

References

-
-
-Lindo, Jason M, Nicholas J Sanders, and Philip Oreopoulos. 2010. -β€œAbility, Gender, and Performance Standards: Evidence from -Academic Probation.” American Economic Journal: Applied -Economics 2 (2): 95–117. -
-
-
- - - - -
- - - - - - - - - - - - - - - diff --git a/vignettes/R/RDD/rdd_vignette.Rmd b/vignettes/R/RDD/rdd_vignette.Rmd deleted file mode 100644 index 12ae6ca6c..000000000 --- a/vignettes/R/RDD/rdd_vignette.Rmd +++ /dev/null @@ -1,354 +0,0 @@ ---- -title: 'Regression Discontinuity Design (RDD) with stochtree' -author: - - Rafael Alcantara, University of Texas at Austin - - P. Richard Hahn, Arizona State University - - Drew Herren, University of Texas at Austin -date: "`r Sys.Date()`" -output: html_document -bibliography: rdd.bib ---- - -```{r setup, include=FALSE} -knitr::opts_chunk$set(echo = TRUE) -``` - -\usepackage{amsmath,asfonts,amssymb,amsthm} -\newcommand{\ind}{\perp \!\!\! \perp} -\newcommand{\B}{\mathcal{B}} -\newcommand{\res}{\mathbf{r}} -\newcommand{\m}{\mathbf{m}} -\newcommand{\x}{\mathbf{x}} -\newcommand{\C}{\mathbb{C}} -\newcommand{\N}{\mathrm{N}} -\newcommand{\w}{\mathrm{w}} -\newcommand{\iidsim}[0]{\stackrel{\mathrm{iid}}{\sim}} -\newcommand{\V}{ \mathbb{V}} -\newcommand{\f}{\mathrm{f}} -\newcommand{\F}{\mathbf{F}} -\newcommand{\Y}{\mathbf{Y}} - -## Introduction - -We study conditional average treatment effect (CATE) estimation for regression discontinuity designs (RDD), in which treatment assignment is based on whether a particular covariate --- referred to as the running variable --- lies above or below a known value, referred to as the cutoff value. Because treatment is deterministically assigned as a known function of the running variable, RDDs are trivially deconfounded: treatment assignment is independent of the outcome variable, given the running variable (because treatment is conditionally constant). However, estimation of treatment effects in RDDs is more complicated than simply controlling for the running variable, because doing so introduces a complete lack of overlap, which is the other key condition needed to justify regression adjustment for causal inference. Nonetheless, the CATE _at the cutoff_, $X=c$, may still be identified provided the conditional expectation $E[Y \mid X,W]$ is continuous at that point for _all_ $W=w$. We exploit this assumption with the leaf regression BART model implemented in Stochtree, which allows us to define an explicit prior on the CATE. We now describe the RDD setup and our model in more detail, and provide code to implement our approach. - -## Regression Discontinuity Design - -We conceptualize the treatment effect estimation problem via a quartet of random variables $(Y, X, Z, U)$. The variable $Y$ is the outcome variable; the variable $X$ is the running variable; the variable $Z$ is the treatment assignment indicator variable; and the variable $U$ represents additional, possibly unobserved, causal factors. What specifically makes this correspond to an RDD is that we stipulate that $Z = I(X > c)$, for cutoff $c$. We assume $c = 0$ without loss of generality. - -The following figure depicts a causal diagram representing the assumed causal relationships between these variables. Two key features of this diagram are one, that $X$ blocks the impact of $U$ on $Z$: in other words, $X$ satisfies the back-door criterion for learning causal effects of $Z$ on $Y$. And two, $X$ and $U$ are not descendants of $Z$. - -```{r cdag, echo=FALSE, fig.cap="A causal directed acyclic graph representing the general structure of a regression discontinuity design problem", fig.align="center", out.width = '40%'} -knitr::include_graphics("RDD_DAG.png") -``` - -Using this causal diagram, we may express $Y$ as some function of its graph parents, the random variables $(X,Z,U)$: $$Y = F(X,Z,U).$$ In principle, we may obtain draws of $Y$ by first drawing $(X,Z,U)$ according to their joint distribution and then applying the function $F$. Similarly, we may relate this formulation to the potential outcomes framework straightforwardly: -\begin{equation} -\begin{split} -Y^1 &= F(X,1,U),\\ -Y^0 &= F(X,0,U). -\end{split} -\end{equation} -Here, draws of $(Y^1, Y^0)$ may be obtained (in principle) by drawing $(X,Z,U)$ from their joint distribution and using only the $(X,U)$ elements as arguments in the above two equations, "discarding" the drawn value of $Z$. Note that this construction implies the _consistency_ condition: $Y = Y^1 Z + Y^0 ( 1 - Z)$. Likewise, this construction implies the _no interference_ condition because each $Y_i$ is considered to be produced with arguments ($X_i, Z_i, U_i)$ and not those from other units $j$; in particular, in constructing $Y_i$, $F$ does not take $Z_j$ for $j \neq i$ as an argument. - -Next, we define the following conditional expectations -\begin{equation} -\begin{split} -\mu_1(x) &= E[ F(x, 1, U) \mid X = x] ,\\ -\mu_0(x) &= E[ F(x, 0, U) \mid X = x], -\end{split} -\end{equation} -with which we can define the treatment effect function -$$\tau(x) = \mu_1(x) - \mu_0(x).$$ -Because $X$ satisfies the back-door criterion, $\mu_1$ and $\mu_0$ are estimable from the data, meaning that -\begin{equation} -\begin{split} -\mu_1(x) &= E[ F(x, 1, U) \mid X = x] = E[Y \mid X=x, Z=1],\\ -\mu_0(x) &= E[ F(x, 0, U) \mid X = x] = E[Y \mid X=x, Z=0], -\end{split} -\end{equation} -the right-hand-sides of which can be estimated from sample data, which we supposed to be independent and identically distributed realizations of $(Y_i, X_i, Z_i)$ for $i = 1, \dots, n$. However, because $Z = I(X >0)$ we can in fact only learn $\mu_1(x)$ for $X > 0$ and $\mu_0(x)$ for $X < 0$. In potential outcomes terminology, conditioning on $X$ satisfies ignorability, -$$(Y^1, Y^0) \ind Z \mid X,$$ -but not _strong ignorability_, because overlap is violated. Overlap would require that -$$0 < P(Z = 1 \mid X=x) < 1 \;\;\;\; \forall x,$$ -which is clearly violated by the RDD assumption that $Z = I(X > 0)$. Consequently, the overall ATE, -$\bar{\tau} = E(\tau(X)),$ is unidentified, and we must content ourselves with estimating $\tau(0)$, the conditional average effect at the point $x = 0$, which we estimate as the difference between $\mu_1(0) - \mu_0(0)$. This is possible for continuous $X$ so long as one is willing to assume that $\mu_1(x)$ and $\mu_0(x)$ are both suitably smooth functions of $x$: any inferred discontinuity at $x = 0$ must therefore be attributable to treatment effect. - -### Conditional average treatment effects in RDD - -We are concerned with learning not only $\tau(0)$, the "RDD ATE" (e.g. the CATE at $x = 0$), but also RDD CATEs, $\tau(0, \w)$ for some covariate vector $\w$. Incorporating additional covariates in the above framework turns out to be straightforward, simply by defining $W = \varphi(U)$ to be an observable function of the (possibly unobservable) causal factors $U$. We may then define our potential outcome means as -\begin{equation} -\begin{split} -\mu_1(x,\w) &= E[ F(x, 1, U) \mid X = x, W = \w] = E[Y \mid X=x, W=\w, Z=1],\\ -\mu_0(x,\w) &= E[ F(x, 0, U) \mid X = x, W = \w] = E[Y \mid X=x, W = \w, Z=0], -\end{split} -\end{equation} -and our treatment effect function as -$$\tau(x,\w) = \mu_1(x,\w) - \mu_0(x,\w).$$ We consider our data to be independent and identically distributed realizations $(Y_i, X_i, Z_i, W_i)$ for $i = 1, \dots, n$. Furthermore, we must assume that $\mu_1(x,\w)$ and $\mu_0(x,\w)$ are suitably smooth functions of $x$, {\em for every} $\w$; in other words, for each value of $\w$ the usual continuity-based identification assumptions must hold. - -With this framework and notation established, CATE estimation in RDDs boils down to estimation of condition expectation functions $E[Y \mid X=x, W=\w, Z=z]$, for which we turn to BART models. - -## The BARDDT Model - -We propose a BART model where the trees are allowed to split on $(x,\w)$ but where each leaf node parameter is a vector of regression coefficients tailored to the RDD context (rather than a scalar constant as in default BART). In one sense, such a model can be seen as implying distinct RDD ATE regressions for each subgroup determined by a given tree; however, this intuition is only heuristic, as the entire model is fit jointly as an ensemble of such trees. Instead, we motivate this model as a way to estimate the necessary conditional expectations via a parametrization where the conditional treatment effect function can be explicitly regularized, as follows. - -Let $\psi$ denote the following basis vector: -\begin{equation} -\psi(x,z) = \begin{bmatrix} -1 & z x & (1-z) x & z -\end{bmatrix}. -\end{equation} -To generalize the original BART model, we define $g_j(x, \w, z)$ as a piecewise linear function as follows. Let $b_j(x, \w)$ denote the node in the $j$th tree which contains the point $(x, \w)$; then the prediction function for tree $j$ is defined to be: -\begin{equation} -g_j(x, \w, z) = \psi(x, z) \Gamma_{b_j(x, \w)} -\end{equation} -for a leaf-specific regression vector $\Gamma_{b_j} = (\eta_{b_j}, \lambda_{b_j}, \theta_{b_j}, \Delta_{b_j})^t$. Therefore, letting $n_{b_j}$ denote the number of data points allocated to node $b$ in the $j$th tree and $\Psi_{b_j}$ denote the $n_{b_j} \times 4$ matrix, with rows equal to $\psi(x,z)$ for all $(x_i,z_i) \in b_j$, the model for observations assigned to leaf $b_j$, can be expressed in matrix notation as: -\begin{equation} -\begin{split} -\Y_{b_j} \mid \Gamma_{b_j}, \sigma^2 &\sim \N(\Psi_{b_j} \Gamma_{b_j},\sigma^2)\\ -\Gamma_{b_j} &\sim \N (0, \Sigma_0), -\end{split} \label{eq:leaf.regression} -\end{equation} -where we set $\Sigma_0 = \frac{0.033}{J} \mbox{I}$ as a default (for $x$ vectors standardized to have unit variance in-sample). - -This choice of basis entails that the RDD CATE at $\w$, $\tau(0, \w)$, is a sum of the $\Delta_{b_j(0, \w)}$ elements across all trees $j = 1, \dots, J$: -\begin{equation} -\begin{split} -\tau(0, \w) &= E[Y^1 \mid X=0, W = \w] - E[Y^0 \mid X = 0, W = \w]\\ -& = E[Y \mid X=0, W = \w, Z = 1] - E[Y \mid X = 0, W = \w, Z = 0]\\ -&= \sum_{j = 1}^J g_j(0, \w, 1) - \sum_{j = 1}^J g_j(0, \w, 0)\\ -&= \sum_{j = 1}^J \psi(0, 1) \Gamma_{b_j(0, \w)} - \sum_{j = 1}^J \psi(0, 0) \Gamma_{b_j(0, \w)} \\ -& = \sum_{j = 1}^J \Bigl( \psi(0, 1) - \psi(0, 0) \Bigr) \Gamma_{b_j(0, \w)} \\ -& = \sum_{j = 1}^J \Bigl( (1,0,0,1) - (1,0,0,0) \Bigr) \Gamma_{b_j(0, \w)} \\ -&= \sum_{j=1}^J \Delta_{b_j(0, \w)}. -\end{split} -\end{equation} -As a result, the priors on the $\Delta$ coefficients directly regularize the treatment effect. We set the tree and error variance priors as in the original BART model. - -The following figures provide a graphical depiction of how the BARDDT model fits a response surface and thereby estimates CATEs for distinct values of $\w$. For simplicity only two trees are used in the illustration, while in practice dozens or hundreds of trees may be used (in our simulations and empirical example, we use 150 trees). - -```{r trees1, echo=FALSE, fig.cap="Two regression trees with splits in x and a single scalar w. Node images depict the g(x,w,z) function (in x) defined by that node's coefficients. The vertical gap between the two line segments in a node that contain x=0 is that node's contribution to the CATE at X = 0. Note that only such nodes contribute for CATE prediction at x=0", fig.align="center", out.width = '70%'} -knitr::include_graphics("trees1.png") -``` - -```{r trees2, echo=FALSE, fig.cap="The two top figures show the same two regression trees as in the preceding figure, now represented as a partition of the x-w plane. Labels in each partition correspond to the leaf nodes depicted in the previous picture. The bottom figure shows the partition of the x-w plane implied by the sum of the two trees; the red dashed line marks point W=w* and the combination of nodes that include this point", fig.align="center", out.width = '70%'} -knitr::include_graphics("trees2.png") -``` - -```{r trees3, echo=FALSE, fig.cap="Left: The function fit at W = w* for the two trees shown in the previous two figures, shown superimposed. Right: The aggregated fit achieved by summing the contributes of two regression tree fits shown at left. The magnitude of the discontinuity at x = 0 (located at the dashed gray vertical line) represents the treatment effect at that point. Different values of w will produce distinct fits; for the two trees shown, there can be three distinct fits based on the value of w.", fig.align="center", out.width = '70%'} -knitr::include_graphics("trees3.png") -``` - -An interesting property of BARDDT can be seen in this small illustration --- by letting the regression trees split on the running variable, there is no need to separately define a 'bandwidth' as is used in the polynomial approach to RDD. Instead, the regression trees automatically determine (in the course of posterior sampling) when to 'prune' away regions away from the cutoff value. There are two notable features of this approach. One, different trees in the ensemble are effectively using different local bandwidths and these fits are then blended together. For example, in the bottom panel of the second figure, we obtain one bandwidth for the region $d+i$, and a different one for regions $a+g$ and $d+g$. Two, for cells in the tree partition that do not span the cutoff, the regression within that partition contains no causal contrasts --- all observations either have $Z = 1$ or $Z = 0$. For those cells, the treatment effect coefficient is ill-posed and in those cases the posterior sampling is effectively a draw from the prior; however, such draws correspond to points where the treatment effect is unidentified and none of these draws contribute to the estimation of $\tau(0, \w)$ --- for example, only nodes $a+g$, $d+g$, and $d+i$ provide any contribution. This implies that draws of $\Delta$ corresponding to nodes not predicting at $X=0$ will always be draws from the prior, which has some intuitive appeal. - -## Demo - -In this section, we provide code for implementing our model in `stochtree` on a popular RDD dataset. -First, let us load `stochtree` and all the necessary libraries for our posterior analysis. - -```{r} -## Load libraries -library(stochtree) -library(rpart) -library(rpart.plot) -library(xtable) -library(foreach) -library(doParallel) -``` - -### Dataset - -The data comes from @lindo2010ability, who analyze data on college students enrolled in a large Canadian university in order to evaluate the effectiveness of an academic probation policy. Students who present a grade point average (GPA) lower than a certain threshold at the end of each term are placed on academic probation and must improve their GPA in the subsequent term or else face suspension. We are interested in how being put on probation or not, $Z$, affects students' GPA, $Y$, at the end of the current term. The running variable, $X$, is the negative distance between a student's previous-term GPA and the probation threshold, so that students placed on probation ($Z = 1$) have a positive score and the cutoff is 0. Potential moderators, $W$, are: - -* gender (`male`), -* age upon entering university (`age_at_entry`) -* a dummy for being born in North America (`bpl_north_america`), -* the number of credits taken in the first year (`totcredits_year1`) -* an indicator designating each of three campuses (`loc_campus` 1, 2 and 3), and -* high school GPA as a quantile w.r.t the university's incoming class (`hsgrade_pct`). - -```{r} -## Load and organize data -data <- read.csv("https://raw.githubusercontent.com/rdpackages-replication/CIT_2024_CUP/refs/heads/main/CIT_2024_CUP_discrete.csv") -y <- data$nextGPA -x <- data$X -x <- x/sd(x) ## we always standardize X -w <- data[,4:11] -### Must define categorical features as ordered/unordered factors -w$totcredits_year1 <- factor(w$totcredits_year1,ordered=TRUE) -w$male <- factor(w$male,ordered=FALSE) -w$bpl_north_america <- factor(w$bpl_north_america,ordered=FALSE) -w$loc_campus1 <- factor(w$loc_campus1,ordered=FALSE) -w$loc_campus2 <- factor(w$loc_campus2,ordered=FALSE) -w$loc_campus3 <- factor(w$loc_campus3,ordered=FALSE) -c <- 0 -n <- nrow(data) -z <- as.numeric(x>c) -h <- 0.1 ## window for prediction sample -test <- -h < x & x < h -ntest <- sum(test) -``` - -### Target estimand - -Generically, our estimand is the CATE function at $x = 0$; i.e. $\tau(0, \w)$. The key practical question is which values of $\w$ to consider. Some values of $\w$ will not be well-represented near $x=0$ and so no estimation technique will be able to estimate those points effectively. As such, to focus on feasible points --- which will lead to interesting comparisons between methods --- we recommend restricting the evaluation points to the observed $\w_i$ such that $|x_i| \leq \delta$, for some $\delta > 0$. In our example, we use $\delta = 0.1$ for a standardized $x$ variable. Therefore, our estimand of interest is a vector of treatment effects: -\begin{equation} -\tau(0, \w_i) \;\;\; \forall i \;\mbox{ such that }\; |x_i| \leq \delta. -\end{equation} - -### Implementing BARDDT - -In order to implement our model, we write the Psi vector, as defined before: `Psi <- cbind(z*x,(1-z)*x, z,rep(1,n))`. The training matrix for the model is `as.matrix(cbind(x,w))`, which we feed into the `stochtree::bart` function via the `X_train` parameter. The basis vector `Psi` is fed into the function via the `leaf_basis_train` parameter. The list object `barddt.mean.parmlist` defines options for the mean forest (a different list can be defined for a variance forest in the case of heteroscedastic BART, which we do not consider here). Importantly, in this list we define parameter `sigma2_leaf_init = diag(rep(0.1/150,4))`, which sets $\Sigma_0$ as described above. Now, we can fit the model, which is saved in object `barddt.fit`. - -Once the model is fit, we need 3 elements to obtain the CATE predictions: the basis vectors at the cutoff for $z=1$ and $z=0$, the test matrix $[X \quad W]$ at the cutoff, and the testing sample. We define the prediction basis vectors $\psi_1 = [1 \quad 0 \quad 0 \quad 1]$ and $\psi_0 = [1 \quad 0 \quad 0 \quad 0]$, which correspond to $\psi$ at $(x=0,z=1)$, and $(x=0,z=0)$, respectively. These vectors are written into R as `Psi1 <- cbind(rep(1,n), rep(c,n), rep(0,n), rep(1,n))` and `Psi0 <- cbind(rep(1,n), rep(0,n), rep(c,n), rep(0,n))`. Then, we write the test matrix at $(x=0,\w)$ as `xmat_test <- as.matrix(cbind(rep(0,n),w)`. Finally, we must define the testing window. As discussed previously, our window is set such that $|x| \leq 0.1$, which can be set in R as `test <- -0.1 < x & x <0.1`. - -Once all of these elements are set, we can obtain the outcome predictions at the cutoff by running `predict(barddt.fit, xmat_test, Psi1)` (resp. `predict(barddt.fit, xmat_test, Psi0)`). Each of these calls returns a list, from which we can extract element `y_hat` to obtain the posterior distribution for the outcome. In the code below, the treated and control outcome predictions are saved in the matrix objects `pred1` and `pred0`, respectively. Now, we can obtain draws from the CATE posterior by simply subtracting these matrices. The function below outlines how to perform each of these steps in R. - -```{r} -fit.barddt <- function(y,x,w,z,test,c) -{ - ## Lists of parameters for the Stochtree BART function - barddt.global.parmlist <- list(standardize=T,sample_sigma_global=TRUE,sigma2_global_init=0.1) - barddt.mean.parmlist <- list(num_trees=50, min_samples_leaf=20, alpha=0.95, beta=2, - max_depth=20, sample_sigma2_leaf=FALSE, sigma2_leaf_init = diag(rep(0.1/150,4))) - ## Set basis vector for leaf regressions - Psi <- cbind(rep(1,n),z*x,(1-z)*x,z) - ## Model fit - barddt.fit = stochtree::bart(X_train= as.matrix(cbind(x,w)), y_train=y, - leaf_basis_train = Psi, mean_forest_params=barddt.mean.parmlist, - general_params=barddt.global.parmlist, - num_mcmc=1000,num_gfr=30) - ## Define basis vectors and test matrix for outcome predictions at X=c - Psi1 <- cbind(rep(1,n), rep(c,n), rep(0,n), rep(1,n)) - Psi0 <- cbind(rep(1,n), rep(0,n), rep(c,n), rep(0,n)) - Psi1 <- Psi1[test,] - Psi0 <- Psi0[test,] - xmat_test <- as.matrix(cbind(rep(0,n),w)[test,]) - ## Obtain outcome predictions - pred1 <- predict(barddt.fit,xmat_test,Psi1)$y_hat - pred0 <- predict(barddt.fit,xmat_test,Psi0)$y_hat - ## Obtain CATE posterior - out <- pred1-pred0 - return(out) -} -``` - -Now, we proceed to fit the BARDDT model. The procedure is exactly the same as described in the simulation section. - -```{r empiricalPosterior, cache=TRUE,cache.lazy=FALSE} -## We will sample multiple chains sequentially -num_chains <- 20 -num_gfr <- 2 -num_burnin <- 0 -num_mcmc <- 500 -bart_models <- list() -## Define basis functions for training and testing -B <- cbind(z*x,(1-z)*x, z,rep(1,n)) -B1 <- cbind(rep(c,n), rep(0,n), rep(1,n), rep(1,n)) -B0 <- cbind(rep(0,n), rep(c,n), rep(0,n), rep(1,n)) -B1 <- B1[test,] -B0 <- B0[test,] -B_test <- rbind(B1,B0) -xmat_test <- cbind(x=rep(0,n),w)[test,] -xmat_test <- rbind(xmat_test,xmat_test) -### We combine the basis for Z=1 and Z=0 to feed it to the BART call and get the Y(z) predictions instantaneously -### Then we separate the posterior matrix between each Z and calculate the CATE prediction -## Sampling trees in parallel -ncores <- 5 -cl <- makeCluster(ncores) -registerDoParallel(cl) - -start_time <- Sys.time() -bart_model_outputs <- foreach (i = 1:num_chains) %dopar% { - random_seed <- i - ## Lists to define BARDDT parameters - barddt.global.parmlist <- list(standardize=T,sample_sigma_global=TRUE,sigma2_global_init=0.1) - barddt.mean.parmlist <- list(num_trees=50, min_samples_leaf=20, alpha=0.95, beta=2, - max_depth=20, sample_sigma2_leaf=FALSE, sigma2_leaf_init = diag(rep(0.1/50,4))) - bart_model <- stochtree::bart( - X_train = cbind(x,w), leaf_basis_train = B, y_train = y, - X_test = xmat_test, leaf_basis_test = B_test, - num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - general_params = barddt.global.parmlist, mean_forest_params = barddt.mean.parmlist - ) - bart_model <- bart_model$y_hat_test[1:ntest,]-bart_model$y_hat_test[(ntest+1):(2*ntest),] -} -stopCluster(cl) -## Combine CATE predictions -pred <- do.call("cbind",bart_model_outputs) - -end_time <- Sys.time() - -print(end_time - start_time) -## Save the results -saveRDS(pred, "bart_rdd_posterior.rds") -``` - -We now proceed to analyze the CATE posterior. The figure produced below presents a summary of the CATE posterior produced by BARDDT for this application. This picture is produced fitting a regression tree, using $W$ as the predictors, to the individual posterior mean CATEs: -\begin{equation} -\bar{\tau}_i = \frac{1}{M} \sum_{h = 1}^M \tau^{(h)}(0, \w_i), -\end{equation} -where $h$ indexes each of $M$ total posterior samples. As in our simulation studies, we restrict our posterior analysis to use $\w_i$ values of observations with $|x_i| \leq \delta = 0.1$ (after normalizing $X$ to have standard deviation 1 in-sample). For the @lindo2010ability data, this means that BARDDT was trained on $n = 40,582$ observations, of which 1,602 satisfy $x_i \leq 0.1$, which were used to generate the effect moderation tree. - -```{r cart_summary, fig.cap="Regression tree fit to posterior point estimates of individual treatment effects: top number in each box is the average subgroup treatment effect, lower number shows the percentage of the total sample in that subgroup; the tree flags credits in first year, gender, and age at entry as important moderators.", fig.align="center"} -## Fit regression tree -cate <- rpart(y~.,data.frame(y=rowMeans(pred),w[test,]),control = rpart.control(cp=0.015)) -## Define separate colors for left and rightmost nodes -plot.cart <- function(rpart.obj) -{ - rpart.frame <- rpart.obj$frame - left <- which.min(rpart.frame$yval) - right <- which.max(rpart.frame$yval) - nodes <- rep(NA,nrow(rpart.frame)) - for (i in 1:length(nodes)) - { - if (rpart.frame$yval[i]==rpart.frame$yval[right]) nodes[i] <- "gold2" - else if (rpart.frame$yval[i]==rpart.frame$yval[left]) nodes[i] <- "tomato3" - else nodes[i] <- "lightblue3" - } - return(nodes) -} -## Plot regression tree -rpart.plot(cate,main="",box.col=plot.cart(cate)) -``` - -The resulting effect moderation tree indicates that course load (credits attempted) in the academic term leading to their probation is a strong moderator. Contextually, this result is plausible, both because course load could relate to latent character attributes that influence a student's responsiveness to sanctions and also because it could predict course load in the current term, which would in turn have implications for the GPA (i.e. it is harder to get a high GPA while taking more credit hours). The tree also suggests that effects differ by campus, and age and gender of the student. These findings are all prima facie plausible as well. - -To gauge how strong these findings are statistically, we can zoom in on isolated subgroups and compare the posteriors of their subgroup average treatment effects. This approach is valid because in fitting the effect moderation tree to the posterior mean CATEs we in no way altered the posterior itself; the effect moderation tree is a posterior summary tool and not any additional inferential approach; the posterior is obtained once and can be explored freely using a variety of techniques without vitiating its statistical validity. Investigating the most extreme differences is a good place to start: consider the two groups of students at opposite ends of the treatment effect range discovered by the effect moderation tree: - -* **Group A** a male student that entered college older than 19 and attempted more than 4.8 credits in the first year (leftmost leaf node, colored red, comprising 128 individuals) -* **Group B** a student of any gender who entered college younger than 19 and attempted between 4.3 and 4.8 credits in the first year (rightmost leaf node, colored gold, comprising 108 individuals). - -Subgroup CATEs are obtained by aggregating CATEs across the observed $\w_i$ values for individuals in each group; this can be done for individual posterior samples, yielding a posterior distribution over the subgroup CATE: -\begin{equation} -\bar{\tau}_A^{(h)} = \frac{1}{n_A} \sum_{i : \w_i} \tau^{(h)}(0, \w_i), -\end{equation} -where $h$ indexes a posterior draw and $n_A$ denotes the number of individuals in the group A. - -The code below produces a contour plot for a bivariate kernel density estimate of the joint CATE posterior distribution for subgroups A and B. The contour lines are nearly all above the $45^{\circ}$ line, indicating that the preponderance of posterior probability falls in the region where the treatment effect for Group B is greater than that of Group A, meaning that the difference in the subgroup treatment effects flagged by the effect moderation tree persist even after accounting for estimation uncertainty in the underlying CATE function. - -```{r kde, fig.cap="Kernel density estimates for the joint CATE posterior between male students who entered college older than 19 and attempted more than 4.8 credits in the first year (leftmost leaf node, red) and students who entered college younger than 19 and attempted between 4.3 and 4.8 credits in the first year (rightmost leaf node, gold)", fig.align="center"} -## Define function to produce KD estimates of the joint distribution of two subgroups -cate.kde <- function(rpart.obj,pred) -{ - rpart.frame <- rpart.obj$frame - left <- rpart.obj$where==which.min(rpart.frame$yval) - right <- rpart.obj$where==which.max(rpart.frame$yval) - ## Calculate CATE posterior for groups A and B - cate.a <- do.call("cbind",by(pred,left, colMeans)) - cate.b <- do.call("cbind",by(pred,right, colMeans)) - cate.a <- cate.a[,2] - cate.b <- cate.b[,2] - ## Estimate kernel density - denshat <- MASS::kde2d(cate.a, cate.b, n=200) - return(denshat) -} -contour(cate.kde(cate,pred),bty='n',xlab="Group A",ylab="Group B") -abline(a=0,b=1) -``` - -As always, CATEs that vary with observable factors do not necessarily represent a _causal_ moderating relationship. Here, if the treatment effect of academic probation is seen to vary with the number of credits, that does not imply that this association is causal: prescribing students to take a certain number of credits will not necessarily lead to a more effective probation policy, it may simply be that the type of student to naturally enroll for fewer credit hours is more likely to be responsive to academic probation. An entirely distinct set of causal assumptions are required to interpret the CATE variations themselves as causal. All the same, uncovering these patterns of treatment effect variability are crucial to suggesting causal mechanism to be investigated in future studies. - -# References - - diff --git a/vignettes/_quarto.yml b/vignettes/_quarto.yml new file mode 100644 index 000000000..e0fd2dc90 --- /dev/null +++ b/vignettes/_quarto.yml @@ -0,0 +1,49 @@ +project: + type: website + output-dir: _site + +website: + title: "StochTree Vignettes" + navbar: + left: + - href: index.qmd + text: Home + sidebar: + - title: "Vignettes" + contents: + - index.qmd + - section: "Core Models" + contents: + - bart.qmd + - bcf.qmd + - heteroskedastic.qmd + - ordinal-outcome.qmd + - multivariate-bcf.qmd + - section: "Practical Topics" + contents: + - serialization.qmd + - multi-chain.qmd + - tree-inspection.qmd + - summary-plotting.qmd + - prior-calibration.qmd + - sklearn.qmd + - section: "Low-Level Interface" + contents: + - custom-sampling.qmd + - ensemble-kernel.qmd + - section: "Advanced Methods" + contents: + - rdd.qmd + - iv.qmd + +format: + html: + theme: cosmo + toc: true + toc-depth: 3 + grid: + body-width: 960px + margin-width: 200px + +execute: + freeze: auto diff --git a/vignettes/bart.qmd b/vignettes/bart.qmd new file mode 100644 index 000000000..2e902d366 --- /dev/null +++ b/vignettes/bart.qmd @@ -0,0 +1,322 @@ +--- +title: "Bayesian Additive Regression Trees for Supervised Learning" +bibliography: vignettes.bib +execute: + freeze: auto # re-render only when source changes +--- + +```{r} +#| include: false +reticulate::use_python( + Sys.getenv( + "RETICULATE_PYTHON", + unset = file.path(rprojroot::find_root(rprojroot::has_file(".here")), ".venv", "bin", "python") + ), + required = TRUE +) +``` + +This vignette demonstrates how to sample variants of the BART model (@chipman2010bart), using the `bart()` function in `stochtree`. The original BART model is + +$$ +\begin{aligned} +y_i \mid X_i = x_i &\sim \mathcal{N}(f(x_i), \sigma^2)\\ +\sigma^2 &\sim \text{IG}\left(\frac{\nu}{2}, \frac{\nu \lambda}{2}\right) +\end{aligned} +$$ + +where + +$$ +f(X) = \sum_{s=1}^m g_s(X) +$$ + +and each $g_s$ refers to a decision tree function which partitions $X$ into $k_s$ mutually exclusive regions ($\mathcal{A}_s = \mathcal{A}_{s,1} \cup \dots \cup \mathcal{A}_{s,k_s}$) and assigns a scalar parameter $\mu_{s,j}$ to each region $\mathcal{A}_{s,j}$ + +$$ +g_s(x) = \sum_{j = 1}^{k_s} \mu_{s,j} \mathbb{I}\left(x \in \mathcal{A}_{s,j}\right). +$$ + +The partitions $\mathcal{A}_s$ are defined by a series of logical split rules $X_i \leq c$ where $i$ is a variable index and $c$ is a numeric cutpoint and these partitions are guided by a uniform prior on variables and cutpoints. The prior on partitions is further specified by a probability of splitting a node + +$$ +P(\text{split node } \eta) = \alpha (1 + \text{depth}_{\eta})^{-\beta} +$$ + +The prior for each leaf node parameter is + +$$ +\mu_{s,j} \sim \mathcal{N}\left(0, \sigma^2_{\mu}\right) +$$ + +Together, we refer to this conditional mean model as + +$$ +f(X) \sim \text{BART}(\alpha, \beta, m) +$$ + +This is the "core" of stochtree's supervised learning interface, though we support many expanded models including + +* linear leaf regression (i.e. each leaf node evaluates a linear regression on basis $W$ rather than return a constant), +* additive random effects, +* forest-based heteroskedasticity, +* binary / ordinal outcome modeling using the probit and complementary log-log (cloglog) links, + +and we offer the ability to sample any of the above models using the MCMC or the Grow-From-Root sampler (@he2023stochastic). + +# Setup + +To begin, we load the `stochtree` and other necessary packages. + +::::{.panel-tabset group="language"} + +## R + +```{r} +library(stochtree) +``` + +## Python + +```{python} +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from sklearn.model_selection import train_test_split +from stochtree import BARTModel +``` + +:::: + +We set a seed for reproducibility + +::::{.panel-tabset group="language"} + +## R + +```{r} +random_seed <- 1234 +set.seed(random_seed) +``` + +## Python + +```{python} +random_seed = 1234 +rng = np.random.default_rng(random_seed) +``` + +:::: + +# Demo 1: Step Function + +## Data Generation + +We generate data from a simple step function + +::::{.panel-tabset group="language"} + +## R + +```{r} +# Generate the data +n <- 500 +p_x <- 10 +snr <- 3 +X <- matrix(runif(n * p_x), ncol = p_x) +f_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * + (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) +noise_sd <- sd(f_XW) / snr +y <- f_XW + rnorm(n, 0, 1) * noise_sd +``` + +## Python + +```{python} +# Generate the data +n = 500 +p_x = 10 +snr = 3 +X = rng.uniform(0, 1, (n, p_x)) +f_XW = ( + ((X[:, 0] >= 0.0) & (X[:, 0] < 0.25)) * (-7.5) + + ((X[:, 0] >= 0.25) & (X[:, 0] < 0.5)) * (-2.5) + + ((X[:, 0] >= 0.5) & (X[:, 0] < 0.75)) * (2.5) + + ((X[:, 0] >= 0.75) & (X[:, 0] < 1.0)) * (7.5) +) +noise_sd = np.std(f_XW) / snr +y = f_XW + rng.normal(0, noise_sd, n) +``` + +:::: + +Split the data into train and test sets + +::::{.panel-tabset group="language"} + +## R + +```{r} +test_set_pct <- 0.2 +n_test <- round(test_set_pct * n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- as.data.frame(X[test_inds, ]) +X_train <- as.data.frame(X[train_inds, ]) +y_test <- y[test_inds] +y_train <- y[train_inds] +``` + +## Python + +```{python} +test_set_pct = 0.2 +X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=test_set_pct, random_state=random_seed +) +``` + +:::: + +## Sampling and Analysis + +We sample from a BART model of $y \mid X$ with 10 grow-from-root GFR samples (@he2023stochastic) followed by 100 MCMC samples (this is the default in `stochtree`), run for 4 chains initialized by different GFR iterations. + +We also specify $m = 100$ trees and we let both $\sigma^2$ and $\sigma^2_{\mu}$ be updated by Gibbs samplers. + +::::{.panel-tabset group="language"} + +## R + +```{r} +num_gfr <- 10 +num_burnin <- 0 +num_mcmc <- 100 +general_params <- list( + sample_sigma2_global = T, + num_threads = 1, + num_chains = 4, + random_seed = random_seed +) +mean_forest_params <- list(sample_sigma2_leaf = T, num_trees = 100) +bart_model <- stochtree::bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = num_gfr, + num_burnin = num_burnin, + num_mcmc = num_mcmc, + general_params = general_params, + mean_forest_params = mean_forest_params +) +``` + +## Python + +```{python} +num_gfr = 10 +num_burnin = 0 +num_mcmc = 100 +general_params = { + "sample_sigma2_global": True, + "num_threads": 1, + "num_chains": 4, + "random_seed": random_seed, +} +mean_forest_params = {"sample_sigma2_leaf": True, "num_trees": 100} +bart_model = BARTModel() +bart_model.sample( + X_train=X_train, + y_train=y_train, + X_test=X_test, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + general_params=general_params, + mean_forest_params=mean_forest_params, +) +``` + +:::: + +Plot the mean outcome predictions versus the true outcomes + +::::{.panel-tabset group="language"} + +## R + +```{r} +y_hat_test <- predict( + bart_model, + X = X_test, + terms = "y_hat", + type = "mean" +) +plot( + y_hat_test, + y_test, + xlab = "predicted", + ylab = "actual", + main = "Outcome" +) +abline(0, 1, col = "red", lty = 3, lwd = 3) +``` + +## Python + +```{python} +y_hat_test = bart_model.predict(X=X_test, terms="y_hat", type="mean") +lo, hi = min(y_hat_test.min(), y_test.min()), max(y_hat_test.max(), y_test.max()) +plt.scatter(y_hat_test, y_test, alpha=0.5) +plt.plot([lo, hi], [lo, hi], color="red", linestyle="dashed", linewidth=2) +plt.xlabel("Predicted") +plt.ylabel("Actual") +plt.title("Outcome") +plt.show() +``` + +:::: + +Plot the $\sigma^2$ traceplot + +::::{.panel-tabset group="language"} + +## R + +```{r} +sigma_observed <- var(y - f_XW) +sigma2_global_samples <- extractParameter(bart_model, "sigma2_global") +plot_bounds <- c( + min(c(sigma2_global_samples, sigma_observed)), + max(c(sigma2_global_samples, sigma_observed)) +) +plot( + sigma2_global_samples, + ylim = plot_bounds, + ylab = "sigma^2", + xlab = "Sample", + main = "Global variance parameter" +) +abline(h = sigma_observed, lty = 3, lwd = 3, col = "blue") +``` + +## Python + +```{python} +sigma_observed = np.var(y - f_XW) +global_var_samples = bart_model.extract_parameter("sigma2_global") +plt.plot(global_var_samples) +plt.axhline(sigma_observed, color="blue", linestyle="dashed", linewidth=2) +plt.xlabel("Sample") +plt.ylabel(r"$\sigma^2$") +plt.title("Global variance parameter") +plt.show() +``` + +:::: + +# References diff --git a/vignettes/bcf.qmd b/vignettes/bcf.qmd new file mode 100644 index 000000000..2920bc192 --- /dev/null +++ b/vignettes/bcf.qmd @@ -0,0 +1,766 @@ +--- +title: "Bayesian Causal Forests for Treatment Effect Estimation" +bibliography: vignettes.bib +execute: + freeze: auto # re-render only when source changes +--- + +```{r} +#| include: false +reticulate::use_python( + Sys.getenv( + "RETICULATE_PYTHON", + unset = file.path(rprojroot::find_root(rprojroot::has_file(".here")), ".venv", "bin", "python") + ), + required = TRUE +) +``` + +This vignette demonstrates how to use the `bcf()` function for causal inference +(@hahn2020bayesian). BCF models the conditional average treatment effect (CATE) +by fitting two separate tree ensembles + +$$ +Y_i = \mu(X_i) + \tau(X_i) Z_i + \epsilon_i, \quad \epsilon_i \sim \mathcal{N}(0, \sigma^2) +$$ + +where $\mu(\cdot)$ is a prognostic forest and $\tau(\cdot)$ is a treatment effect +forest. The estimated propensity score $\hat{\pi}(X_i)$ is included as a covariate +in $\mu(\cdot)$ to reduce confounding bias. + +# Setup + +Load necessary packages + +::::{.panel-tabset group="language"} + +## R + +```{r} +library(stochtree) +``` + +## Python + +```{python} +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +from scipy.stats import norm +from stochtree import BCFModel +``` + +:::: + +Set a seed for reproducibility + +::::{.panel-tabset group="language"} + +## R + +```{r} +random_seed <- 1234 +set.seed(random_seed) +``` + +## Python + +```{python} +random_seed = 1234 +rng = np.random.default_rng(random_seed) +``` + +:::: + +We also define several simple functions that configure the data generating processes +used in this vignette + +::::{.panel-tabset group="language"} + +## R + +```{r} +g <- function(x) { + ifelse(x[, 5] == 1, 2, ifelse(x[, 5] == 2, -1, -4)) +} +mu1 <- function(x) { + 1 + g(x) + x[, 1] * x[, 3] +} +mu2 <- function(x) { + 1 + g(x) + 6 * abs(x[, 3] - 1) +} +tau1 <- function(x) { + rep(3, nrow(x)) +} +tau2 <- function(x) { + 1 + 2 * x[, 2] * x[, 4] +} +``` + +## Python + +```{python} +def g(x): + return np.where(x[:, 4] == 1, 2, np.where(x[:, 4] == 2, -1, -4)) + + +def mu1(x): + return 1 + g(x) + x[:, 0] * x[:, 2] + + +def mu2(x): + return 1 + g(x) + 6 * np.abs(x[:, 2] - 1) + + +def tau1(x): + return np.full(x.shape[0], 3.0) + + +def tau2(x): + return 1 + 2 * x[:, 1] * x[:, 3] +``` + +:::: + +# Binary Treatment + +## Demo 1: Linear Outcome Model, Heterogeneous Treatment Effect + +We consider the following data generating process from @hahn2020bayesian: + +\begin{equation*} +\begin{aligned} +y &= \mu(X) + \tau(X) Z + \epsilon\\ +\epsilon &\sim N\left(0,\sigma^2\right)\\ +\mu(X) &= 1 + g(X) + 6 X_1 X_3\\ +\tau(X) &= 1 + 2 X_2 X_4\\ +g(X) &= \mathbb{I}(X_5=1) \times 2 - \mathbb{I}(X_5=2) \times 1 - \mathbb{I}(X_5=3) \times 4\\ +s_{\mu} &= \sqrt{\mathbb{V}(\mu(X))}\\ +\pi(X) &= 0.8 \phi\left(\frac{3\mu(X)}{s_{\mu}}\right) - \frac{X_1}{2} + \frac{2U+1}{20}\\ +X_1,X_2,X_3 &\sim N\left(0,1\right)\\ +X_4 &\sim \text{Bernoulli}(1/2)\\ +X_5 &\sim \text{Categorical}(1/3,1/3,1/3)\\ +U &\sim \text{Uniform}\left(0,1\right)\\ +Z &\sim \text{Bernoulli}\left(\pi(X)\right) +\end{aligned} +\end{equation*} + +### Simulation + +We generate data from the DGP defined above + +::::{.panel-tabset group="language"} + +## R + +```{r} +n <- 1000 +snr <- 3 +x1 <- rnorm(n) +x2 <- rnorm(n) +x3 <- rnorm(n) +x4 <- as.numeric(rbinom(n, 1, 0.5)) +x5 <- as.numeric(sample(1:3, n, replace = TRUE)) +X <- cbind(x1, x2, x3, x4, x5) +p <- ncol(X) +mu_x <- mu1(X) +tau_x <- tau2(X) +pi_x <- 0.8 * pnorm((3 * mu_x / sd(mu_x)) - 0.5 * X[, 1]) + 0.05 + runif(n) / 10 +Z <- rbinom(n, 1, pi_x) +E_XZ <- mu_x + Z * tau_x +y <- E_XZ + rnorm(n, 0, 1) * (sd(E_XZ) / snr) +X <- as.data.frame(X) +X$x4 <- factor(X$x4, ordered = TRUE) +X$x5 <- factor(X$x5, ordered = TRUE) +``` + +## Python + +```{python} +n = 1000 +snr = 3 +x1 = rng.normal(size=n) +x2 = rng.normal(size=n) +x3 = rng.normal(size=n) +x4 = rng.binomial(1, 0.5, n).astype(float) +x5 = rng.choice([1, 2, 3], size=n).astype(float) +X = np.column_stack([x1, x2, x3, x4, x5]) +mu_x = mu1(X) +tau_x = tau2(X) +pi_x = ( + 0.8 * norm.cdf((3 * mu_x / np.std(mu_x)) - 0.5 * X[:, 0]) + + 0.05 + + rng.uniform(size=n) / 10 +) +Z = rng.binomial(1, pi_x, n).astype(float) +E_XZ = mu_x + Z * tau_x +y = E_XZ + rng.normal(size=n) * (np.std(E_XZ) / snr) +X_df = pd.DataFrame({"x1": x1, "x2": x2, "x3": x3, "x4": x4, "x5": x5}) +X_df["x4"] = pd.Categorical(X_df["x4"].astype(int), categories=[0, 1], ordered=True) +X_df["x5"] = pd.Categorical(X_df["x5"].astype(int), categories=[1, 2, 3], ordered=True) +``` + +:::: + +Split data into test and train sets + +::::{.panel-tabset group="language"} + +## R + +```{r} +test_set_pct <- 0.2 +n_test <- round(test_set_pct * n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds, ] +X_train <- X[train_inds, ] +pi_test <- pi_x[test_inds] +pi_train <- pi_x[train_inds] +Z_test <- Z[test_inds] +Z_train <- Z[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] +mu_test <- mu_x[test_inds] +mu_train <- mu_x[train_inds] +tau_test <- tau_x[test_inds] +tau_train <- tau_x[train_inds] +``` + +## Python + +```{python} +test_set_pct = 0.2 +n_test = round(test_set_pct * n) +n_train = n - n_test +test_inds = rng.choice(n, n_test, replace=False) +train_inds = np.setdiff1d(np.arange(n), test_inds) +X_test = X_df.iloc[test_inds] +X_train = X_df.iloc[train_inds] +pi_test, pi_train = pi_x[test_inds], pi_x[train_inds] +Z_test, Z_train = Z[test_inds], Z[train_inds] +y_test, y_train = y[test_inds], y[train_inds] +mu_test, mu_train = mu_x[test_inds], mu_x[train_inds] +tau_test, tau_train = tau_x[test_inds], tau_x[train_inds] +``` + +:::: + + +### Sampling and Analysis + +We simulate from a BCF model initialized by "warm-start" samples fit with the grow-from-root algorithm (@he2023stochastic, @krantsevich2023stochastic). This is the default in `stochtree`. + +::::{.panel-tabset group="language"} + +## R + +```{r} +general_params <- list( + num_threads=1, + num_chains=4, + random_seed=random_seed +) +bcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + num_gfr = 10, + num_burnin = 1000, + num_mcmc = 100, + propensity_test = pi_test, + general_params = general_params +) +``` + +## Python + +```{python} +general_params = {"num_threads": 1, "num_chains": 4, "random_seed": random_seed} +bcf_model = BCFModel() +bcf_model.sample( + X_train=X_train, + Z_train=Z_train, + y_train=y_train, + propensity_train=pi_train, + X_test=X_test, + Z_test=Z_test, + num_gfr=10, + num_burnin=1000, + num_mcmc=100, + propensity_test=pi_test, + general_params=general_params, +) +``` + +:::: + +Plot the true versus estimated prognostic function + +::::{.panel-tabset group="language"} + +## R + +```{r} +mu_hat_test <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + terms = "prognostic_function" +) +plot( + rowMeans(mu_hat_test), + mu_test, + xlab = "predicted", + ylab = "actual", + main = "Prognostic function" +) +abline(0, 1, col = "red", lty = 3, lwd = 3) +``` + +## Python + +```{python} +mu_hat_test = bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_test, terms="prognostic_function" +) +mu_pred = mu_hat_test.mean(axis=1) +lo, hi = min(mu_pred.min(), mu_test.min()), max(mu_pred.max(), mu_test.max()) +plt.scatter(mu_pred, mu_test, alpha=0.5) +plt.plot([lo, hi], [lo, hi], color="red", linestyle="dashed", linewidth=2) +plt.xlabel("Predicted") +plt.ylabel("Actual") +plt.title("Prognostic function") +plt.show() +``` + +:::: + +Plot the true versus estimated CATE function + +::::{.panel-tabset group="language"} + +## R + +```{r} +tau_hat_test <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + terms = "cate" +) +plot( + rowMeans(tau_hat_test), + tau_test, + xlab = "predicted", + ylab = "actual", + main = "Treatment effect" +) +abline(0, 1, col = "red", lty = 3, lwd = 3) +``` + +## Python + +```{python} +tau_hat_test = bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_test, terms="cate") +tau_pred = tau_hat_test.mean(axis=1) +lo, hi = min(tau_pred.min(), tau_test.min()), max(tau_pred.max(), tau_test.max()) +plt.scatter(tau_pred, tau_test, alpha=0.5) +plt.plot([lo, hi], [lo, hi], color="red", linestyle="dashed", linewidth=2) +plt.xlabel("Predicted") +plt.ylabel("Actual") +plt.title("Treatment effect") +plt.show() +``` + +:::: + +Plot the $\sigma^2$ traceplot + +::::{.panel-tabset group="language"} + +## R + +```{r} +sigma_observed <- var(y - E_XZ) +sigma2_global_samples <- extractParameter(bcf_model, "sigma2_global") +plot_bounds <- c( + min(c(sigma2_global_samples, sigma_observed)), + max(c(sigma2_global_samples, sigma_observed)) +) +plot( + sigma2_global_samples, + ylim = plot_bounds, + ylab = "sigma^2", + xlab = "Sample", + main = "Global variance parameter" +) +abline(h = sigma_observed, lty = 3, lwd = 3, col = "blue") +``` + +## Python + +```{python} +sigma_observed = np.var(y - E_XZ) +global_var_samples = bcf_model.extract_parameter("sigma2_global") +plt.plot(global_var_samples) +plt.axhline(sigma_observed, color="blue", linestyle="dashed", linewidth=2) +plt.xlabel("Sample") +plt.ylabel(r"$\sigma^2$") +plt.title("Global variance parameter") +plt.show() +``` + +:::: + +Examine test set interval coverage of $\tau(X)$. + +::::{.panel-tabset group="language"} + +## R + +```{r} +test_lb <- apply(tau_hat_test, 1, quantile, 0.025) +test_ub <- apply(tau_hat_test, 1, quantile, 0.975) +cover <- ((test_lb <= tau_x[test_inds]) & + (test_ub >= tau_x[test_inds])) +cat("CATE function interval coverage: ", mean(cover) * 100, "%\n") +``` + +## Python + +```{python} +test_lb = np.quantile(tau_hat_test, 0.025, axis=1) +test_ub = np.quantile(tau_hat_test, 0.975, axis=1) +cover = (test_lb <= tau_test) & (test_ub >= tau_test) +print(f"CATE function interval coverage: {cover.mean() * 100:.2f}%") +``` + +:::: + +## Demo 2: Nonlinear Outcome Model, Heterogeneous Treatment Effect + +We consider the following data generating process from @hahn2020bayesian: + +\begin{equation*} +\begin{aligned} +y &= \mu(X) + \tau(X) Z + \epsilon\\ +\epsilon &\sim N\left(0,\sigma^2\right)\\ +\mu(X) &= 1 + g(X) + 6 \lvert X_3 - 1 \rvert\\ +\tau(X) &= 1 + 2 X_2 X_4\\ +g(X) &= \mathbb{I}(X_5=1) \times 2 - \mathbb{I}(X_5=2) \times 1 - \mathbb{I}(X_5=3) \times 4\\ +s_{\mu} &= \sqrt{\mathbb{V}(\mu(X))}\\ +\pi(X) &= 0.8 \phi\left(\frac{3\mu(X)}{s_{\mu}}\right) - \frac{X_1}{2} + \frac{2U+1}{20}\\ +X_1,X_2,X_3 &\sim N\left(0,1\right)\\ +X_4 &\sim \text{Bernoulli}(1/2)\\ +X_5 &\sim \text{Categorical}(1/3,1/3,1/3)\\ +U &\sim \text{Uniform}\left(0,1\right)\\ +Z &\sim \text{Bernoulli}\left(\pi(X)\right) +\end{aligned} +\end{equation*} + +### Simulation + +Generate data from the DGP above + +::::{.panel-tabset group="language"} + +## R + +```{r} +n <- 1000 +snr <- 3 +x1 <- rnorm(n) +x2 <- rnorm(n) +x3 <- rnorm(n) +x4 <- as.numeric(rbinom(n, 1, 0.5)) +x5 <- as.numeric(sample(1:3, n, replace = TRUE)) +X <- cbind(x1, x2, x3, x4, x5) +p <- ncol(X) +mu_x <- mu2(X) +tau_x <- tau2(X) +pi_x <- 0.8 * pnorm((3 * mu_x / sd(mu_x)) - 0.5 * X[, 1]) + 0.05 + runif(n) / 10 +Z <- rbinom(n, 1, pi_x) +E_XZ <- mu_x + Z * tau_x +y <- E_XZ + rnorm(n, 0, 1) * (sd(E_XZ) / snr) +X <- as.data.frame(X) +X$x4 <- factor(X$x4, ordered = TRUE) +X$x5 <- factor(X$x5, ordered = TRUE) +``` + +## Python + +```{python} +n = 1000 +snr = 3 +x1 = rng.normal(size=n) +x2 = rng.normal(size=n) +x3 = rng.normal(size=n) +x4 = rng.binomial(1, 0.5, n).astype(float) +x5 = rng.choice([1, 2, 3], size=n).astype(float) +X = np.column_stack([x1, x2, x3, x4, x5]) +mu_x = mu2(X) # mu2 for Demo 2 +tau_x = tau2(X) +pi_x = ( + 0.8 * norm.cdf((3 * mu_x / np.std(mu_x)) - 0.5 * X[:, 0]) + + 0.05 + + rng.uniform(size=n) / 10 +) +Z = rng.binomial(1, pi_x, n).astype(float) +E_XZ = mu_x + Z * tau_x +y = E_XZ + rng.normal(size=n) * (np.std(E_XZ) / snr) +X_df = pd.DataFrame({"x1": x1, "x2": x2, "x3": x3, "x4": x4, "x5": x5}) +X_df["x4"] = pd.Categorical(X_df["x4"].astype(int), categories=[0, 1], ordered=True) +X_df["x5"] = pd.Categorical(X_df["x5"].astype(int), categories=[1, 2, 3], ordered=True) +``` + +:::: + +Split into train and test sets + +::::{.panel-tabset group="language"} + +## R + +```{r} +test_set_pct <- 0.2 +n_test <- round(test_set_pct * n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds, ] +X_train <- X[train_inds, ] +pi_test <- pi_x[test_inds] +pi_train <- pi_x[train_inds] +Z_test <- Z[test_inds] +Z_train <- Z[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] +mu_test <- mu_x[test_inds] +mu_train <- mu_x[train_inds] +tau_test <- tau_x[test_inds] +tau_train <- tau_x[train_inds] +``` + +## Python + +```{python} +test_set_pct = 0.2 +n_test = round(test_set_pct * n) +n_train = n - n_test +test_inds = rng.choice(n, n_test, replace=False) +train_inds = np.setdiff1d(np.arange(n), test_inds) +X_test = X_df.iloc[test_inds] +X_train = X_df.iloc[train_inds] +pi_test, pi_train = pi_x[test_inds], pi_x[train_inds] +Z_test, Z_train = Z[test_inds], Z[train_inds] +y_test, y_train = y[test_inds], y[train_inds] +mu_test, mu_train = mu_x[test_inds], mu_x[train_inds] +tau_test, tau_train = tau_x[test_inds], tau_x[train_inds] +``` + +:::: + +### Sampling and Analysis + +We simulate from a BCF model using default settings. + +::::{.panel-tabset group="language"} + +## R + +```{r} +general_params <- list( + num_threads = 1, + num_chains = 4, + random_seed = random_seed +) +bcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 10, + num_burnin = 1000, + num_mcmc = 100, + general_params = general_params +) +``` + +## Python + +```{python} +general_params = {"num_threads": 1, "num_chains": 4, "random_seed": random_seed} +bcf_model = BCFModel() +bcf_model.sample( + X_train=X_train, + Z_train=Z_train, + y_train=y_train, + propensity_train=pi_train, + X_test=X_test, + Z_test=Z_test, + propensity_test=pi_test, + num_gfr=10, + num_burnin=1000, + num_mcmc=100, + general_params=general_params, +) +``` + +:::: + +Plot the true versus estimated prognostic function + +::::{.panel-tabset group="language"} + +## R + +```{r} +mu_hat_test <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + terms = "prognostic_function" +) +plot( + rowMeans(mu_hat_test), + mu_test, + xlab = "predicted", + ylab = "actual", + main = "Prognostic function" +) +abline(0, 1, col = "red", lty = 3, lwd = 3) +``` + +## Python + +```{python} +mu_hat_test = bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_test, terms="prognostic_function" +) +mu_pred = mu_hat_test.mean(axis=1) +lo, hi = min(mu_pred.min(), mu_test.min()), max(mu_pred.max(), mu_test.max()) +plt.scatter(mu_pred, mu_test, alpha=0.5) +plt.plot([lo, hi], [lo, hi], color="red", linestyle="dashed", linewidth=2) +plt.xlabel("Predicted") +plt.ylabel("Actual") +plt.title("Prognostic function") +plt.show() +``` + +:::: + +Plot the true versus estimated CATE function + +::::{.panel-tabset group="language"} + +## R + +```{r} +tau_hat_test <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + terms = "cate" +) +plot( + rowMeans(tau_hat_test), + tau_test, + xlab = "predicted", + ylab = "actual", + main = "Treatment effect" +) +abline(0, 1, col = "red", lty = 3, lwd = 3) +``` + +## Python + +```{python} +tau_hat_test = bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_test, terms="cate") +tau_pred = tau_hat_test.mean(axis=1) +lo, hi = min(tau_pred.min(), tau_test.min()), max(tau_pred.max(), tau_test.max()) +plt.scatter(tau_pred, tau_test, alpha=0.5) +plt.plot([lo, hi], [lo, hi], color="red", linestyle="dashed", linewidth=2) +plt.xlabel("Predicted") +plt.ylabel("Actual") +plt.title("Treatment effect") +plt.show() +``` + +:::: + +Plot the $\sigma^2$ traceplot + +::::{.panel-tabset group="language"} + +## R + +```{r} +sigma_observed <- var(y - E_XZ) +sigma2_global_samples <- extractParameter(bcf_model, "sigma2_global") +plot_bounds <- c( + min(c(sigma2_global_samples, sigma_observed)), + max(c(sigma2_global_samples, sigma_observed)) +) +plot( + sigma2_global_samples, + ylim = plot_bounds, + ylab = "sigma^2", + xlab = "Sample", + main = "Global variance parameter" +) +abline(h = sigma_observed, lty = 3, lwd = 3, col = "blue") +``` + +## Python + +```{python} +sigma_observed = np.var(y - E_XZ) +global_var_samples = bcf_model.extract_parameter("sigma2_global") +plt.plot(global_var_samples) +plt.axhline(sigma_observed, color="blue", linestyle="dashed", linewidth=2) +plt.xlabel("Sample") +plt.ylabel(r"$\sigma^2$") +plt.title("Global variance parameter") +plt.show() +``` + +:::: + +Examine test set interval coverage of $\tau(X)$. + +::::{.panel-tabset group="language"} + +## R + +```{r} +test_lb <- apply(tau_hat_test, 1, quantile, 0.025) +test_ub <- apply(tau_hat_test, 1, quantile, 0.975) +cover <- ((test_lb <= tau_x[test_inds]) & + (test_ub >= tau_x[test_inds])) +cat("CATE function interval coverage: ", mean(cover) * 100, "%\n") +``` + +## Python + +```{python} +test_lb = np.quantile(tau_hat_test, 0.025, axis=1) +test_ub = np.quantile(tau_hat_test, 0.975, axis=1) +cover = (test_lb <= tau_test) & (test_ub >= tau_test) +print(f"CATE function interval coverage: {cover.mean() * 100:.2f}%") +``` + +:::: + +# References diff --git a/vignettes/custom-sampling.qmd b/vignettes/custom-sampling.qmd new file mode 100644 index 000000000..37aacba48 --- /dev/null +++ b/vignettes/custom-sampling.qmd @@ -0,0 +1,983 @@ +--- +title: "Building a Custom Gibbs Sampler with Stochtree Primitives" +bibliography: vignettes.bib +execute: + freeze: auto # re-render only when source changes +--- + +```{r} +#| include: false +reticulate::use_python( + Sys.getenv( + "RETICULATE_PYTHON", + unset = file.path(rprojroot::find_root(rprojroot::has_file(".here")), ".venv", "bin", "python") + ), + required = TRUE +) +``` + +While the functions `bart()` and `bcf()` provide simple and performant interfaces for +supervised learning / causal inference, `stochtree` also offers access to many of the +"low-level" data structures that are typically implemented in C++. This low-level +interface is not designed for performance or even simplicity β€” rather the intent is +to provide a "prototype" interface to the C++ code that doesn't require modifying any +C++. + +# Motivation + +To illustrate when such a prototype interface might be useful, consider the classic +BART algorithm + + + +::: {.algorithm style="border-top: 2px solid; border-bottom: 2px solid; padding: 0.6em 1em; margin: 1.5em 0;"} + +Input: $y$, $X$, $\tau$, $\nu$, $\lambda$, $\alpha$, $\beta$ + +Output: $mc$ samples of a decision forest with $m$ trees and global variance parameter $\sigma^2$ + +Initialize $\sigma^2$ via a default or a data-dependent calibration exercise + +Initialize a forest with $m$ trees with a single root node, referring to tree $j$'s prediction vector as $f_{j}$ + +Compute residual as $r = y - \sum_{j=1}^k f_{j}$ + +For $i$ in $\left\{1,\dots,mc\right\}$: + +::::: {style="margin-left: 2em"} +For $j$ in $\left\{1,\dots,m\right\}$: + +::::: {style="margin-left: 2em"} +Add predictions for tree $j$ to residual: $r = r + f_{j}$ + +Sample tree $j$ of forest $i$ from $p\left(\mathcal{T}_{i,j} \mid r, \sigma^2\right)$ + +Sample tree $j$'s leaf parameters from $p\left(\theta_{i,j} \mid \mathcal{T}_{i,j}, r, \sigma^2\right)$ and update $f_j$ accordingly + +Update residual by removing tree $j$'s predictions: $r = r - f_{j}$ + +::::: + +Sample $\sigma^2$ from $p\left(\sigma^2 \mid r\right)$ + +::::: + +Return each of the forests and $\sigma^2$ draws + +::: + +This algorithm is implemented in `stochtree` via the `bart()` R function or the `BARTModel` python class, but the low-level interface allows you to customize this loop. + +In this vignette, we will demonstrate how to use this interface to fit a modified BART model in which the global error variance is modeled as $t$-distributed rather than Gaussian. + +# Setup + +::::{.panel-tabset group="language"} + +## R + +```{r} +library(stochtree) +``` + +## Python + +```{python} +import numpy as np +import matplotlib.pyplot as plt +from stochtree import ( + RNG, Dataset, Forest, ForestContainer, ForestSampler, + GlobalVarianceModel, LeafVarianceModel, Residual, + ForestModelConfig, GlobalModelConfig, +) +``` + +:::: + +Set seed for reproducibility + +::::{.panel-tabset group="language"} + +## R + +```{r} +random_seed <- 1234 +set.seed(random_seed) +``` + +## Python + +```{python} +random_seed = 1234 +rng = np.random.default_rng(random_seed) +``` + +:::: + +# Data Generation and Preparation + +Consider a modified version of the "Friedman dataset" (@friedman1991multivariate) with heavy-tailed errors + +$$ +\begin{aligned} +Y_i \mid X_i = x_i &\overset{\text{iid}}{\sim} t_{\nu}\left(f(x_i), \sigma^2\right),\\ +f(x) &= 10 \sin \left(\pi x_1 x_2\right) + 20 (x_3 - 1/2)^2 + 10 x_4 + 5 x_5,\\ +X_1, \dots, X_p &\overset{\text{iid}}{\sim} \text{U}\left(0,1\right), +\end{aligned} +$$ + +where $t_{\nu}(\mu,\sigma^2)$ represented a generalized $t$ distribution with location $\mu$, scale $\sigma^2$ and $\nu$ degrees of freedom. + +We simulate from this dataset below + +::::{.panel-tabset group="language"} + +## R + +```{r} +n <- 1000 +p <- 20 +X <- matrix(runif(n * p), ncol = p) +m_x <- (10 * + sin(pi * X[, 1] * X[, 2]) + + 20 * (X[, 3] - 0.5)^2 + + 10 * X[, 4] + + 5 * X[, 5]) +sigma2 <- 9 +nu <- 2 +eps <- rt(n, df = nu) * sqrt(sigma2) +y <- m_x + eps +sigma2_true <- var(eps) +``` + +## Python + +```{python} +n = 1000 +p = 20 +X = rng.uniform(low=0.0, high=1.0, size=(n, p)) +m_x = ( + 10 * np.sin(np.pi * X[:, 0] * X[:, 1]) + + 20 * np.power(X[:, 2] - 0.5, 2.0) + + 10 * X[:, 3] + + 5 * X[:, 4] +) +sigma2 = 9 +nu = 2 +eps = rng.standard_t(df=nu, size=n) * np.sqrt(sigma2) +y = m_x + eps +sigma2_true = np.var(eps) +``` + +:::: + +And we pre-standardize the outcome + +::::{.panel-tabset group="language"} + +## R + +```{r} +y_bar <- mean(y) +y_std <- sd(y) +y_standardized <- (y - y_bar) / y_std +``` + +## Python + +```{python} +y_bar = np.mean(y) +y_std = np.std(y) +y_standardized = (y - y_bar) / y_std +``` + +:::: + +# Sampling + +We can obtain $t$-distributed errors by augmenting the basic BART model with a further prior on the individual variances: +$$ +\begin{aligned} +Y_i \mid (X_i = x_i) &\overset{\text{iid}}{\sim} \mathrm{N}(f(x_i), \phi_i),\\ +\phi_i &\overset{\text{iid}}{\sim} \text{IG}\left(\frac{\nu}{2}, \frac{\nu\sigma^2}{2}\right),\\ +f &\sim \mathrm{BART}(\alpha,\beta,m). +\end{aligned} +$$ +Any Gamma prior on $\sigma^2$ ensures conditional conjugacy, though for simplicity's sake we use a log-uniform prior $\sigma^2\propto 1 / \sigma^2$. In the implementation below, we sample from a "parameter-expanded" variant of this model discussed in Section 12.1 of @gelman2013bayesian, which possesses favorable convergence properties. +$$ +\begin{aligned} +Y_i \mid (X_i = x_i) &\overset{\text{iid}}{\sim} \mathrm{N}(f(x_i), a^2\phi_i),\\ +\phi_i &\overset{\text{iid}}{\sim} \text{IG}\left(\frac{\nu}{2}, \frac{\nu\tau^2}{2}\right),\\ +a^2 &\propto 1/a^2,\\ +\tau^2 &\propto 1/\tau^2,\\ +f &\sim \mathrm{BART}(\alpha,\beta,m). +\end{aligned} +$$ + +## Helper functions + +We define several helper functions for Gibbs draws of each of the above parameters. + +::::{.panel-tabset group="language"} + +## R + +```{r} +# Sample observation-specific variance parameters phi_i +sample_phi_i <- function(y, dataset, forest, a2, tau2, nu) { + n <- length(y) + yhat_forest <- forest$predict(dataset) + res <- y - yhat_forest + posterior_shape <- (nu + 1) / 2 + posterior_scale <- (nu * tau2 + (res * res / a2)) / 2 + return(1 / rgamma(n, posterior_shape, rate = posterior_scale)) +} + +# Sample variance parameter a^2 +sample_a2 <- function(y, dataset, forest, phi_i) { + n <- length(y) + yhat_forest <- forest$predict(dataset) + res <- y - yhat_forest + posterior_shape <- n / 2 + posterior_scale <- (1 / 2) * sum(res * res / phi_i) + return(1 / rgamma(1, posterior_shape, rate = posterior_scale)) +} + +# Sample variance parameter tau^2 +sample_tau2 <- function(phi_i, nu) { + n <- length(phi_i) + posterior_shape <- nu * n / 2 + posterior_scale <- (nu / 2) * sum(1 / phi_i) + return(1 / rgamma(1, posterior_shape, rate = posterior_scale)) +} +``` + +## Python + +```{python} +def sample_phi_i( + y: np.array, + dataset: Dataset, + forest: Forest, + a2: float, + tau2: float, + nu: float, + rng: np.random.Generator, +) -> np.array: + """ + Sample observation-specific variance parameters phi_i + """ + n = len(y) + yhat_forest = forest.predict(dataset) + res = y - yhat_forest + posterior_shape = (nu + 1) / 2 + posterior_scale = (nu * tau2 + (res * res / a2)) / 2 + return 1 / rng.gamma(shape=posterior_shape, scale=1 / posterior_scale, size=n) + + +def sample_a2( + y: np.array, + dataset: Dataset, + forest: Forest, + phi_i: np.array, + rng: np.random.Generator, +) -> float: + """ + Sample variance parameter a^2 + """ + n = len(y) + yhat_forest = forest.predict(dataset) + res = y - yhat_forest + posterior_shape = n / 2 + posterior_scale = (1 / 2) * np.sum(res * res / phi_i) + return 1 / rng.gamma(shape=posterior_shape, scale=1 / posterior_scale, size=1)[0] + + +def sample_tau2(phi_i: np.array, nu: float, rng: np.random.Generator) -> float: + """ + Sample variance parameter tau^2 + """ + n = len(phi_i) + posterior_shape = nu * n / 2 + posterior_scale = (nu / 2) * np.sum(1 / phi_i) + return 1 / rng.gamma(shape=posterior_shape, scale=1 / posterior_scale, size=1)[0] +``` + +:::: + +## Sampling data structures + +The underlying C++ codebase centers around a handful of objects and their interactions. We provide R and Python wrappers for these objects to enable greater customization of stochastic tree samplers than can be furnished by the high-level BART and BCF interfaces. + +A "Forest Dataset" class manages covariates, bases, and variance weights used in a forest model, and contains methods for updating the underlying data as well as querying numeric attributes of the data (i.e. `num_observations`, `num_covariates`, `has_basis`, etc...). An Outcome / Residual class wraps the model outcome, which is updated in-place during sampling to reflect the full, or partial, residual net of mean forest or random effects predictions. A "Forest Samples" class is a container of sampled tree ensembles, essentially a very thin wrapper around a C++ `std::vector` of `std::unique_ptr` to `Ensemble` objects. A Forest class is a thin wrapper around `Ensemble` C++ objects, which is used as the "active forest" or "state" of the forest model during sampling. A "Forest Model" class maintains all of the "temporary" data structures used to sample a forest, and its `sample_one_iteration()` method performs one iteration of the requested forest sampling algorithm (i.e. Metropolis-Hastings or Grow-From-Root). Two different configuration objects (global and forest-specific) manage the parameters needed to run the samplers. + +Writing a custom Gibbs sampler with one or more stochastic forest terms requires initializing each of these objects and then deploying them in a sampling loop. + +First, we initialize the data objects with covariates and standardized outcomes + +::::{.panel-tabset group="language"} + +## R + +```{r} +# Initial values of robust model parameters +tau2_init <- 1. +a2_init <- 1. +sigma2_init <- 1. +phi_i_init <- rep(1., n) + +# Initialize data objects +forest_dataset <- createForestDataset(X, variance_weights = 1 / phi_i_init) +outcome <- createOutcome(y_standardized) +``` + +## Python + +```{python} +# Initial values of robust model parameters +tau2_init = 1.0 +a2_init = 1.0 +sigma2_init = tau2_init * a2_init +phi_i_init = np.repeat(1.0, n) + +# Initialize data objects +forest_dataset = Dataset() +forest_dataset.add_covariates(X) +forest_dataset.add_variance_weights(1.0 / phi_i_init) +residual = Residual(y_standardized) +``` + +:::: + +Next, we initialize random number generator objects, which are essentially wrappers around `std::mt19937`, which can optionally be seeded for reproducibility. + +::::{.panel-tabset group="language"} + +## R + +```{r} +rng <- createCppRNG(random_seed) +``` + +## Python + +```{python} +cpp_rng = RNG(random_seed) +``` + +:::: + +Next, we initialize the configuration objects. Note that each config has default values so these parameters do not all need to be explicitly set. + +::::{.panel-tabset group="language"} + +## R + +```{r} +# Set parameters +outcome_model_type <- 0 +leaf_dimension <- 1 +num_trees <- 200 +feature_types <- as.integer(rep(0, p)) # 0 = numeric +variable_weights <- rep(1 / p, p) + +# Initialize config objects +forest_model_config <- createForestModelConfig( + feature_types = feature_types, + num_trees = num_trees, + min_samples_leaf = 5, + num_features = p, + num_observations = n, + variable_weights = variable_weights, + leaf_dimension = leaf_dimension, + leaf_model_type = outcome_model_type +) +global_model_config <- createGlobalModelConfig( + global_error_variance = sigma2_init +) +``` + +## Python + +```{python} +# Set parameters +outcome_model_type = 0 +leaf_dimension = 1 +num_trees = 200 +feature_types = np.repeat(0, p).astype(int) # 0 = numeric +var_weights = np.repeat(1 / p, p) + +# Initialize config objects +forest_model_config = ForestModelConfig( + feature_types=feature_types, + num_trees=num_trees, + num_features=p, + num_observations=n, + variable_weights=var_weights, + leaf_dimension=leaf_dimension, + leaf_model_type=outcome_model_type, +) +global_model_config = GlobalModelConfig(global_error_variance=sigma2_init) +``` + +:::: + + +Next, we initialize forest model / sampler objects which dispatch the sampling algorithms + +::::{.panel-tabset group="language"} + +## R + +```{r} +forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) +``` + +## Python + +```{python} +forest_sampler = ForestSampler( + forest_dataset, + global_model_config, + forest_model_config +) +``` + +:::: + +Initialize both the (empty) container of retained forest samples and the "active forest." + +We set the leaf node values for every (single-node) tree in the active forest so that they sum to the mean of the scaled outcome (which is 0 since it was centered). + +::::{.panel-tabset group="language"} + +## R + +```{r} +# Create forest container and active forest +forest_samples <- createForestSamples(num_trees, 1, T) +active_forest <- createForest(num_trees, 1, T) + +# Initialize the leaves of each tree in the active forest +leaf_init <- mean(y_standardized) +active_forest$prepare_for_sampler( + forest_dataset, + outcome, + forest_model, + outcome_model_type, + leaf_init +) +``` + +## Python + +```{python} +# Create forest container and active forest +forest_container = ForestContainer(num_trees, leaf_dimension, True, False) +active_forest = Forest(num_trees, leaf_dimension, True, False) + +# Initialize the leaves of each tree in the active forest +leaf_init = np.mean(y_standardized, keepdims=True) +forest_sampler.prepare_for_sampler( + forest_dataset, + residual, + active_forest, + outcome_model_type, + leaf_init, +) +``` + +:::: + +We prepare to run the sampler by initialize empty containers for all of the parametric components of the model (and other intermediate values we track such as RMSE and predicted values). + +::::{.panel-tabset group="language"} + +## R + +```{r} +num_burnin <- 3000 +num_mcmc <- 1000 +sigma2_samples <- rep(NA, num_mcmc) +a2_samples <- rep(NA, num_mcmc) +tau2_samples <- rep(NA, num_mcmc) +phi_i_samples <- matrix(NA, n, num_mcmc) +rmse_samples <- rep(0, num_mcmc) +fhat_samples <- matrix(0, n, num_mcmc) +current_sigma2 <- sigma2_init +current_a2 <- a2_init +current_tau2 <- tau2_init +current_phi_i <- phi_i_init +``` + +## Python + +```{python} +num_burnin = 3000 +num_mcmc = 1000 +sigma2_samples = np.empty(num_mcmc) +a2_samples = np.empty(num_mcmc) +tau2_samples = np.empty(num_mcmc) +phi_i_samples = np.empty((n, num_mcmc)) +rmse_samples = np.empty(num_mcmc) +fhat_samples = np.empty((n, num_mcmc)) +current_sigma2 = sigma2_init +current_a2 = a2_init +current_tau2 = tau2_init +current_phi_i = phi_i_init +``` + +:::: + +Run an MCMC sampler + +::::{.panel-tabset group="language"} + +## R + +```{r} +for (i in 1:(num_burnin + num_mcmc)) { + keep_sample <- i > num_burnin + + # Sample forest + forest_model$sample_one_iteration( + forest_dataset, + outcome, + forest_samples, + active_forest, + rng, + forest_model_config, + global_model_config, + keep_forest = keep_sample, + gfr = F, + num_threads = 1 + ) + + # Sample local variance parameters + current_phi_i <- sample_phi_i( + y_standardized, + forest_dataset, + active_forest, + current_a2, + current_tau2, + nu + ) + + # Sample a2 + current_a2 <- sample_a2( + y_standardized, + forest_dataset, + active_forest, + current_phi_i + ) + if (keep_sample) { + a2_samples[i - num_burnin] <- current_a2 * y_std^2 + } + + # Sample tau2 + current_tau2 <- sample_tau2(current_phi_i, nu) + if (keep_sample) { + tau2_samples[i - num_burnin] <- current_tau2 * y_std^2 + sigma2_samples[i - num_burnin] <- current_tau2 * current_a2 * y_std^2 + } + + # Update observation-specific variance weights + forest_dataset$update_variance_weights(current_phi_i * current_a2) + + # Compute in-sample RMSE and cache mean function samples + if (keep_sample) { + yhat_forest <- active_forest$predict(forest_dataset) * y_std + y_bar + error <- (m_x - yhat_forest) + rmse_samples[i - num_burnin] <- sqrt(mean(error * error)) + fhat_samples[, i - num_burnin] <- yhat_forest + } +} +``` + +## Python + +```{python} +keep_sample = False +for i in range(num_burnin + num_mcmc): + if i >= num_burnin: + keep_sample = True + + # Sample from the forest + forest_sampler.sample_one_iteration( + forest_container=forest_container, + forest=active_forest, + dataset=forest_dataset, + residual=residual, + rng=cpp_rng, + global_config=global_model_config, + forest_config=forest_model_config, + keep_forest=keep_sample, + gfr=False, + num_threads=1 + ) + + # Sample local variance parameters + current_phi_i = sample_phi_i( + y_standardized, + forest_dataset, + active_forest, + current_a2, + current_tau2, + nu, + rng, + ) + + # Sample a2 + current_a2 = sample_a2( + y_standardized, + forest_dataset, + active_forest, + current_phi_i, + rng, + ) + + # Sample tau2 + current_tau2 = sample_tau2(current_phi_i, nu, rng) + if keep_sample: + tau2_samples[i - num_burnin] = current_tau2 * y_std * y_std + sigma2_samples[i - num_burnin] = current_tau2 * current_a2 * y_std * y_std + + # Update observation-specific variance weights + forest_dataset.update_variance_weights(current_phi_i * current_a2) + + # Compute in-sample RMSE and cache mean function samples + if keep_sample: + yhat_forest = active_forest.predict(forest_dataset) * y_std + y_bar + error = m_x - yhat_forest + rmse_samples[i - num_burnin] = np.sqrt(np.mean(error * error)) + fhat_samples[:, i - num_burnin] = yhat_forest +``` + +:::: + +Compute posterior mean of the conditional expectations for the non-robust model + +::::{.panel-tabset group="language"} + +## R + +```{r} +m_x_hat_posterior_mean <- rowMeans(fhat_samples) +``` + +## Python + +```{python} +m_x_hat_posterior_mean = np.mean(fhat_samples, axis=1) +``` + +:::: + +For comparison, we run the same model without robust errors + +::::{.panel-tabset group="language"} + +## R + +```{r} +# Initial value of global error variance parameter +sigma2_init <- 1.0 + +# Initialize data objects +forest_dataset <- createForestDataset(X) +outcome <- createOutcome(y_standardized) + +# Random number generator (std::mt19937) +rng <- createCppRNG(random_seed) + +# Model configuration +outcome_model_type <- 0 +leaf_dimension <- 1 +num_trees <- 200 +feature_types <- as.integer(rep(0, p)) # 0 = numeric +variable_weights <- rep(1 / p, p) +forest_model_config <- createForestModelConfig( + feature_types = feature_types, + num_trees = num_trees, + num_features = p, + min_samples_leaf = 5, + num_observations = n, + variable_weights = variable_weights, + leaf_dimension = leaf_dimension, + leaf_model_type = outcome_model_type +) +global_model_config <- createGlobalModelConfig( + global_error_variance = sigma2_init +) + +# Forest model object +forest_model <- createForestModel( + forest_dataset, + forest_model_config, + global_model_config +) + +# "Active forest" (which gets updated by the sample) and +# container of forest samples (which is written to when +# a sample is not discarded due to burn-in / thinning) +forest_samples <- createForestSamples(num_trees, 1, T) +active_forest <- createForest(num_trees, 1, T) + +# Initialize the leaves of each tree in the forest +leaf_init <- mean(y_standardized) +active_forest$prepare_for_sampler( + forest_dataset, + outcome, + forest_model, + outcome_model_type, + leaf_init +) +active_forest$adjust_residual(forest_dataset, outcome, forest_model, F, F) + +# Prepare to run the sampler +global_var_samples <- rep(NA, num_mcmc) +rmse_samples_non_robust <- rep(0, num_mcmc) +fhat_samples_non_robust <- matrix(0, n, num_mcmc) +current_sigma2 <- sigma2_init + +# Run the MCMC sampler +for (i in 1:(num_burnin + num_mcmc)) { + keep_sample <- i > num_burnin + + # Sample forest + forest_model$sample_one_iteration( + forest_dataset, + outcome, + forest_samples, + active_forest, + rng, + forest_model_config, + global_model_config, + keep_forest = keep_sample, + gfr = F, + num_threads = 1 + ) + + # Sample global error variance parameter + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( + outcome, + forest_dataset, + rng, + 1, + 1 + ) + global_model_config$update_global_error_variance(current_sigma2) + if (keep_sample) { + global_var_samples[i - num_burnin] <- current_sigma2 * y_std^2 + } + + # Compute in-sample RMSE + if (keep_sample) { + yhat_forest <- active_forest$predict(forest_dataset) * y_std + y_bar + error <- (m_x - yhat_forest) + rmse_samples_non_robust[i - num_burnin] <- sqrt(mean(error * error)) + fhat_samples_non_robust[, i - num_burnin] <- yhat_forest + } +} +``` + +## Python + +```{python} +# Initial value of global error variance parameter +sigma2_init = 1.0 + +# Initialize data objects +forest_dataset = Dataset() +forest_dataset.add_covariates(X) +residual = Residual(y_standardized) + +# Random number generator (std::mt19937) +cpp_rng = RNG(random_seed) + +# Model configuration +outcome_model_type = 0 +leaf_dimension = 1 +num_trees = 200 +feature_types = np.repeat(0, p).astype(int) # 0 = numeric +var_weights = np.repeat(1 / p, p) +global_model_config = GlobalModelConfig(global_error_variance=sigma2_init) +forest_model_config = ForestModelConfig( + feature_types=feature_types, + num_trees=num_trees, + num_features=p, + num_observations=n, + variable_weights=var_weights, + leaf_dimension=leaf_dimension, + leaf_model_type=outcome_model_type, +) + +# Forest model object +forest_sampler = ForestSampler(forest_dataset, global_model_config, forest_model_config) + +# "Active forest" (which gets updated by the sample) and +# container of forest samples (which is written to when +# a sample is not discarded due to burn-in / thinning) +active_forest = Forest(num_trees, leaf_dimension, True, False) +forest_container = ForestContainer(num_trees, leaf_dimension, True, False) + +# Initialize the leaves of each tree in the mean forest +leaf_init = np.mean(y_standardized, keepdims=True) +forest_sampler.prepare_for_sampler( + forest_dataset, + residual, + active_forest, + outcome_model_type, + leaf_init, +) + +# Global error variance model +global_var_model = GlobalVarianceModel() + +# Prepare to run the sampler +num_burnin = 3000 +num_mcmc = 1000 +sigma2_samples_non_robust = np.empty(num_mcmc) +rmse_samples_non_robust = np.empty(num_mcmc) +fhat_samples_non_robust = np.empty((n, num_mcmc)) +current_sigma2 = sigma2_init + +# Run the MCMC sampler +keep_sample = False +for i in range(num_burnin + num_mcmc): + if i >= num_burnin: + keep_sample = True + + # Sample from the forest + forest_sampler.sample_one_iteration( + forest_container=forest_container, + forest=active_forest, + dataset=forest_dataset, + residual=residual, + rng=cpp_rng, + global_config=global_model_config, + forest_config=forest_model_config, + keep_forest=keep_sample, + gfr=False, + num_threads=1 + ) + + # Sample global variance parameter + current_sigma2 = global_var_model.sample_one_iteration(residual, cpp_rng, 1.0, 1.0) + global_model_config.update_global_error_variance(current_sigma2) + if keep_sample: + sigma2_samples_non_robust[i - num_burnin] = current_sigma2 * y_std * y_std + + # Compute in-sample RMSE and cache mean function samples + if keep_sample: + yhat_forest = active_forest.predict(forest_dataset) * y_std + y_bar + error = m_x - yhat_forest + rmse_samples_non_robust[i - num_burnin] = np.sqrt(np.mean(error * error)) + fhat_samples_non_robust[:, i - num_burnin] = yhat_forest +``` + +:::: + +## Results + +Plot RMSE samples side-by-side + +::::{.panel-tabset group="language"} + +## R + +```{r} +par(mar = c(4, 4, 0.5, 0.5)) +y_bounds <- range(c(rmse_samples, rmse_samples_non_robust)) +y_bounds[2] <- y_bounds[2] * 1.25 +plot( + rmse_samples, + type = "l", + col = "blue", + ylim = y_bounds, + ylab = "In-Sample RMSE", + xlab = "Iteration" +) +lines(rmse_samples_non_robust, col = "red") +legend( + "topleft", + legend = c("Gaussian Errors", "t Errors"), + col = c("red", "blue"), + lty = 1 +) +``` + +## Python + +```{python} +y_bounds = ( + np.min([rmse_samples, rmse_samples_non_robust]) * 0.8, + np.max([rmse_samples, rmse_samples_non_robust]) * 1.25, +) +plt.ylim(y_bounds) +plt.plot(rmse_samples, label="t Errors", color="blue") +plt.plot( + rmse_samples_non_robust, + label="Gaussian Errors", + color="red", +) +plt.ylabel("In-Sample RMSE") +plt.xlabel("Iteration") +plt.legend(loc="upper left") +plt.tight_layout() +plt.show() +``` + +:::: + +Compute the posterior mean of conditional expectations for the non-robust model and compare to the robust model + +::::{.panel-tabset group="language"} + +## R + +```{r} +m_x_hat_posterior_mean_non_robust <- rowMeans(fhat_samples_non_robust) +par(mar = c(4, 4, 0.5, 0.5)) +y_bounds <- range(m_x) +y_bounds[2] <- y_bounds[2] * 1.1 +plot( + m_x_hat_posterior_mean_non_robust, + m_x, + pch = 20, + col = 'lightgray', + xlab = 'Predicted f(x)', + ylab = 'True f(x)', + ylim = y_bounds +) +abline(0, 1) +points(m_x_hat_posterior_mean, m_x, pch = 20, cex = 0.5) +legend( + "topleft", + legend = c('Gaussian errors', 't errors'), + pch = c(20, 20), + col = c('lightgray', 'black') +) +``` + +## Python + +```{python} +m_x_hat_posterior_mean_non_robust = np.mean(fhat_samples_non_robust, axis=1) +margin = 0.05 * (np.max(m_x) - np.min(m_x)) +y_bounds = (np.min(m_x) - margin, np.max(m_x) + margin) +plt.ylim(y_bounds) +plt.scatter( + m_x_hat_posterior_mean_non_robust, m_x, label="Gaussian Errors", color="lightgray" +) +plt.scatter(m_x_hat_posterior_mean, m_x, label="t Errors", color="black") +plt.axline((np.mean(m_x), np.mean(m_x)), slope=1, color="black", linestyle=(0, (3, 3))) +plt.ylabel("True f(x)") +plt.xlabel("Predicted f(x)") +plt.legend(loc="upper left") +plt.tight_layout() +``` + +:::: + +# References diff --git a/vignettes/ensemble-kernel.qmd b/vignettes/ensemble-kernel.qmd new file mode 100644 index 000000000..ab399f8b9 --- /dev/null +++ b/vignettes/ensemble-kernel.qmd @@ -0,0 +1,534 @@ +--- +title: "Using Shared Leaf Membership as a Kernel" +bibliography: vignettes.bib +execute: + freeze: auto # re-render only when source changes +--- + +```{r} +#| include: false +reticulate::use_python( + Sys.getenv( + "RETICULATE_PYTHON", + unset = file.path(rprojroot::find_root(rprojroot::has_file(".here")), ".venv", "bin", "python") + ), + required = TRUE +) +``` + +A trained tree ensemble with strong out-of-sample performance admits a natural +motivation for the "distance" between two samples: shared leaf membership. +This vignette demonstrates how to extract a kernel matrix from a fitted `stochtree` +ensemble and use it for Gaussian process inference. + +# Motivation + +We number the leaves in an ensemble from 1 to $s$ (that is, if tree 1 has 3 leaves, +it reserves the numbers 1 - 3, and in turn if tree 2 has 5 leaves, it reserves the +numbers 4 - 8 to label its leaves, and so on). For a dataset with $n$ observations, +we construct the matrix $W$ as follows: + + + +::: {.algorithm style="border-top: 2px solid; border-bottom: 2px solid; padding: 0.6em 1em; margin: 1.5em 0;"} + +Initialize $W$ as a matrix of all zeroes with $n$ rows and as many columns as leaves in the ensemble + +Let `s` = 0 + +For $j$ in $\left\{1,\dots,m\right\}$: + +:::: {style="margin-left: 2em"} +Let `num_leaves` be the number of leaves in tree $j$ + +For $i$ in $\left\{1,\dots,n\right\}$: + +::::: {style="margin-left: 2em"} +Let `k` be the leaf to which tree $j$ maps observation $i$ + +Set element $W_{i,k+s} = 1$ +::::: + +Let `s` = `s + num_leaves` +:::: + +::: + +This sparse matrix $W$ is a matrix representation of the basis predictions of an +ensemble (i.e. integrating out the leaf parameters and just analyzing the leaf +indices). For an ensemble with $m$ trees, we can determine the proportion of trees +that map each observation to the same leaf by computing $W W^T / m$. This can form +the basis for a kernel function used in a Gaussian process regression, as we +demonstrate below. + +# Setup + +Load necessary packages + +::::{.panel-tabset group="language"} + +## R + +```{r} +library(stochtree) +library(tgp) +library(MASS) +library(Matrix) +library(mvtnorm) +``` + +## Python + +```{python} +import numpy as np +import matplotlib.pyplot as plt +from scipy.sparse import csr_matrix +from sklearn.gaussian_process import GaussianProcessRegressor +from sklearn.gaussian_process.kernels import RBF, WhiteKernel +from sklearn.datasets import make_friedman1 + +from stochtree import BARTModel, compute_forest_leaf_indices +``` + +:::: + +Set a seed for reproducibility + +::::{.panel-tabset group="language"} + +## R + +```{r} +random_seed <- 101 +set.seed(random_seed) +``` + +## Python + +```{python} +random_seed = 101 +rng = np.random.default_rng(random_seed) +``` + +:::: + + +# Demo 1: Univariate Supervised Learning + +We begin with a non-stationary simulated DGP with a single numeric covariate, +originally described in @gramacy2010categorical. We define a training set and test +set and evaluate various approaches to modeling the out-of-sample outcome. + +## Traditional Gaussian Process + +::::{.panel-tabset group="language"} + +## R + +```{r} +#| results: hide +# Generate the data +X_train <- seq(0,20,length=100) +X_test <- seq(0,20,length=99) +y_train <- (sin(pi*X_train/5) + 0.2*cos(4*pi*X_train/5)) * (X_train <= 9.6) +lin_train <- X_train>9.6; +y_train[lin_train] <- -1 + X_train[lin_train]/10 +y_train <- y_train + rnorm(length(y_train), sd=0.1) +y_test <- (sin(pi*X_test/5) + 0.2*cos(4*pi*X_test/5)) * (X_test <= 9.6) +lin_test <- X_test>9.6; +y_test[lin_test] <- -1 + X_test[lin_test]/10 + +# Fit the GP +model_gp <- bgp(X=X_train, Z=y_train, XX=X_test) +plot(model_gp$ZZ.mean, y_test, xlab = "predicted", ylab = "actual", main = "Gaussian process") +abline(0,1,lwd=2.5,lty=3,col="red") +``` + +## Python + +```{python} +# Generate the data +X_train_1d = np.linspace(0, 20, 100) +X_test_1d = np.linspace(0, 20, 99) + +y_train_1 = (np.sin(np.pi * X_train_1d / 5) + 0.2 * np.cos(4 * np.pi * X_train_1d / 5)) * (X_train_1d <= 9.6) +y_train_1[X_train_1d > 9.6] = -1 + X_train_1d[X_train_1d > 9.6] / 10 +y_train_1 = y_train_1 + rng.normal(0, 0.1, len(X_train_1d)) + +y_test_1 = (np.sin(np.pi * X_test_1d / 5) + 0.2 * np.cos(4 * np.pi * X_test_1d / 5)) * (X_test_1d <= 9.6) +y_test_1[X_test_1d > 9.6] = -1 + X_test_1d[X_test_1d > 9.6] / 10 + +# sklearn's GaussianProcessRegressor is used here in place of R's tgp::bgp +X_train_2d = X_train_1d.reshape(-1, 1) +X_test_2d = X_test_1d.reshape(-1, 1) +gp_kernel = RBF(length_scale=1.0) + WhiteKernel(noise_level=0.01) +model_gp_1 = GaussianProcessRegressor(kernel=gp_kernel, n_restarts_optimizer=5, + random_state=random_seed) +model_gp_1.fit(X_train_2d, y_train_1) +gp_pred_1 = model_gp_1.predict(X_test_2d) + +plt.scatter(gp_pred_1, y_test_1, alpha=0.5) +lo, hi = min(gp_pred_1.min(), y_test_1.min()), max(gp_pred_1.max(), y_test_1.max()) +plt.plot([lo, hi], [lo, hi], color="red", linestyle="dashed", linewidth=2) +plt.xlabel("Predicted") +plt.ylabel("Actual") +plt.title("Gaussian process") +plt.show() +``` + +:::: + +Assess the RMSE + +::::{.panel-tabset group="language"} + +## R + +```{r} +sqrt(mean((model_gp$ZZ.mean - y_test)^2)) +``` + +## Python + +```{python} +print(f"RMSE: {np.sqrt(np.mean((gp_pred_1 - y_test_1)**2)):.4f}") +``` + +:::: + +## BART-based Gaussian Process + +::::{.panel-tabset group="language"} + +## R + +```{r} +# Run BART on the data +num_trees <- 200 +sigma_leaf <- 1/num_trees +X_train <- as.data.frame(X_train) +X_test <- as.data.frame(X_test) +colnames(X_train) <- colnames(X_test) <- "x1" +general_params <- list(num_threads=1) +mean_forest_params <- list(num_trees=num_trees, sigma2_leaf_init=sigma_leaf) +bart_model <- bart(X_train=X_train, y_train=y_train, X_test=X_test, + general_params = general_params, mean_forest_params = mean_forest_params) + +# Extract kernels needed for kriging +leaf_mat_train <- computeForestLeafIndices(bart_model, X_train, forest_type = "mean", + forest_inds = bart_model$model_params$num_samples - 1) +leaf_mat_test <- computeForestLeafIndices(bart_model, X_test, forest_type = "mean", + forest_inds = bart_model$model_params$num_samples - 1) +W_train <- sparseMatrix(i=rep(1:length(y_train),num_trees), j=leaf_mat_train + 1, x=1) +W_test <- sparseMatrix(i=rep(1:length(y_test),num_trees), j=leaf_mat_test + 1, x=1) +Sigma_11 <- tcrossprod(W_test) / num_trees +Sigma_12 <- tcrossprod(W_test, W_train) / num_trees +Sigma_22 <- tcrossprod(W_train) / num_trees +Sigma_22_inv <- ginv(as.matrix(Sigma_22)) +Sigma_21 <- t(Sigma_12) + +# Compute mean and covariance for the test set posterior +mu_tilde <- Sigma_12 %*% Sigma_22_inv %*% y_train +Sigma_tilde <- as.matrix((sigma_leaf)*(Sigma_11 - Sigma_12 %*% Sigma_22_inv %*% Sigma_21)) + +# Sample from f(X_test) | X_test, X_train, f(X_train) +gp_samples <- mvtnorm::rmvnorm(1000, mean = mu_tilde, sigma = Sigma_tilde) + +# Compute posterior mean predictions for f(X_test) +yhat_mean_test <- colMeans(gp_samples) +plot(yhat_mean_test, y_test, xlab = "predicted", ylab = "actual", main = "BART Gaussian process") +abline(0,1,lwd=2.5,lty=3,col="red") +``` + +## Python + +```{python} +# Run BART on the data +num_trees = 200 +sigma_leaf = 1 / num_trees +general_params = {"num_threads": 1} +mean_forest_params = {"num_trees": num_trees, "sigma2_leaf_init": sigma_leaf} +bart_model_1 = BARTModel() +bart_model_1.sample(X_train=X_train_2d, y_train=y_train_1, X_test=X_test_2d, + general_params=general_params, mean_forest_params=mean_forest_params) + +# Extract leaf indices for the last retained forest sample +last_sample = bart_model_1.num_samples - 1 +leaf_mat_train_1 = compute_forest_leaf_indices(bart_model_1, X_train_2d, + forest_type="mean", forest_inds=last_sample) +leaf_mat_test_1 = compute_forest_leaf_indices(bart_model_1, X_test_2d, + forest_type="mean", forest_inds=last_sample) + +# Build sparse W matrices (rows = observations, cols = global leaf indices) +n_train_1, n_test_1 = len(y_train_1), len(y_test_1) +col_inds_train = leaf_mat_train_1.flatten() +col_inds_test = leaf_mat_test_1.flatten() +max_col = max(col_inds_train.max(), col_inds_test.max()) + 1 +W_train_1 = csr_matrix( + (np.ones(len(col_inds_train)), (np.tile(np.arange(n_train_1), num_trees), col_inds_train)), + shape=(n_train_1, max_col), +) +W_test_1 = csr_matrix( + (np.ones(len(col_inds_test)), (np.tile(np.arange(n_test_1), num_trees), col_inds_test)), + shape=(n_test_1, max_col), +) + +# Compute kernel matrices +W_tr = W_train_1.toarray() +W_te = W_test_1.toarray() +Sigma_22 = (W_tr @ W_tr.T) / num_trees +Sigma_11 = (W_te @ W_te.T) / num_trees +Sigma_12 = (W_te @ W_tr.T) / num_trees +Sigma_21 = Sigma_12.T +Sigma_22_inv = np.linalg.pinv(Sigma_22) + +# Compute GP posterior mean and covariance +mu_tilde = Sigma_12 @ Sigma_22_inv @ y_train_1 +Sigma_tilde = sigma_leaf * (Sigma_11 - Sigma_12 @ Sigma_22_inv @ Sigma_21) +Sigma_tilde += 1e-8 * np.eye(n_test_1) # small jitter for numerical stability + +# Sample from f(X_test) | X_test, X_train, f(X_train) +gp_samples_1 = rng.multivariate_normal(mu_tilde, Sigma_tilde, size=1000, method="eigh") + +# Posterior mean predictions +yhat_mean_test_1 = gp_samples_1.mean(axis=0) +lo = min(yhat_mean_test_1.min(), y_test_1.min()) +hi = max(yhat_mean_test_1.max(), y_test_1.max()) +plt.scatter(yhat_mean_test_1, y_test_1, alpha=0.5) +plt.plot([lo, hi], [lo, hi], color="red", linestyle="dashed", linewidth=2) +plt.xlabel("Predicted") +plt.ylabel("Actual") +plt.title("BART Gaussian process") +plt.show() +``` + +:::: + +Assess the RMSE + +::::{.panel-tabset group="language"} + +## R + +```{r} +sqrt(mean((yhat_mean_test - y_test)^2)) +``` + +## Python + +```{python} +print(f"RMSE: {np.sqrt(np.mean((yhat_mean_test_1 - y_test_1)**2)):.4f}") +``` + +:::: + +# Demo 2: Multivariate Supervised Learning + +We proceed to the simulated "Friedman" dataset (@friedman1991multivariate). In R, this +is accessed via `tgp::friedman.1.data`; in Python we use +`sklearn.datasets.make_friedman1`, which implements the same DGP. + +## Traditional Gaussian Process + +::::{.panel-tabset group="language"} + +## R + +```{r} +#| results: hide +# Generate the data, add many "noise variables" +n <- 100 +friedman.df <- friedman.1.data(n=n) +train_inds <- sort(sample(1:n, floor(0.8*n), replace = FALSE)) +test_inds <- (1:n)[!((1:n) %in% train_inds)] +X <- as.matrix(friedman.df)[,1:10] +X <- cbind(X, matrix(runif(n*10), ncol = 10)) +y <- as.matrix(friedman.df)[,12] + rnorm(n,0,1)*(sd(as.matrix(friedman.df)[,11])/2) +X_train <- X[train_inds,] +X_test <- X[test_inds,] +y_train <- y[train_inds] +y_test <- y[test_inds] + +# Fit the GP +model_gp <- bgp(X=X_train, Z=y_train, XX=X_test) +plot(model_gp$ZZ.mean, y_test, xlab = "predicted", ylab = "actual", main = "Gaussian process") +abline(0,1,lwd=2.5,lty=3,col="red") +``` + +## Python + +```{python} +# Generate the data: 10 Friedman features + 10 noise features +n = 100 +X_raw, y_friedman = make_friedman1(n_samples=n, n_features=10, noise=1.0, + random_state=random_seed) +X_2 = np.hstack([X_raw, rng.uniform(size=(n, 10))]) # 20 features total +y_2 = y_friedman + +train_inds_2 = rng.choice(n, int(0.8 * n), replace=False) +test_inds_2 = np.setdiff1d(np.arange(n), train_inds_2) +X_train_2, X_test_2 = X_2[train_inds_2], X_2[test_inds_2] +y_train_2, y_test_2 = y_2[train_inds_2], y_2[test_inds_2] + +# sklearn's GaussianProcessRegressor is used here in place of R's tgp::bgp +gp_kernel_2 = RBF(length_scale=1.0) + WhiteKernel(noise_level=1.0) +model_gp_2 = GaussianProcessRegressor(kernel=gp_kernel_2, n_restarts_optimizer=1, + random_state=random_seed) +model_gp_2.fit(X_train_2, y_train_2) +gp_pred_2 = model_gp_2.predict(X_test_2) + +lo = min(gp_pred_2.min(), y_test_2.min()) +hi = max(gp_pred_2.max(), y_test_2.max()) +plt.scatter(gp_pred_2, y_test_2, alpha=0.6) +plt.plot([lo, hi], [lo, hi], color="red", linestyle="dashed", linewidth=2) +plt.xlabel("Predicted") +plt.ylabel("Actual") +plt.title("Gaussian process") +plt.show() +``` + +:::: + +Assess the RMSE + +::::{.panel-tabset group="language"} + +## R + +```{r} +sqrt(mean((model_gp$ZZ.mean - y_test)^2)) +``` + +## Python + +```{python} +print(f"RMSE: {np.sqrt(np.mean((gp_pred_2 - y_test_2)**2)):.4f}") +``` + +:::: + +## BART-based Gaussian Process + +::::{.panel-tabset group="language"} + +## R + +```{r} +# Run BART on the data +num_trees <- 200 +sigma_leaf <- 1/num_trees +X_train <- as.data.frame(X_train) +X_test <- as.data.frame(X_test) +general_params <- list(num_threads=1) +mean_forest_params <- list(num_trees=num_trees, sigma2_leaf_init=sigma_leaf) +bart_model <- bart(X_train=X_train, y_train=y_train, X_test=X_test, + general_params = general_params, mean_forest_params = mean_forest_params) + +# Extract kernels needed for kriging +leaf_mat_train <- computeForestLeafIndices(bart_model, X_train, forest_type = "mean", + forest_inds = bart_model$model_params$num_samples - 1) +leaf_mat_test <- computeForestLeafIndices(bart_model, X_test, forest_type = "mean", + forest_inds = bart_model$model_params$num_samples - 1) +W_train <- sparseMatrix(i=rep(1:length(y_train),num_trees), j=leaf_mat_train + 1, x=1) +W_test <- sparseMatrix(i=rep(1:length(y_test),num_trees), j=leaf_mat_test + 1, x=1) +Sigma_11 <- tcrossprod(W_test) / num_trees +Sigma_12 <- tcrossprod(W_test, W_train) / num_trees +Sigma_22 <- tcrossprod(W_train) / num_trees +Sigma_22_inv <- ginv(as.matrix(Sigma_22)) +Sigma_21 <- t(Sigma_12) + +# Compute mean and covariance for the test set posterior +mu_tilde <- Sigma_12 %*% Sigma_22_inv %*% y_train +Sigma_tilde <- as.matrix((sigma_leaf)*(Sigma_11 - Sigma_12 %*% Sigma_22_inv %*% Sigma_21)) + +# Sample from f(X_test) | X_test, X_train, f(X_train) +gp_samples <- mvtnorm::rmvnorm(1000, mean = mu_tilde, sigma = Sigma_tilde) + +# Compute posterior mean predictions for f(X_test) +yhat_mean_test <- colMeans(gp_samples) +plot(yhat_mean_test, y_test, xlab = "predicted", ylab = "actual", main = "BART Gaussian process") +abline(0,1,lwd=2.5,lty=3,col="red") +``` + +## Python + +```{python} +num_trees = 200 +sigma_leaf = 1 / num_trees +general_params = {"num_threads": 1} +mean_forest_params = {"num_trees": num_trees, "sigma2_leaf_init": sigma_leaf} +bart_model_2 = BARTModel() +bart_model_2.sample(X_train=X_train_2, y_train=y_train_2, X_test=X_test_2, + general_params=general_params, mean_forest_params=mean_forest_params) + +last_sample_2 = bart_model_2.num_samples - 1 +leaf_mat_train_2 = compute_forest_leaf_indices(bart_model_2, X_train_2, + forest_type="mean", forest_inds=last_sample_2) +leaf_mat_test_2 = compute_forest_leaf_indices(bart_model_2, X_test_2, + forest_type="mean", forest_inds=last_sample_2) + +n_train_2, n_test_2 = len(y_train_2), len(y_test_2) +col_inds_train_2 = leaf_mat_train_2.flatten() +col_inds_test_2 = leaf_mat_test_2.flatten() +max_col_2 = max(col_inds_train_2.max(), col_inds_test_2.max()) + 1 +W_train_2 = csr_matrix( + (np.ones(len(col_inds_train_2)), + (np.tile(np.arange(n_train_2), num_trees), col_inds_train_2)), + shape=(n_train_2, max_col_2), +) +W_test_2 = csr_matrix( + (np.ones(len(col_inds_test_2)), + (np.tile(np.arange(n_test_2), num_trees), col_inds_test_2)), + shape=(n_test_2, max_col_2), +) + +W_tr2 = W_train_2.toarray() +W_te2 = W_test_2.toarray() +Sigma_22_2 = (W_tr2 @ W_tr2.T) / num_trees +Sigma_11_2 = (W_te2 @ W_te2.T) / num_trees +Sigma_12_2 = (W_te2 @ W_tr2.T) / num_trees +Sigma_21_2 = Sigma_12_2.T +Sigma_22_inv_2 = np.linalg.pinv(Sigma_22_2) + +mu_tilde_2 = Sigma_12_2 @ Sigma_22_inv_2 @ y_train_2 +Sigma_tilde_2 = sigma_leaf * (Sigma_11_2 - Sigma_12_2 @ Sigma_22_inv_2 @ Sigma_21_2) +Sigma_tilde_2 += 1e-8 * np.eye(n_test_2) + +gp_samples_2 = rng.multivariate_normal(mu_tilde_2, Sigma_tilde_2, size=1000, method="eigh") +yhat_mean_test_2 = gp_samples_2.mean(axis=0) + +lo = min(yhat_mean_test_2.min(), y_test_2.min()) +hi = max(yhat_mean_test_2.max(), y_test_2.max()) +plt.scatter(yhat_mean_test_2, y_test_2, alpha=0.6) +plt.plot([lo, hi], [lo, hi], color="red", linestyle="dashed", linewidth=2) +plt.xlabel("Predicted") +plt.ylabel("Actual") +plt.title("BART Gaussian process") +plt.show() +``` + +:::: + +Assess the RMSE + +::::{.panel-tabset group="language"} + +## R + +```{r} +sqrt(mean((yhat_mean_test - y_test)^2)) +``` + +## Python + +```{python} +print(f"RMSE: {np.sqrt(np.mean((yhat_mean_test_2 - y_test_2)**2)):.4f}") +``` + +:::: + +While the use case of a BART kernel for classical kriging is perhaps unclear without +more empirical investigation, the kernel approach can be very beneficial for causal +inference applications. + +# References diff --git a/vignettes/heteroskedastic.qmd b/vignettes/heteroskedastic.qmd new file mode 100644 index 000000000..12ca5a9ba --- /dev/null +++ b/vignettes/heteroskedastic.qmd @@ -0,0 +1,735 @@ +--- +title: "BART with a Forest-based Variance Model" +bibliography: vignettes.bib +execute: + freeze: auto # re-render only when source changes +--- + +```{r} +#| include: false +reticulate::use_python( + Sys.getenv( + "RETICULATE_PYTHON", + unset = file.path(rprojroot::find_root(rprojroot::has_file(".here")), ".venv", "bin", "python") + ), + required = TRUE +) +``` + +This vignette demonstrates how to configure a "variance forest" in stochtree for modeling conditional variance (see @murray2021log). + +# Setup + +Load necessary packages + +::::{.panel-tabset group="language"} + +## R + +```{r} +library(stochtree) +``` + +## Python + +```{python} +import numpy as np +import matplotlib.pyplot as plt +from stochtree import BARTModel +``` + +:::: + +Set a random seed + +::::{.panel-tabset group="language"} + +## R + +```{r} +random_seed = 1234 +set.seed(random_seed) +``` + +## Python + +```{python} +random_seed = 1234 +rng = np.random.default_rng(random_seed) +``` + +:::: + +# Demo 1: Variance-Only Simulation (Simple DGP) + +Here, we generate data with a constant (zero) mean and a relatively simple +covariate-modified variance function. + +\begin{equation*} +\begin{aligned} +y &= 0 + \sigma(X) \epsilon\\ +\sigma^2(X) &= \begin{cases} +0.5 & X_1 \geq 0 \text{ and } X_1 < 0.25\\ +1 & X_1 \geq 0.25 \text{ and } X_1 < 0.5\\ +2 & X_1 \geq 0.5 \text{ and } X_1 < 0.75\\ +3 & X_1 \geq 0.75 \text{ and } X_1 < 1\\ +\end{cases}\\ +X_1,\dots,X_p &\sim \text{U}\left(0,1\right)\\ +\epsilon &\sim \mathcal{N}\left(0,1\right) +\end{aligned} +\end{equation*} + +## Simulation + +Generate data from the DGP above + +::::{.panel-tabset group="language"} + +## R + +```{r} +n <- 1000 +p_x <- 10 +X <- matrix(runif(n * p_x), ncol = p_x) +f_XW <- 0 +s_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * + (0.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (1) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (3)) +y <- f_XW + rnorm(n, 0, 1) * s_XW +``` + +## Python + +```{python} +n, p_x = 1000, 10 +X = rng.uniform(size=(n, p_x)) +s_XW = ( + ((X[:, 0] >= 0) & (X[:, 0] < 0.25)) * 0.5 + + ((X[:, 0] >= 0.25) & (X[:, 0] < 0.5)) * 1.0 + + ((X[:, 0] >= 0.5) & (X[:, 0] < 0.75)) * 2.0 + + ((X[:, 0] >= 0.75) & (X[:, 0] < 1.0)) * 3.0 +) +y = rng.normal(size=n) * s_XW +``` + +:::: + +Split into train and test sets + +::::{.panel-tabset group="language"} + +## R + +```{r} +test_set_pct <- 0.2 +n_test <- round(test_set_pct * n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- as.data.frame(X[test_inds, ]) +X_train <- as.data.frame(X[train_inds, ]) +y_test <- y[test_inds] +y_train <- y[train_inds] +s_x_test <- s_XW[test_inds] +s_x_train <- s_XW[train_inds] +``` + +## Python + +```{python} +test_set_pct = 0.2 +n_test = round(test_set_pct * n) +test_inds = rng.choice(n, n_test, replace=False) +train_inds = np.setdiff1d(np.arange(n), test_inds) +X_test, X_train = X[test_inds], X[train_inds] +y_test, y_train = y[test_inds], y[train_inds] +s_x_test, s_x_train = s_XW[test_inds], s_XW[train_inds] +``` + +:::: + +## Sampling and Analysis + +We sample four chains of the $\sigma^2(X)$ forest using "warm-start" initialization (@he2023stochastic). + +We use fewer trees for the variance forest than typically used for mean forests, and we disable sampling a global error scale and omit the mean forest by setting `num_trees = 0` in its parameter list. + +::::{.panel-tabset group="language"} + +## R + +```{r} +num_gfr <- 10 +num_burnin <- 0 +num_mcmc <- 100 +num_trees <- 20 +num_samples <- num_gfr + num_burnin + num_mcmc +general_params <- list( + sample_sigma2_global = F, + num_chains = 4, + num_threads = 1, + random_seed = random_seed +) +mean_forest_params <- list(sample_sigma2_leaf = F, num_trees = 0) +variance_forest_params <- list(num_trees = num_trees) +bart_model <- stochtree::bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = num_gfr, + num_burnin = num_burnin, + num_mcmc = num_mcmc, + general_params = general_params, + mean_forest_params = mean_forest_params, + variance_forest_params = variance_forest_params +) +``` + +## Python + +```{python} +num_gfr = 10 +num_burnin = 0 +num_mcmc = 100 +num_trees = 20 +bart_model = BARTModel() +bart_model.sample( + X_train=X_train, + y_train=y_train, + X_test=X_test, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + general_params={ + "sample_sigma2_global": False, + "num_threads": 1, + "num_chains": 4, + "random_seed": random_seed, + }, + mean_forest_params={"sample_sigma2_leaf": False, "num_trees": 0}, + variance_forest_params={"num_trees": num_trees}, +) +``` + +:::: + +We inspect the model by plotting the true variance function against its forest-based predictions + +::::{.panel-tabset group="language"} + +## R + +```{r} +sigma2_x_hat_test <- predict( + bart_model, + X = X_test, + terms = "variance_forest", + type = "mean" +) +plot( + sigma2_x_hat_test, + s_x_test^2, + pch = 16, + cex = 0.75, + xlab = "Predicted", + ylab = "Actual", + main = "Variance function" +) +abline(0, 1, col = "red", lty = 2, lwd = 2.5) +``` + +## Python + +```{python} +sigma2_x_hat_test = bart_model.predict(X=X_test, terms="variance_forest", type="mean") +lo, hi = ( + min(sigma2_x_hat_test.min(), (s_x_test**2).min()), + max(sigma2_x_hat_test.max(), (s_x_test**2).max()), +) +plt.scatter(sigma2_x_hat_test, s_x_test**2, s=10, alpha=0.6) +plt.plot([lo, hi], [lo, hi], color="red", linestyle="dashed", linewidth=2) +plt.xlabel("Predicted") +plt.ylabel("Actual") +plt.title("Variance function") +plt.show() +``` + +:::: + +# Demo 2: Variance-Only Simulation (Complex DGP) + +Here, we generate data with a constant (zero) mean and a more complex +covariate-modified variance function. + +\begin{equation*} +\begin{aligned} +y &= 0 + \sigma(X) \epsilon\\ +\sigma^2(X) &= \begin{cases} +0.25X_3^2 & X_1 \geq 0 \text{ and } X_1 < 0.25\\ +1X_3^2 & X_1 \geq 0.25 \text{ and } X_1 < 0.5\\ +4X_3^2 & X_1 \geq 0.5 \text{ and } X_1 < 0.75\\ +9X_3^2 & X_1 \geq 0.75 \text{ and } X_1 < 1\\ +\end{cases}\\ +X_1,\dots,X_p &\sim \text{U}\left(0,1\right)\\ +\epsilon &\sim \mathcal{N}\left(0,1\right) +\end{aligned} +\end{equation*} + +## Simulation + +We generate data from the DGP above + +::::{.panel-tabset group="language"} + +## R + +```{r} +n <- 1000 +p_x <- 10 +X <- matrix(runif(n*p_x), ncol = p_x) +f_XW <- 0 +s_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.5*X[,3]) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (1*X[,3]) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2*X[,3]) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (3*X[,3]) +) +y <- f_XW + rnorm(n, 0, 1)*s_XW +``` + +## Python + +```{python} +n, p_x = 1000, 10 +X = rng.uniform(size=(n, p_x)) +# R's X[,3] = Python's X[:,2] +s_XW = ( + ((X[:, 0] >= 0) & (X[:, 0] < 0.25)) * (0.5 * X[:, 2]) + + ((X[:, 0] >= 0.25) & (X[:, 0] < 0.5)) * (1.0 * X[:, 2]) + + ((X[:, 0] >= 0.5) & (X[:, 0] < 0.75)) * (2.0 * X[:, 2]) + + ((X[:, 0] >= 0.75) & (X[:, 0] < 1.0)) * (3.0 * X[:, 2]) +) +y = rng.normal(size=n) * s_XW +``` + +:::: + +And split the data into train and test sets + +::::{.panel-tabset group="language"} + +## R + +```{r} +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- as.data.frame(X[test_inds,]) +X_train <- as.data.frame(X[train_inds,]) +y_test <- y[test_inds] +y_train <- y[train_inds] +s_x_test <- s_XW[test_inds] +s_x_train <- s_XW[train_inds] +``` + +## Python + +```{python} +test_set_pct = 0.2 +n_test = round(test_set_pct * n) +test_inds = rng.choice(n, n_test, replace=False) +train_inds = np.setdiff1d(np.arange(n), test_inds) +X_test, X_train = X[test_inds], X[train_inds] +y_test, y_train = y[test_inds], y[train_inds] +s_x_test, s_x_train = s_XW[test_inds], s_XW[train_inds] +``` + +:::: + +## Sampling and Analysis + +We sample four chains of the $\sigma^2(X)$ forest using "warm-start" initialization (@he2023stochastic). + +We use fewer trees for the variance forest than typically used for mean forests, and we disable sampling a global error scale and omit the mean forest by setting `num_trees = 0` in its parameter list. + +::::{.panel-tabset group="language"} + +## R + +```{r} +num_gfr <- 10 +num_burnin <- 0 +num_mcmc <- 100 +num_trees <- 20 +num_samples <- num_gfr + num_burnin + num_mcmc +general_params <- list( + sample_sigma2_global = F, + num_chains = 4, + num_threads = 1, + random_seed = random_seed +) +mean_forest_params <- list(sample_sigma2_leaf = F, num_trees = 0) +variance_forest_params <- list(num_trees = num_trees) +bart_model <- stochtree::bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = num_gfr, + num_burnin = num_burnin, + num_mcmc = num_mcmc, + general_params = general_params, + mean_forest_params = mean_forest_params, + variance_forest_params = variance_forest_params +) +``` + +## Python + +```{python} +num_gfr = 10 +num_burnin = 0 +num_mcmc = 100 +num_trees = 20 +bart_model = BARTModel() +bart_model.sample( + X_train=X_train, + y_train=y_train, + X_test=X_test, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + general_params={ + "sample_sigma2_global": False, + "num_threads": 1, + "num_chains": 4, + "random_seed": random_seed, + }, + mean_forest_params={"sample_sigma2_leaf": False, "num_trees": 0}, + variance_forest_params={"num_trees": num_trees}, +) +``` + +:::: + +We inspect the model by plotting the true variance function against its forest-based predictions + +::::{.panel-tabset group="language"} + +## R + +```{r} +sigma2_x_hat_test <- predict( + bart_model, + X = X_test, + terms = "variance_forest", + type = "mean" +) +plot( + sigma2_x_hat_test, + s_x_test^2, + pch = 16, + cex = 0.75, + xlab = "Predicted", + ylab = "Actual", + main = "Variance function" +) +abline(0, 1, col = "red", lty = 2, lwd = 2.5) +``` + +## Python + +```{python} +sigma2_x_hat_test = bart_model.predict(X=X_test, terms="variance_forest", type="mean") +lo, hi = ( + min(sigma2_x_hat_test.min(), (s_x_test**2).min()), + max(sigma2_x_hat_test.max(), (s_x_test**2).max()), +) +plt.scatter(sigma2_x_hat_test, s_x_test**2, s=10, alpha=0.6) +plt.plot([lo, hi], [lo, hi], color="red", linestyle="dashed", linewidth=2) +plt.xlabel("Predicted") +plt.ylabel("Actual") +plt.title("Variance function") +plt.show() +``` + +:::: + +# Demo 3: Mean and Variance Function Simulation + +Here, we generate data with (relatively simple) covariate-modified mean and variance +functions. + +\begin{equation*} +\begin{aligned} +y &= f(X) + \sigma(X) \epsilon\\ +f(X) &= \begin{cases} +-6 & X_2 \geq 0 \text{ and } X_2 < 0.25\\ +-2 & X_2 \geq 0.25 \text{ and } X_2 < 0.5\\ +2 & X_2 \geq 0.5 \text{ and } X_2 < 0.75\\ +6 & X_2 \geq 0.75 \text{ and } X_2 < 1\\ +\end{cases}\\ +\sigma^2(X) &= \begin{cases} +0.25 & X_1 \geq 0 \text{ and } X_1 < 0.25\\ +1 & X_1 \geq 0.25 \text{ and } X_1 < 0.5\\ +4 & X_1 \geq 0.5 \text{ and } X_1 < 0.75\\ +9 & X_1 \geq 0.75 \text{ and } X_1 < 1\\ +\end{cases}\\ +X_1,\dots,X_p &\sim \text{U}\left(0,1\right)\\ +\epsilon &\sim \mathcal{N}\left(0,1\right) +\end{aligned} +\end{equation*} + +## Simulation + +Generate data from the DGP above + +::::{.panel-tabset group="language"} + +## R + +```{r} +n <- 1000 +p_x <- 10 +X <- matrix(runif(n*p_x), ncol = p_x) +f_XW <- ( + ((0 <= X[,2]) & (0.25 > X[,2])) * (-6) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (-2) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (2) + + ((0.75 <= X[,2]) & (1 > X[,2])) * (6) +) +s_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (1) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (3) +) +y <- f_XW + rnorm(n, 0, 1)*s_XW +``` + +## Python + +```{python} +n, p_x = 1000, 10 +X = rng.uniform(size=(n, p_x)) +f_XW = ( + ((X[:, 1] >= 0) & (X[:, 1] < 0.25)) * (-6) + + ((X[:, 1] >= 0.25) & (X[:, 1] < 0.5)) * (-2) + + ((X[:, 1] >= 0.5) & (X[:, 1] < 0.75)) * (2) + + ((X[:, 1] >= 0.75) & (X[:, 1] < 1.0)) * (6) +) +s_XW = ( + ((X[:, 0] >= 0) & (X[:, 0] < 0.25)) * 0.5 + + ((X[:, 0] >= 0.25) & (X[:, 0] < 0.5)) * 1.0 + + ((X[:, 0] >= 0.5) & (X[:, 0] < 0.75)) * 2.0 + + ((X[:, 0] >= 0.75) & (X[:, 0] < 1.0)) * 3.0 +) +y = f_XW + rng.normal(size=n) * s_XW +``` + +:::: + +Split the data into train and test sets + +::::{.panel-tabset group="language"} + +## R + +```{r} +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- as.data.frame(X[test_inds,]) +X_train <- as.data.frame(X[train_inds,]) +y_test <- y[test_inds] +y_train <- y[train_inds] +f_x_test <- f_XW[test_inds] +s_x_test <- s_XW[test_inds] +``` + +## Python + +```{python} +test_set_pct = 0.2 +n_test = round(test_set_pct * n) +test_inds = rng.choice(n, n_test, replace=False) +train_inds = np.setdiff1d(np.arange(n), test_inds) +X_test, X_train = X[test_inds], X[train_inds] +y_test, y_train = y[test_inds], y[train_inds] +f_x_test = f_XW[test_inds] +s_x_test = s_XW[test_inds] +``` + +:::: + +## Sampling and Analysis + +As above, we sample four chains of the $\sigma^2(X)$ forest using "warm-start" initialization (@he2023stochastic), except we do not omit the mean forest by setting `num_trees = 0`. + +::::{.panel-tabset group="language"} + +## R + +```{r} +num_gfr <- 10 +num_burnin <- 0 +num_mcmc <- 100 +general_params <- list( + sample_sigma2_global = F, + num_threads = 1, + num_chains = 4, + random_seed = random_seed +) +mean_forest_params <- list( + sample_sigma2_leaf = F, + num_trees = 50, + alpha = 0.95, + beta = 2, + min_samples_leaf = 5 +) +variance_forest_params <- list( + num_trees = 50, + alpha = 0.95, + beta = 1.25, + min_samples_leaf = 5 +) +bart_model <- stochtree::bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = num_gfr, + num_burnin = num_burnin, + num_mcmc = num_mcmc, + general_params = general_params, + mean_forest_params = mean_forest_params, + variance_forest_params = variance_forest_params +) +``` + +## Python + +```{python} +bart_model = BARTModel() +bart_model.sample( + X_train=X_train, + y_train=y_train, + X_test=X_test, + num_gfr=10, + num_burnin=0, + num_mcmc=100, + general_params={ + "sample_sigma2_global": False, + "num_threads": 1, + "num_chains": 4, + "random_seed": random_seed, + }, + mean_forest_params={ + "sample_sigma2_leaf": False, + "num_trees": 50, + "alpha": 0.95, + "beta": 2, + "min_samples_leaf": 5, + }, + variance_forest_params={ + "num_trees": 50, + "alpha": 0.95, + "beta": 1.25, + "min_samples_leaf": 5, + }, +) +``` + +:::: + +We inspect the model by plotting the true variance function against the variance forest predictions + +::::{.panel-tabset group="language"} + +## R + +```{r} +sigma2_x_hat_test <- predict( + bart_model, + X = X_test, + terms = "variance_forest", + type = "mean" +) +plot( + sigma2_x_hat_test, + s_x_test^2, + pch = 16, + cex = 0.75, + xlab = "Predicted", + ylab = "Actual", + main = "Variance function" +) +abline(0, 1, col = "red", lty = 2, lwd = 2.5) +``` + +## Python + +```{python} +sigma2_x_hat_test = bart_model.predict(X=X_test, terms="variance_forest", type="mean") +lo, hi = ( + min(sigma2_x_hat_test.min(), (s_x_test**2).min()), + max(sigma2_x_hat_test.max(), (s_x_test**2).max()), +) +plt.scatter(sigma2_x_hat_test, s_x_test**2, s=10, alpha=0.6) +plt.plot([lo, hi], [lo, hi], color="red", linestyle="dashed", linewidth=2) +plt.xlabel("Predicted") +plt.ylabel("Actual") +plt.title("Variance function") +plt.show() +``` + +:::: + +We also plot the true outcome against mean forest predictions + +::::{.panel-tabset group="language"} + +## R + +```{r} +y_hat_test <- predict( + bart_model, + X = X_test, + terms = "y_hat", + type = "mean" +) +plot( + y_hat_test, + y_test, + pch = 16, + cex = 0.75, + xlab = "Predicted", + ylab = "Actual", + main = "Outcome" +) +abline(0, 1, col = "red", lty = 2, lwd = 2.5) +``` + +## Python + +```{python} +y_hat_test = bart_model.predict(X=X_test, terms="y_hat", type="mean") +lo, hi = ( + min(y_hat_test.min(), y_test.min()), + max(y_hat_test.max(), y_test.max()), +) +plt.scatter(y_hat_test, y_test, s=10, alpha=0.6) +plt.plot([lo, hi], [lo, hi], color="red", linestyle="dashed", linewidth=2) +plt.xlabel("Predicted") +plt.ylabel("Actual") +plt.title("Outcome") +plt.show() +``` + +:::: + +# References diff --git a/vignettes/index.qmd b/vignettes/index.qmd new file mode 100644 index 000000000..e8b63888e --- /dev/null +++ b/vignettes/index.qmd @@ -0,0 +1,42 @@ +--- +title: "StochTree Vignettes" +--- + +Extended worked examples for the `stochtree` package, covering core models, +practical topics, and advanced causal inference methods. Each vignette presents +R and Python implementations side-by-side. + +## Core Models + +| Vignette | Description | +|---|---| +| [BART](bart.qmd) | Bayesian Additive Regression Trees for Supervised Learning | +| [BCF](bcf.qmd) | Bayesian Causal Forests for Treatment Effect Estimation | +| [Heteroskedastic BART](heteroskedastic.qmd) | BART with a Forest-based Variance Model | +| [Ordinal Outcome Modeling](ordinal-outcome.qmd) | BART with the Complementary Log-Log Link for Ordinal Outcomes | +| [Multivariate Treatment BCF](multivariate-bcf.qmd) | BCF with Vector-valued Treatments | + +## Practical Topics + +| Vignette | Description | +|---|---| +| [Model Serialization](serialization.qmd) | Saving and Loading Fitted Models | +| [Multi-Chain Inference](multi-chain.qmd) | Running and Combining Multiple MCMC Chains | +| [Tree Inspection](tree-inspection.qmd) | Examining Individual Trees in a Fitted Ensemble | +| [Summary and Plotting](summary-plotting.qmd) | Posterior Summary and Visualization Utilities | +| [Prior Calibration](prior-calibration.qmd) | Calibrating Leaf Node Scale Parameter Priors | +| [Scikit-Learn Interface](sklearn.qmd) | Using Stochtree via Sklearn-Compatible Estimators in Python | + +## Low-Level Interface + +| Vignette | Description | +|---|---| +| [Custom Sampling Routine](custom-sampling.qmd) | Building a Custom Gibbs Sampler with Stochtree Primitives | +| [Ensemble Kernel](ensemble-kernel.qmd) | Using Shared Leaf Membership as a Kernel | + +## Advanced Methods + +| Vignette | Description | +|---|---| +| [Regression Discontinuity Design](rdd.qmd) | BARDDT: Leaf-Regression BART for RDD | +| [Instrumental Variables](iv.qmd) | IV Analysis via a Custom Monotone Probit Gibbs Sampler | \ No newline at end of file diff --git a/vignettes/iv.qmd b/vignettes/iv.qmd new file mode 100644 index 000000000..13b722159 --- /dev/null +++ b/vignettes/iv.qmd @@ -0,0 +1,936 @@ +--- +title: "Instrumental Variables (IV) with StochTree" +bibliography: vignettes.bib +execute: + freeze: auto # re-render only when source changes +--- + +```{r} +#| include: false +reticulate::use_python( + Sys.getenv( + "RETICULATE_PYTHON", + unset = file.path(rprojroot::find_root(rprojroot::has_file(".here")), ".venv", "bin", "python") + ), + required = TRUE +) +``` + +# Introduction + +Here we consider a causal inference problem with a binary treatment and a binary outcome +where there is unobserved confounding, but an exogenous instrument is available (also +binary). This problem requires several extensions to the basic BART model, all of which +can be implemented as Gibbs samplers using `stochtree`. Our analysis follows the +Bayesian nonparametric approach described in the supplement to @hahn2016bayesian. + +# Background + +To be concrete, suppose we wish to measure the effect of receiving a flu vaccine on the +probability of getting the flu. Individuals who opt to get a flu shot differ in many +ways from those that don't, and these lifestyle differences presumably also affect their +respective chances of getting the flu. However, a randomized encouragement design β€” +where some individuals are selected at random to receive extra incentive to get a flu +shot β€” allows us to tease apart the impact of the vaccine from the confounding factors. +This exact problem has been studied in @mcdonald1992effects, with follow-on analyses by +@hirano2000assessing, @richardson2011transparent, and @imbens2015causal. + +## Notation + +Let $V$ denote the treatment variable (vaccine). Let $Y$ denote the response +(getting the flu), $Z$ the instrument (encouragement), and $X$ an additional observable +covariate (patient age). + +Let $S$ denote the *principal strata*, an exhaustive characterization of how individuals +are affected by the encouragement. +Some people will get a flu shot no matter what: *always takers* ($a$). +Some will not get the shot no matter what: *never takers* ($n$). +*Compliers* ($c$) would not have gotten the shot but for the encouragement. +We assume no *defiers* ($d$). + +## The Causal Diagram + +![The causal directed acyclic graph (CDAG) for the IV flu example. The dashed red arrow represents a potential direct effect of $Z$ on $Y$, whose absence is the exclusion restriction.](R/IV/IV_CDAG.png){width=50% fig-align="center"} + +The biggest question about this graph concerns the dashed red arrow from the putative +instrument $Z$ to the outcome. If that arrow is present, $Z$ is not a valid instrument. +The assumption that there is no such arrow is the *exclusion restriction*. We will +explore what inferences are possible when we remain agnostic about its presence. + +## Potential Outcomes + +There are six distinct random variables: $V(0)$, $V(1)$, $Y(0,0)$, $Y(1,0)$, $Y(0,1)$, +and $Y(1,1)$. The fundamental problem of causal inference is that some of these are +never simultaneously observed: + +| $i$ | $Z_i$ | $V_i(0)$ | $V_i(1)$ | $Y_i(0,0)$ | $Y_i(1,0)$ | $Y_i(0,1)$ | $Y_i(1,1)$ | +|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| +| 1 | 1 | ? | 1 | ? | ? | ? | 0 | +| 2 | 0 | 1 | ? | ? | 1 | ? | ? | +| 3 | 0 | 0 | ? | 1 | ? | ? | ? | +| 4 | 1 | ? | 0 | ? | ? | 0 | ? | +| $\vdots$ | $\vdots$ | $\vdots$ | $\vdots$ | $\vdots$ | $\vdots$ | $\vdots$ | $\vdots$ | + +The principal strata are defined by which potential treatment $V(z)$ is observed: + +| $V_i(0)$ | $V_i(1)$ | $S_i$ | +|:---:|:---:|:---:| +| 0 | 0 | Never Taker ($n$) | +| 1 | 1 | Always Taker ($a$) | +| 0 | 1 | Complier ($c$) | +| 1 | 0 | Defier ($d$) | + +## Estimands and Identification + +Let $\pi_s(x) = \Pr(S=s \mid X=x)$ and +$\gamma_s^{vz}(x) = \Pr(Y(v,z)=1 \mid S=s, X=x)$. +The complier conditional average treatment effect +$\gamma_c^{1,z}(x) - \gamma_c^{0,z}(x)$ is our ultimate goal. + +Under the monotonicity assumption ($\pi_d(x) = 0$), the observed data imply: + +$$ +\begin{aligned} +p_{1 \mid 00}(x) &= \frac{\pi_c(x)}{\pi_c(x)+\pi_n(x)} \gamma_c^{00}(x) + + \frac{\pi_n(x)}{\pi_c(x)+\pi_n(x)} \gamma_n^{00}(x) \\ +p_{1 \mid 11}(x) &= \frac{\pi_c(x)}{\pi_c(x)+\pi_a(x)} \gamma_c^{11}(x) + + \frac{\pi_a(x)}{\pi_c(x)+\pi_a(x)} \gamma_a^{11}(x) \\ +p_{1 \mid 01}(x) &= \gamma_n^{01}(x) \\ +p_{1 \mid 10}(x) &= \gamma_a^{10}(x) +\end{aligned} +$$ + +and the strata probabilities satisfy: + +$$ +\Pr(V=1 \mid Z=0, X=x) = \pi_a(x), \qquad +\Pr(V=1 \mid Z=1, X=x) = \pi_a(x) + \pi_c(x). +$$ + +Under the exclusion restriction, $\gamma_c^{11}(x)$ and $\gamma_c^{00}(x)$ are +point-identified. Without it, they are partially identified: + +$$ +\max\!\left(0,\, \frac{\pi_c+\pi_n}{\pi_c} p_{1\mid 00} - \frac{\pi_n}{\pi_c}\right) +\leq \gamma_c^{00}(x) \leq +\min\!\left(1,\, \frac{\pi_c+\pi_n}{\pi_c} p_{1\mid 00}\right), +$$ + +and analogously for $\gamma_c^{11}(x)$. + +# Setup + +We load all necessary libraries + +:::{.panel-tabset group="language"} + +## R + +```{r} +#| message: false +library(stochtree) +``` + +## Python + +```{python} +import numpy as np +import matplotlib.pyplot as plt +from scipy.stats import norm + +from stochtree import ( + RNG, Dataset, Forest, ForestContainer, + ForestSampler, Residual, ForestModelConfig, GlobalModelConfig, +) +``` + +::: + +And set a seed for reproducibility + +:::{.panel-tabset group="language"} + +## R + +```{r} +random_seed <- 1234 +set.seed(random_seed) +``` + +## Python + +```{python} +random_seed = 1234 +rng = np.random.default_rng(random_seed) +``` + +::: + +## Data Generation + +Data size + +:::{.panel-tabset group="language"} + +## R + +```{r} +n <- 20000 +``` + +## Python + +```{python} +n = 20000 +``` + +::: + +Generate the Instrument + +:::{.panel-tabset group="language"} + +## R + +```{r} +z <- rbinom(n, 1, 0.5) +``` + +## Python + +```{python} +z = rng.binomial(n=1, p=0.5, size=n) +``` + +::: + +We conceptualize a covariate $X$ as patient age, drawn from a uniform distribution on $[0, 3]$ +(pre-standardized for illustration purposes) and generate the covariate + +:::{.panel-tabset group="language"} + +## R + +```{r} +p_X <- 1 +X <- matrix(runif(n * p_X, 0, 3), ncol = p_X) +x <- X[, 1] +``` + +## Python + +```{python} +p_X = 1 +X = rng.uniform(low=0., high=3., size=(n, p_X)) +x = X[:, 0] +``` + +::: + +We generate principal strata $S$ from a logistic model in $X$, parameterized so that the probability +of being a never taker decreases with age + +:::{.panel-tabset group="language"} + +## R + +```{r} +alpha_a <- 0; beta_a <- 1 +alpha_n <- 1; beta_n <- -1 +alpha_c <- 1; beta_c <- 1 + +pi_s <- function(xval) { + w_a <- exp(alpha_a + beta_a * xval) + w_n <- exp(alpha_n + beta_n * xval) + w_c <- exp(alpha_c + beta_c * xval) + w <- cbind(w_a, w_n, w_c) + w / rowSums(w) +} + +s <- sapply(seq_len(n), function(j) + sample(c("a", "n", "c"), 1, prob = pi_s(X[j, 1]))) +``` + +## Python + +```{python} +alpha_a = 0; beta_a = 1 +alpha_n = 1; beta_n = -1 +alpha_c = 1; beta_c = 1 + +def pi_s(xval, alpha_a, beta_a, alpha_n, beta_n, alpha_c, beta_c): + w = np.column_stack([ + np.exp(alpha_a + beta_a * xval), + np.exp(alpha_n + beta_n * xval), + np.exp(alpha_c + beta_c * xval), + ]) + return w / w.sum(axis=1, keepdims=True) + +strata_probs = pi_s(X[:, 0], alpha_a, beta_a, alpha_n, beta_n, alpha_c, beta_c) +s = np.empty(n, dtype=str) +for i in range(n): + s[i] = rng.choice(['a', 'n', 'c'], p=strata_probs[i, :]) +``` + +::: + +The treatment $V$ is generated as a deterministic function of $S$ and $Z$ β€” this is what gives the +principal strata their meaning + +:::{.panel-tabset group="language"} + +## R + +```{r} +v <- 1*(s == "a") + 0*(s == "n") + z*(s == "c") + (1-z)*(s == "d") +``` + +## Python + +```{python} +v = 1*(s == 'a') + 0*(s == 'n') + z*(s == "c") + (1-z)*(s == "d") +``` + +::: + +The outcome is generated according to the structural model below. +By varying this function we can alter the identification conditions. +Setting it to depend on `zval` violates the exclusion restriction, +and we do so here to illustrate partial identification. + +:::{.panel-tabset group="language"} + +## R + +```{r} +gamfun <- function(xval, vval, zval, sval) { + baseline <- pnorm(2 - xval - 2.5*(xval - 1.5)^2 - 0.5*zval + + 1*(sval == "n") - 1*(sval == "a")) + baseline - 0.5 * vval * baseline +} +y <- rbinom(n, 1, gamfun(X[, 1], v, z, s)) +``` + +## Python + +```{python} +def gamfun(xval, vval, zval, sval): + baseline = norm.cdf(2 - xval - 2.5*(xval - 1.5)**2 - 0.5*zval + + 1*(sval == "n") - 1*(sval == "a")) + return baseline - 0.5 * vval * baseline + +y = rng.binomial(n=1, p=gamfun(X[:, 0], v, z, s), size=n) +``` + +::: + +## Model Fitting + +In order to fit a monotone probit model, the observations must be sorted so that $Z=1$ cases come first. + +:::{.panel-tabset group="language"} + +## R + +```{r} +Xall <- cbind(X, v, z) +p_X <- p_X + 2 +index <- sort(z, decreasing = TRUE, index.return = TRUE) +X <- matrix(X[index$ix, ], ncol = 1) +Xall <- Xall[index$ix, ] +z <- z[index$ix] +v <- v[index$ix] +s <- s[index$ix] +y <- y[index$ix] +x <- x[index$ix] +``` + +## Python + +```{python} +Xall = np.concatenate((X, np.column_stack((v, z))), axis=1) +p_X = p_X + 2 +sort_index = np.argsort(z)[::-1] +X = X[sort_index, :] +Xall = Xall[sort_index, :] +z = z[sort_index] +v = v[sort_index] +s = s[sort_index] +y = y[sort_index] +x = x[sort_index] +``` + +::: + +We fit a probit BART model for $\Pr(Y=1 \mid V=1, Z=1, X=x)$ using the +Albert–Chib [@albert1993bayesian] data augmentation Gibbs sampler. We initialize the +forest, enter the main loop (alternating: sample forest | sample latent utilities), +and retain all post-warmstart draws. + +:::{.panel-tabset group="language"} + +## R + +```{r} +num_warmstart <- 10 +num_mcmc <- 1000 +num_samples <- num_warmstart + num_mcmc + +alpha <- 0.95; beta <- 2; min_samples_leaf <- 1; max_depth <- 20 +num_trees <- 50; cutpoint_grid_size <- 100 +tau_init <- 0.5 +leaf_prior_scale <- matrix(tau_init, ncol = 1) +feature_types <- as.integer(c(rep(0, p_X - 2), 1, 1)) +var_weights <- rep(1, p_X) / p_X +outcome_model_type <- 0 + +if (is.null(random_seed)) { + rng_r <- createCppRNG(-1) +} else { + rng_r <- createCppRNG(random_seed) +} + +forest_dataset <- createForestDataset(Xall) +forest_model_config <- createForestModelConfig( + feature_types = feature_types, num_trees = num_trees, + num_features = p_X, num_observations = n, + variable_weights = var_weights, leaf_dimension = 1, + alpha = alpha, beta = beta, min_samples_leaf = min_samples_leaf, + max_depth = max_depth, leaf_model_type = outcome_model_type, + leaf_model_scale = leaf_prior_scale, + cutpoint_grid_size = cutpoint_grid_size +) +global_model_config <- createGlobalModelConfig(global_error_variance = 1) +forest_model <- createForestModel(forest_dataset, forest_model_config, + global_model_config) +forest_samples <- createForestSamples(num_trees, 1, TRUE, FALSE) +active_forest <- createForest(num_trees, 1, TRUE, FALSE) + +n1 <- sum(y) +zed <- 0.25 * (2 * as.numeric(y) - 1) +outcome <- createOutcome(zed) +active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, + outcome_model_type, 0.0) +active_forest$adjust_residual(forest_dataset, outcome, forest_model, + FALSE, FALSE) + +gfr_flag <- TRUE +for (i in seq_len(num_samples)) { + if (i > num_warmstart) gfr_flag <- FALSE + forest_model$sample_one_iteration( + forest_dataset, outcome, forest_samples, active_forest, + rng_r, forest_model_config, global_model_config, + keep_forest = TRUE, gfr = gfr_flag, num_threads = 1 + ) + eta <- forest_samples$predict_raw_single_forest(forest_dataset, i - 1) + U1 <- runif(n1, pnorm(0, eta[y == 1], 1), 1) + zed[y == 1] <- qnorm(U1, eta[y == 1], 1) + U0 <- runif(n - n1, 0, pnorm(0, eta[y == 0], 1)) + zed[y == 0] <- qnorm(U0, eta[y == 0], 1) + outcome$update_data(zed) + forest_model$propagate_residual_update(outcome) +} +``` + +## Python + +```{python} +num_warmstart = 10 +num_mcmc = 1000 +num_samples = num_warmstart + num_mcmc + +alpha = 0.95; beta = 2; min_samples_leaf = 1; max_depth = 20 +num_trees = 50; cutpoint_grid_size = 100 +tau_init = 0.5 +leaf_prior_scale = np.array([[tau_init]]) +feature_types = np.append(np.repeat(0, p_X - 2), [1, 1]).astype(int) +var_weights = np.repeat(1.0 / p_X, p_X) +outcome_model_type = 0 + +cpp_rng = RNG(random_seed) if random_seed is not None else RNG() + +forest_dataset = Dataset() +forest_dataset.add_covariates(Xall) + +forest_model_config = ForestModelConfig( + feature_types=feature_types, num_trees=num_trees, + num_features=p_X, num_observations=n, + variable_weights=var_weights, leaf_dimension=1, + alpha=alpha, beta=beta, min_samples_leaf=min_samples_leaf, + max_depth=max_depth, leaf_model_type=outcome_model_type, + leaf_model_scale=leaf_prior_scale, + cutpoint_grid_size=cutpoint_grid_size, +) +global_model_config = GlobalModelConfig(global_error_variance=1.0) +forest_sampler = ForestSampler(forest_dataset, global_model_config, + forest_model_config) +forest_samples = ForestContainer(num_trees, 1, True, False) +active_forest = Forest(num_trees, 1, True, False) + +n1 = int(np.sum(y)) +zed = 0.25 * (2.0 * y - 1.0) +outcome = Residual(zed) +forest_sampler.prepare_for_sampler(forest_dataset, outcome, active_forest, + outcome_model_type, np.array([0.0])) + +gfr_flag = True +for i in range(num_samples): + if i >= num_warmstart: + gfr_flag = False + forest_sampler.sample_one_iteration( + forest_samples, active_forest, forest_dataset, outcome, cpp_rng, + global_model_config, forest_model_config, + keep_forest=True, gfr=gfr_flag, num_threads=1, + ) + eta = np.squeeze(forest_samples.predict_raw_single_forest(forest_dataset, i)) + mu0 = eta[y == 0]; mu1 = eta[y == 1] + u0 = rng.uniform(0, norm.cdf(-mu0), size=n - n1) + u1 = rng.uniform(norm.cdf(-mu1), 1, size=n1) + zed[y == 0] = mu0 + norm.ppf(u0) + zed[y == 1] = mu1 + norm.ppf(u1) + outcome.update_data(np.squeeze(zed) - eta) +``` + +::: + +The monotonicity constraint $\Pr(V=1 \mid Z=0, X=x) \leq \Pr(V=1 \mid Z=1, X=x)$ is enforced via the data augmentation of @papakostas2023forecasts. We parameterize + +$$ +\Pr(V=1 \mid Z=0, X=x) = \Phi_f(x)\,\Phi_h(x), \qquad +\Pr(V=1 \mid Z=1, X=x) = \Phi_f(x), +$$ + +where $\Phi_\mu(x)$ is the normal CDF with mean $\mu(x)$ and variance 1. + +:::{.panel-tabset group="language"} + +## R + +```{r} +X_h <- as.matrix(X[z == 0, ]) +n0 <- sum(z == 0); n1 <- sum(z == 1) +num_trees_f <- 50; num_trees_h <- 20 +feature_types_mono <- as.integer(rep(0, 1)) +var_weights_mono <- rep(1, 1) +tau_h <- 1 / num_trees_h +leaf_scale_h <- matrix(tau_h, ncol = 1) +leaf_scale_f <- matrix(1 / num_trees_f, ncol = 1) + +forest_dataset_f <- createForestDataset(X) +forest_dataset_h <- createForestDataset(X_h) + +fmc_f <- createForestModelConfig( + feature_types = feature_types_mono, num_trees = num_trees_f, + num_features = ncol(X), num_observations = nrow(X), + variable_weights = var_weights_mono, leaf_dimension = 1, + alpha = alpha, beta = beta, min_samples_leaf = min_samples_leaf, + max_depth = max_depth, leaf_model_type = 0, + leaf_model_scale = leaf_scale_f, cutpoint_grid_size = cutpoint_grid_size +) +fmc_h <- createForestModelConfig( + feature_types = feature_types_mono, num_trees = num_trees_h, + num_features = ncol(X_h), num_observations = nrow(X_h), + variable_weights = var_weights_mono, leaf_dimension = 1, + alpha = alpha, beta = beta, min_samples_leaf = min_samples_leaf, + max_depth = max_depth, leaf_model_type = 0, + leaf_model_scale = leaf_scale_h, cutpoint_grid_size = cutpoint_grid_size +) +gmc_mono <- createGlobalModelConfig(global_error_variance = 1) +fm_f <- createForestModel(forest_dataset_f, fmc_f, gmc_mono) +fm_h <- createForestModel(forest_dataset_h, fmc_h, gmc_mono) + +fs_f <- createForestSamples(num_trees_f, 1, TRUE) +fs_h <- createForestSamples(num_trees_h, 1, TRUE) +af_f <- createForest(num_trees_f, 1, TRUE) +af_h <- createForest(num_trees_h, 1, TRUE) + +v1 <- v[z == 1]; v0 <- v[z == 0] +R1 <- rep(NA, n0); R0 <- rep(NA, n0) +R1[v0 == 1] <- 1; R0[v0 == 1] <- 1 +R1[v0 == 0] <- 0; R0[v0 == 0] <- sample(c(0, 1), sum(v0 == 0), replace = TRUE) +vaug <- c(v1, R1) + +z_f <- (2 * as.numeric(vaug) - 1); z_f <- z_f / sd(z_f) +z_h <- (2 * as.numeric(R0) - 1); z_h <- z_h / sd(z_h) +out_f <- createOutcome(z_f); out_h <- createOutcome(z_h) +af_f$prepare_for_sampler(forest_dataset_f, out_f, fm_f, 0, 0.0) +af_h$prepare_for_sampler(forest_dataset_h, out_h, fm_h, 0, 0.0) +af_f$adjust_residual(forest_dataset_f, out_f, fm_f, FALSE, FALSE) +af_h$adjust_residual(forest_dataset_h, out_h, fm_h, FALSE, FALSE) + +gfr_flag <- TRUE +for (i in seq_len(num_samples)) { + if (i > num_warmstart) gfr_flag <- FALSE + fm_f$sample_one_iteration(forest_dataset_f, out_f, fs_f, af_f, + rng_r, fmc_f, gmc_mono, keep_forest = TRUE, gfr = gfr_flag, num_threads = 1) + fm_h$sample_one_iteration(forest_dataset_h, out_h, fs_h, af_h, + rng_r, fmc_h, gmc_mono, keep_forest = TRUE, gfr = gfr_flag, num_threads = 1) + + eta_f <- fs_f$predict_raw_single_forest(forest_dataset_f, i - 1) + eta_h <- fs_h$predict_raw_single_forest(forest_dataset_h, i - 1) + + idx0 <- which(v0 == 0) + w1 <- (1 - pnorm(eta_h[idx0])) * (1 - pnorm(eta_f[n1 + idx0])) + w2 <- (1 - pnorm(eta_h[idx0])) * pnorm(eta_f[n1 + idx0]) + w3 <- pnorm(eta_h[idx0]) * (1 - pnorm(eta_f[n1 + idx0])) + s_w <- w1 + w2 + w3 + u <- runif(length(idx0)) + temp <- 1*(u < w1/s_w) + 2*(u > w1/s_w & u < (w1+w2)/s_w) + 3*(u > (w1+w2)/s_w) + R1[v0 == 0] <- 1*(temp == 2); R0[v0 == 0] <- 1*(temp == 3) + vaug <- c(v1, R1) + + U1 <- runif(sum(R0), pnorm(0, eta_h[R0 == 1], 1), 1) + z_h[R0 == 1] <- qnorm(U1, eta_h[R0 == 1], 1) + U0 <- runif(n0 - sum(R0), 0, pnorm(0, eta_h[R0 == 0], 1)) + z_h[R0 == 0] <- qnorm(U0, eta_h[R0 == 0], 1) + + U1 <- runif(sum(vaug), pnorm(0, eta_f[vaug == 1], 1), 1) + z_f[vaug == 1] <- qnorm(U1, eta_f[vaug == 1], 1) + U0 <- runif(n - sum(vaug), 0, pnorm(0, eta_f[vaug == 0], 1)) + z_f[vaug == 0] <- qnorm(U0, eta_f[vaug == 0], 1) + + out_h$update_data(z_h); fm_h$propagate_residual_update(out_h) + out_f$update_data(z_f); fm_f$propagate_residual_update(out_f) +} +``` + +## Python + +```{python} +X_h = X[z == 0, :] +n0 = int(np.sum(z == 0)); n1 = int(np.sum(z == 1)) +num_trees_f = 50; num_trees_h = 20 +feature_types_mono = np.repeat(0, p_X - 2).astype(int) +var_weights_mono = np.repeat(1.0 / (p_X - 2.0), p_X - 2) +leaf_scale_f = np.array([[1.0 / num_trees_f]]) +leaf_scale_h = np.array([[1.0 / num_trees_h]]) + +forest_dataset_f = Dataset(); forest_dataset_f.add_covariates(X) +forest_dataset_h = Dataset(); forest_dataset_h.add_covariates(X_h) + +fmc_f = ForestModelConfig( + feature_types=feature_types_mono, num_trees=num_trees_f, + num_features=X.shape[1], num_observations=n, + variable_weights=var_weights_mono, leaf_dimension=1, + alpha=alpha, beta=beta, min_samples_leaf=min_samples_leaf, + max_depth=max_depth, leaf_model_type=0, + leaf_model_scale=leaf_scale_f, cutpoint_grid_size=cutpoint_grid_size, +) +fmc_h = ForestModelConfig( + feature_types=feature_types_mono, num_trees=num_trees_h, + num_features=X_h.shape[1], num_observations=n0, + variable_weights=var_weights_mono, leaf_dimension=1, + alpha=alpha, beta=beta, min_samples_leaf=min_samples_leaf, + max_depth=max_depth, leaf_model_type=0, + leaf_model_scale=leaf_scale_h, cutpoint_grid_size=cutpoint_grid_size, +) +gmc_mono = GlobalModelConfig(global_error_variance=1.0) +fs_f = ForestSampler(forest_dataset_f, gmc_mono, fmc_f) +fs_h = ForestSampler(forest_dataset_h, gmc_mono, fmc_h) +forest_samples_f = ForestContainer(num_trees_f, 1, True, False) +forest_samples_h = ForestContainer(num_trees_h, 1, True, False) +af_f = Forest(num_trees_f, 1, True, False) +af_h = Forest(num_trees_h, 1, True, False) + +v1 = v[z == 1]; v0 = v[z == 0] +R1 = np.empty(n0); R0 = np.empty(n0) +R1[v0 == 1] = 1; R0[v0 == 1] = 1 +nv0 = int(np.sum(v0 == 0)) +R1[v0 == 0] = 0; R0[v0 == 0] = rng.choice([0, 1], size=nv0) +vaug = np.append(v1, R1) +z_f = (2.0 * vaug - 1.0); z_f = z_f / np.std(z_f) +z_h = (2.0 * R0 - 1.0); z_h = z_h / np.std(z_h) +out_f = Residual(z_f); out_h = Residual(z_h) +fs_f.prepare_for_sampler(forest_dataset_f, out_f, af_f, 0, np.array([0.0])) +fs_h.prepare_for_sampler(forest_dataset_h, out_h, af_h, 0, np.array([0.0])) + +gfr_flag = True +for i in range(num_samples): + if i >= num_warmstart: + gfr_flag = False + fs_f.sample_one_iteration(forest_samples_f, af_f, forest_dataset_f, out_f, + cpp_rng, gmc_mono, fmc_f, keep_forest=True, gfr=gfr_flag, num_threads=1) + fs_h.sample_one_iteration(forest_samples_h, af_h, forest_dataset_h, out_h, + cpp_rng, gmc_mono, fmc_h, keep_forest=True, gfr=gfr_flag, num_threads=1) + + eta_f = np.squeeze(forest_samples_f.predict_raw_single_forest(forest_dataset_f, i)) + eta_h = np.squeeze(forest_samples_h.predict_raw_single_forest(forest_dataset_h, i)) + + idx0 = np.where(v0 == 0)[0] + w1 = (1 - norm.cdf(eta_h[idx0])) * (1 - norm.cdf(eta_f[n1 + idx0])) + w2 = (1 - norm.cdf(eta_h[idx0])) * norm.cdf(eta_f[n1 + idx0]) + w3 = norm.cdf(eta_h[idx0]) * (1 - norm.cdf(eta_f[n1 + idx0])) + s_w = w1 + w2 + w3 + u = rng.uniform(size=len(idx0)) + temp = 1*(u < w1/s_w) + 2*((u > w1/s_w) & (u < (w1+w2)/s_w)) + 3*(u > (w1+w2)/s_w) + R1[v0 == 0] = (temp == 2).astype(float) + R0[v0 == 0] = (temp == 3).astype(float) + vaug = np.append(v1, R1) + + mu1 = eta_h[R0 == 1] + z_h[R0 == 1] = mu1 + norm.ppf(rng.uniform(norm.cdf(-mu1), 1, size=int(np.sum(R0)))) + mu0 = eta_h[R0 == 0] + z_h[R0 == 0] = mu0 + norm.ppf(rng.uniform(0, norm.cdf(-mu0), size=n0 - int(np.sum(R0)))) + + mu1 = eta_f[vaug == 1] + z_f[vaug == 1] = mu1 + norm.ppf(rng.uniform(norm.cdf(-mu1), 1, size=int(np.sum(vaug)))) + mu0 = eta_f[vaug == 0] + z_f[vaug == 0] = mu0 + norm.ppf(rng.uniform(0, norm.cdf(-mu0), size=n - int(np.sum(vaug)))) + + out_h.update_data(np.squeeze(z_h) - eta_h) + out_f.update_data(np.squeeze(z_f) - eta_f) +``` + +::: + +## Extracting Estimates and Plotting + +We compute the true $ITT_c$ and LATE functions on a prediction grid, then extract +posterior predictions and plot credible bands. + +### Prediction Grid and Truth + +:::{.panel-tabset group="language"} + +## R + +```{r} +ngrid <- 200 +xgrid <- seq(0.1, 2.5, length.out = ngrid) +X_11 <- cbind(xgrid, rep(1, ngrid), rep(1, ngrid)) +X_00 <- cbind(xgrid, rep(0, ngrid), rep(0, ngrid)) +X_01 <- cbind(xgrid, rep(0, ngrid), rep(1, ngrid)) +X_10 <- cbind(xgrid, rep(1, ngrid), rep(0, ngrid)) + +pi_strat <- pi_s(xgrid) +w_a <- pi_strat[, 1]; w_n <- pi_strat[, 2]; w_c <- pi_strat[, 3] + +p11_true <- (w_c/(w_a+w_c))*gamfun(xgrid,1,1,"c") + (w_a/(w_a+w_c))*gamfun(xgrid,1,1,"a") +p00_true <- (w_c/(w_n+w_c))*gamfun(xgrid,0,0,"c") + (w_n/(w_n+w_c))*gamfun(xgrid,0,0,"n") +itt_c_true <- gamfun(xgrid, 1, 1, "c") - gamfun(xgrid, 0, 0, "c") +LATE_true0 <- gamfun(xgrid, 1, 0, "c") - gamfun(xgrid, 0, 0, "c") +LATE_true1 <- gamfun(xgrid, 1, 1, "c") - gamfun(xgrid, 0, 1, "c") +``` + +## Python + +```{python} +ngrid = 200 +xgrid = np.linspace(0.1, 2.5, ngrid) +X_11 = np.column_stack((xgrid, np.ones(ngrid), np.ones(ngrid))) +X_00 = np.column_stack((xgrid, np.zeros(ngrid), np.zeros(ngrid))) +X_01 = np.column_stack((xgrid, np.zeros(ngrid), np.ones(ngrid))) +X_10 = np.column_stack((xgrid, np.ones(ngrid), np.zeros(ngrid))) + +pi_strat = pi_s(xgrid, alpha_a, beta_a, alpha_n, beta_n, alpha_c, beta_c) +w_a = pi_strat[:, 0]; w_n = pi_strat[:, 1]; w_c = pi_strat[:, 2] + +p11_true = (w_c/(w_a+w_c))*gamfun(xgrid,1,1,"c") + (w_a/(w_a+w_c))*gamfun(xgrid,1,1,"a") +p00_true = (w_c/(w_n+w_c))*gamfun(xgrid,0,0,"c") + (w_n/(w_n+w_c))*gamfun(xgrid,0,0,"n") +itt_c_true = gamfun(xgrid, 1, 1, "c") - gamfun(xgrid, 0, 0, "c") +LATE_true0 = gamfun(xgrid, 1, 0, "c") - gamfun(xgrid, 0, 0, "c") +LATE_true1 = gamfun(xgrid, 1, 1, "c") - gamfun(xgrid, 0, 1, "c") +``` + +::: + +### Extract Posterior Predictions + +:::{.panel-tabset group="language"} + +## R + +```{r} +fd_grid <- createForestDataset(as.matrix(xgrid)) +fd_11 <- createForestDataset(X_11) +fd_00 <- createForestDataset(X_00) +fd_01 <- createForestDataset(X_01) +fd_10 <- createForestDataset(X_10) + +phat_11 <- pnorm(forest_samples$predict(fd_11)) +phat_00 <- pnorm(forest_samples$predict(fd_00)) +phat_01 <- pnorm(forest_samples$predict(fd_01)) +phat_10 <- pnorm(forest_samples$predict(fd_10)) +phat_ac <- pnorm(fs_f$predict(fd_grid)) +phat_a <- phat_ac * pnorm(fs_h$predict(fd_grid)) +phat_c <- phat_ac - phat_a +phat_n <- 1 - phat_ac +``` + +## Python + +```{python} +def make_dataset(mat): + ds = Dataset() + ds.add_covariates(mat) + return ds + +fd_grid = make_dataset(np.expand_dims(xgrid, 1)) +fd_11 = make_dataset(X_11); fd_00 = make_dataset(X_00) +fd_01 = make_dataset(X_01); fd_10 = make_dataset(X_10) + +phat_11 = norm.cdf(forest_samples.predict(fd_11)) +phat_00 = norm.cdf(forest_samples.predict(fd_00)) +phat_01 = norm.cdf(forest_samples.predict(fd_01)) +phat_10 = norm.cdf(forest_samples.predict(fd_10)) +phat_ac = norm.cdf(forest_samples_f.predict(fd_grid)) +phat_a = phat_ac * norm.cdf(forest_samples_h.predict(fd_grid)) +phat_c = phat_ac - phat_a +phat_n = 1 - phat_ac +``` + +::: + +### Model Fit Diagnostics + +:::{.panel-tabset group="language"} + +## R + +```{r} +#| fig-cap: "Fitted vs. true conditional outcome probabilities." +par(mfrow = c(1, 2)) +plot(p11_true, rowMeans(phat_11), pch = 20, cex = 0.5, bty = "n", + xlab = "True p11", ylab = "Fitted p11") +abline(0, 1, col = "red") +plot(p00_true, rowMeans(phat_00), pch = 20, cex = 0.5, bty = "n", + xlab = "True p00", ylab = "Fitted p00") +abline(0, 1, col = "red") +``` + +## Python + +```{python} +#| fig-cap: "Fitted vs. true conditional outcome probabilities." +fig, (ax1, ax2) = plt.subplots(1, 2) +ax1.scatter(p11_true, np.mean(phat_11, axis=1), color="black", s=5) +ax1.axline((0, 0), slope=1, color="red", linestyle=(0, (3, 3))) +ax2.scatter(p00_true, np.mean(phat_00, axis=1), color="black", s=5) +ax2.axline((0, 0), slope=1, color="red", linestyle=(0, (3, 3))) +plt.show() +``` + +::: + +### Construct and Plot the $ITT_c$ + +We center the posterior on the identified interval at the value implied by a valid +exclusion restriction, then construct credible bands for the $ITT_c$ and compare +to the LATE. + +:::{.panel-tabset group="language"} + +## R + +```{r} +#| cache: true +#| fig-cap: "Posterior credible bands for the ITT_c (gold/brown) and LATE (gray/blue) compared to the true ITT_c (solid black), LATE_z0 (dotted), and LATE_z1 (dashed)." +ss <- 6 +itt_c <- late <- matrix(NA, ngrid, ncol(phat_c)) + +for (j in seq_len(ncol(phat_c))) { + gamest11 <- ((phat_a[,j]+phat_c[,j])/phat_c[,j])*phat_11[,j] - + phat_10[,j]*phat_a[,j]/phat_c[,j] + lower11 <- pmax(0, ((phat_a[,j]+phat_c[,j])/phat_c[,j])*phat_11[,j] - + phat_a[,j]/phat_c[,j]) + upper11 <- pmin(1, ((phat_a[,j]+phat_c[,j])/phat_c[,j])*phat_11[,j]) + m11 <- (gamest11 - lower11)/(upper11 - lower11) + a1 <- ss*m11; b1 <- ss*(1 - m11) + a1[m11 < 0] <- 1; b1[m11 < 0] <- 5 + a1[m11 > 1] <- 5; b1[m11 > 1] <- 1 + + gamest00 <- ((phat_n[,j]+phat_c[,j])/phat_c[,j])*phat_00[,j] - + phat_01[,j]*phat_n[,j]/phat_c[,j] + lower00 <- pmax(0, ((phat_n[,j]+phat_c[,j])/phat_c[,j])*phat_00[,j] - + phat_n[,j]/phat_c[,j]) + upper00 <- pmin(1, ((phat_n[,j]+phat_c[,j])/phat_c[,j])*phat_00[,j]) + m00 <- (gamest00 - lower00)/(upper00 - lower00) + a0 <- ss*m00; b0 <- ss*(1 - m00) + a0[m00 < 0] <- 1; b0[m00 < 0] <- 5 + a0[m00 > 1] <- 5; b0[m00 > 1] <- 1 + + itt_c[,j] <- lower11 + (upper11-lower11)*rbeta(ngrid, a1, b1) - + (lower00 + (upper00-lower00)*rbeta(ngrid, a0, b0)) + late[,j] <- gamest11 - gamest00 +} + +upperq <- apply(itt_c, 1, quantile, 0.975) +lowerq <- apply(itt_c, 1, quantile, 0.025) +upperq_er <- apply(late, 1, quantile, 0.975, na.rm = TRUE) +lowerq_er <- apply(late, 1, quantile, 0.025, na.rm = TRUE) + +plot(xgrid, itt_c_true, type = "n", ylim = c(-0.75, 0.05), bty = "n", + xlab = "x", ylab = "Treatment effect") +polygon(c(xgrid, rev(xgrid)), c(lowerq, rev(upperq)), + col = rgb(0.5, 0.25, 0, 0.25), border = FALSE) +polygon(c(xgrid, rev(xgrid)), c(lowerq_er, rev(upperq_er)), + col = rgb(0, 0, 0.5, 0.25), border = FALSE) +lines(xgrid, rowMeans(late), col = "slategray", lwd = 3) +lines(xgrid, rowMeans(itt_c), col = "goldenrod1", lwd = 1) +lines(xgrid, LATE_true0, col = "black", lwd = 2, lty = 3) +lines(xgrid, LATE_true1, col = "black", lwd = 2, lty = 2) +lines(xgrid, itt_c_true, col = "black", lwd = 1) +``` + +## Python + +```{python} +#| cache: true +#| fig-cap: "Posterior credible bands for ITT_c and LATE compared to the true functions." +ss = 6 +itt_c = np.empty((ngrid, phat_c.shape[1])) +late = np.empty((ngrid, phat_c.shape[1])) + +for j in range(phat_c.shape[1]): + gamest11 = ((phat_a[:,j]+phat_c[:,j])/phat_c[:,j])*phat_11[:,j] - \ + phat_10[:,j]*phat_a[:,j]/phat_c[:,j] + lower11 = np.maximum(0., ((phat_a[:,j]+phat_c[:,j])/phat_c[:,j])*phat_11[:,j] - + phat_a[:,j]/phat_c[:,j]) + upper11 = np.minimum(1., ((phat_a[:,j]+phat_c[:,j])/phat_c[:,j])*phat_11[:,j]) + m11 = (gamest11 - lower11) / (upper11 - lower11) + a1 = ss * m11; b1 = ss * (1 - m11) + a1[m11 < 0] = 1; b1[m11 < 0] = 5 + a1[m11 > 1] = 5; b1[m11 > 1] = 1 + + gamest00 = ((phat_n[:,j]+phat_c[:,j])/phat_c[:,j])*phat_00[:,j] - \ + phat_01[:,j]*phat_n[:,j]/phat_c[:,j] + lower00 = np.maximum(0., ((phat_n[:,j]+phat_c[:,j])/phat_c[:,j])*phat_00[:,j] - + phat_n[:,j]/phat_c[:,j]) + upper00 = np.minimum(1., ((phat_n[:,j]+phat_c[:,j])/phat_c[:,j])*phat_00[:,j]) + m00 = (gamest00 - lower00) / (upper00 - lower00) + a0 = ss * m00; b0 = ss * (1 - m00) + a0[m00 < 0] = 1; b0[m00 < 0] = 5 + a0[m00 > 1] = 5; b0[m00 > 1] = 1 + + itt_c[:, j] = lower11 + (upper11-lower11)*rng.beta(a1, b1, ngrid) - \ + (lower00 + (upper00-lower00)*rng.beta(a0, b0, ngrid)) + late[:, j] = gamest11 - gamest00 + +upperq = np.quantile(itt_c, 0.975, axis=1) +lowerq = np.quantile(itt_c, 0.025, axis=1) +upperq_er = np.quantile(late, 0.975, axis=1) +lowerq_er = np.quantile(late, 0.025, axis=1) + +plt.plot(xgrid, itt_c_true, color="black") +plt.ylim(-0.75, 0.05) +plt.fill(np.append(xgrid, xgrid[::-1]), np.append(lowerq, upperq[::-1]), + color=(0.5, 0.5, 0, 0.25)) +plt.fill(np.append(xgrid, xgrid[::-1]), np.append(lowerq_er, upperq_er[::-1]), + color=(0, 0, 0.5, 0.25)) +plt.plot(xgrid, np.mean(late, axis=1), color="darkgrey") +plt.plot(xgrid, np.mean(itt_c, axis=1), color="gold") +plt.plot(xgrid, LATE_true0, color="black", linestyle=(0, (2, 2))) +plt.plot(xgrid, LATE_true1, color="black", linestyle=(0, (4, 4))) +plt.show() +``` + +::: + +With a valid exclusion restriction the three black curves would all be identical. +Without it, the direct effect of $Z$ on $Y$ causes them to diverge. Specifically, the +$ITT_c$ (gold) compares getting the vaccine *and* the reminder to not getting either β€” +when both reduce risk, we see a larger overall reduction. The two LATE effects compare +the isolated impact of the vaccine among those who did and did not receive the reminder, +respectively. + +## References diff --git a/vignettes/multi-chain.qmd b/vignettes/multi-chain.qmd new file mode 100644 index 000000000..40204324a --- /dev/null +++ b/vignettes/multi-chain.qmd @@ -0,0 +1,858 @@ +--- +title: "Running and Combining Multiple MCMC Chains" +bibliography: vignettes.bib +execute: + freeze: auto # re-render only when source changes +--- + +```{r} +#| include: false +reticulate::use_python( + Sys.getenv( + "RETICULATE_PYTHON", + unset = file.path(rprojroot::find_root(rprojroot::has_file(".here")), ".venv", "bin", "python") + ), + required = TRUE +) +``` + +# Motivation + +Mixing of an MCMC sampler is a perennial concern for complex Bayesian models. BART +and BCF are no exception. One common way to address such concerns is to run multiple +independent "chains" of an MCMC sampler, so that if each chain gets stuck in a +different region of the posterior, their combined samples attain better coverage of +the full posterior. + +This idea works with the classic "root-initialized" MCMC sampler of @chipman2010bart, +but a key insight of @he2023stochastic and @krantsevich2023stochastic is that the GFR +algorithm may be used to warm-start initialize multiple chains of the BART / BCF MCMC +sampler. + +Operationally, the above two approaches have the same implementation (setting +`num_gfr > 0` if warm-start initialization is desired), so this vignette will +demonstrate how to run a multi-chain sampler sequentially or in parallel. + +# Setup + +::::{.panel-tabset group="language"} + +## R + +```{r} +#| warning: false +#| message: false +library(stochtree) +library(ggplot2) +library(coda) +library(bayesplot) +library(foreach) +library(doParallel) +``` + +## Python + +```{python} +import numpy as np +import matplotlib.pyplot as plt +import arviz as az +from stochtree import BARTModel + +rng = np.random.default_rng(1111) +``` + +:::: + +# Demo 1: Supervised Learning + +## Data Simulation + +Simulate a simple partitioned linear model. + +::::{.panel-tabset group="language"} + +## R + +```{r} +# Generate the data +set.seed(1111) +n <- 500 +p_x <- 10 +p_w <- 1 +snr <- 3 +X <- matrix(runif(n * p_x), ncol = p_x) +leaf_basis <- matrix(runif(n * p_w), ncol = p_w) +f_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * + (-7.5 * leaf_basis[, 1]) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5 * leaf_basis[, 1]) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5 * leaf_basis[, 1]) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5 * leaf_basis[, 1])) +noise_sd <- sd(f_XW) / snr +y <- f_XW + rnorm(n, 0, 1) * noise_sd + +# Split data into test and train sets +test_set_pct <- 0.2 +n_test <- round(test_set_pct * n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds, ] +X_train <- X[train_inds, ] +leaf_basis_test <- leaf_basis[test_inds, ] +leaf_basis_train <- leaf_basis[train_inds, ] +y_test <- y[test_inds] +y_train <- y[train_inds] +``` + +## Python + +```{python} +n, p_x, p_w, snr = 500, 10, 1, 3 +X = rng.uniform(size=(n, p_x)) +leaf_basis = rng.uniform(size=(n, p_w)) +f_XW = (((0 <= X[:, 0]) & (0.25 > X[:, 0])) * (-7.5 * leaf_basis[:, 0]) + + ((0.25 <= X[:, 0]) & (0.5 > X[:, 0])) * (-2.5 * leaf_basis[:, 0]) + + ((0.5 <= X[:, 0]) & (0.75 > X[:, 0])) * (2.5 * leaf_basis[:, 0]) + + ((0.75 <= X[:, 0]) & (1 > X[:, 0])) * (7.5 * leaf_basis[:, 0])) +noise_sd = np.std(f_XW) / snr +y = f_XW + rng.normal(0, noise_sd, size=n) + +test_set_pct = 0.2 +n_test = round(test_set_pct * n) +test_inds = rng.choice(n, n_test, replace=False) +train_inds = np.setdiff1d(np.arange(n), test_inds) +X_test, X_train = X[test_inds], X[train_inds] +leaf_basis_test, leaf_basis_train = leaf_basis[test_inds], leaf_basis[train_inds] +y_test, y_train = y[test_inds], y[train_inds] +``` + +:::: + +## Sampling Multiple Chains Sequentially from Scratch + +The simplest way to sample multiple chains of a stochtree model is to do so +"sequentially," that is, after chain 1 is sampled, chain 2 is sampled from a +different starting state, and similarly for each of the requested chains. This is +supported internally in both the `bart()` and `bcf()` functions, with the +`num_chains` parameter in the `general_params` list. + +Define some high-level parameters, including number of chains to run and number of +samples per chain. Here we run 4 independent chains with 2000 MCMC iterations, each +of which is burned in for 1000 iterations. + +::::{.panel-tabset group="language"} + +## R + +```{r} +num_chains <- 4 +num_gfr <- 0 +num_burnin <- 1000 +num_mcmc <- 2000 +``` + +## Python + +```{python} +num_chains = 4 +num_gfr = 0 +num_burnin = 1000 +num_mcmc = 2000 +``` + +:::: + +Run the sampler. + +::::{.panel-tabset group="language"} + +## R + +```{r} +bart_model <- stochtree::bart( + X_train = X_train, + leaf_basis_train = leaf_basis_train, + y_train = y_train, + num_gfr = num_gfr, + num_burnin = num_burnin, + num_mcmc = num_mcmc, + general_params = list(num_chains = num_chains) +) +``` + +## Python + +```{python} +bart_model = BARTModel() +bart_model.sample( + X_train=X_train, leaf_basis_train=leaf_basis_train, y_train=y_train, + num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, + general_params={"num_threads": 1, "num_chains": num_chains}, +) +``` + +:::: + +Now we have a model with `num_chains * num_mcmc` samples stored internally. These +samples are arranged sequentially, with the first `num_mcmc` samples corresponding +to chain 1, the next `num_mcmc` samples to chain 2, etc. + +Since each chain is a set of samples of the same model, we can analyze the samples +collectively, for example, by looking at out-of-sample predictions. + +::::{.panel-tabset group="language"} + +## R + +```{r} +y_hat_test <- predict( + bart_model, + X = X_test, + leaf_basis = leaf_basis_test, + type = "mean", + terms = "y_hat" +) +plot(y_hat_test, y_test, xlab = "Predicted", ylab = "Actual") +abline(0, 1, col = "red", lty = 3, lwd = 3) +``` + +## Python + +```{python} +y_hat_test = bart_model.predict( + X=X_test, leaf_basis=leaf_basis_test, type="mean", terms="y_hat" +) +lo, hi = min(y_hat_test.min(), y_test.min()), max(y_hat_test.max(), y_test.max()) +plt.scatter(y_hat_test, y_test, alpha=0.5) +plt.plot([lo, hi], [lo, hi], color="red", linestyle="dashed", linewidth=2) +plt.xlabel("Predicted"); plt.ylabel("Actual") +plt.show() +``` + +:::: + +Now, suppose we want to analyze each of the chains separately to assess mixing / +convergence. We can construct an `mcmc.list` in the `coda` package to perform various +diagnostics. + +::::{.panel-tabset group="language"} + +## R + +```{r} +sigma2_coda_list <- coda::as.mcmc.list(lapply( + 1:num_chains, + function(chain_idx) { + offset <- (chain_idx - 1) * num_mcmc + inds_start <- offset + 1 + inds_end <- offset + num_mcmc + coda::mcmc(bart_model$sigma2_global_samples[inds_start:inds_end]) + } +)) +traceplot(sigma2_coda_list, ylab = expression(sigma^2)) +abline(h = noise_sd^2, col = "black", lty = 3, lwd = 3) +acf <- autocorr.diag(sigma2_coda_list) +ess <- effectiveSize(sigma2_coda_list) +rhat <- gelman.diag(sigma2_coda_list, autoburnin = F) +cat(paste0( + "Average autocorrelation across chains:\n", + paste0(paste0(rownames(acf), ": ", round(acf, 3)), collapse = ", "), + "\nTotal effective sample size across chains: ", + paste0(round(ess, 1), collapse = ", "), + "\n'R-hat' potential scale reduction factor of Gelman and Rubin (1992)): ", + paste0(round(rhat$psrf[, 1], 3), collapse = ", ") +)) +``` + +## Python + +```{python} +# Reshape flat sigma2 samples into (num_chains, num_mcmc) for per-chain diagnostics +# az.from_dict requires nested dict: {"posterior": {"var": array(chains, draws)}} +idata = az.from_dict({"posterior": {"sigma2": bart_model.global_var_samples.reshape(num_chains, num_mcmc)}}) + +az.plot_trace(idata) +plt.axhline(noise_sd**2, color="black", linestyle="dashed", linewidth=1.5) +plt.show() + +print("ESS: ", az.ess(idata)) +print("R-hat:", az.rhat(idata)) +az.plot_autocorr(idata) +plt.show() +``` + +:::: + +We can convert this to an array to be consumed by the `bayesplot` package. + +::::{.panel-tabset group="language"} + +## R + +```{r} +coda_array <- as.array(sigma2_coda_list) +dim(coda_array) <- c(nrow(coda_array), ncol(coda_array), 1) +dimnames(coda_array) <- list( + Iteration = paste0("iter", 1:num_mcmc), + Chain = paste0("chain", 1:num_chains), + Parameter = "sigma2_global" +) +``` + +## Python + +```{python} +# sigma2_by_chain already has shape (num_chains, num_mcmc) β€” ready for per-chain plots +sigma2_chains = bart_model.global_var_samples.reshape(num_chains, num_mcmc) +``` + +:::: + +From here, we can visualize the posterior of $\sigma^2$ for each chain, comparing +to the true simulated value. + +::::{.panel-tabset group="language"} + +## R + +```{r} +#| warning: false +#| message: false +bayesplot::mcmc_hist_by_chain( + coda_array, + pars = "sigma2_global" +) + + ggplot2::labs( + title = "Global error scale posterior by chain", + x = expression(sigma^2) + ) + + ggplot2::theme( + plot.title = ggplot2::element_text(hjust = 0.5) + ) + + ggplot2::geom_vline( + xintercept = noise_sd^2, + color = "black", + linetype = "dashed", + size = 1 + ) +``` + +## Python + +```{python} +fig, axes = plt.subplots(1, num_chains, figsize=(12, 3), sharey=True) +for i, ax in enumerate(axes): + ax.hist(sigma2_chains[i], bins=30) + ax.axvline(noise_sd**2, color="black", linestyle="dashed", linewidth=1.5) + ax.set_title(f"Chain {i+1}") + ax.set_xlabel(r"$\sigma^2$") +fig.suptitle("Global error scale posterior by chain") +plt.tight_layout() +plt.show() +``` + +:::: + +## Sampling Multiple Chains Sequentially from XBART Forests + +In the example above, each chain was initialized from "root". If we sample a model +using a small number of 'grow-from-root' iterations, we can use these forests to +initialize MCMC chains. + +::::{.panel-tabset group="language"} + +## R + +```{r} +num_chains <- 4 +num_gfr <- 5 +num_burnin <- 1000 +num_mcmc <- 2000 +``` + +## Python + +```{python} +num_chains = 4 +num_gfr = 5 +num_burnin = 1000 +num_mcmc = 2000 +``` + +:::: + +Run the initial GFR sampler. + +::::{.panel-tabset group="language"} + +## R + +```{r} +xbart_model <- stochtree::bart( + X_train = X_train, + leaf_basis_train = leaf_basis_train, + y_train = y_train, + num_gfr = num_gfr, + num_burnin = 0, + num_mcmc = 0 +) +xbart_model_string <- stochtree::saveBARTModelToJsonString(xbart_model) +``` + +## Python + +```{python} +xbart_model = BARTModel() +xbart_model.sample( + X_train=X_train, leaf_basis_train=leaf_basis_train, y_train=y_train, + num_gfr=num_gfr, num_burnin=0, num_mcmc=0, + general_params={"num_threads": 1}, +) +xbart_model_json = xbart_model.to_json() +``` + +:::: + +Run the multi-chain BART sampler, with each chain initialized from a different GFR +forest. + +::::{.panel-tabset group="language"} + +## R + +```{r} +bart_model <- stochtree::bart( + X_train = X_train, + leaf_basis_train = leaf_basis_train, + y_train = y_train, + num_gfr = num_gfr, + num_burnin = num_burnin, + num_mcmc = num_mcmc, + general_params = list(num_chains = num_chains), + previous_model_json = xbart_model_string, + previous_model_warmstart_sample_num = num_gfr +) +``` + +## Python + +```{python} +bart_model = BARTModel() +bart_model.sample( + X_train=X_train, leaf_basis_train=leaf_basis_train, y_train=y_train, + num_gfr=0, num_burnin=num_burnin, num_mcmc=num_mcmc, + general_params={"num_threads": 1, "num_chains": num_chains}, + previous_model_json=xbart_model_json, + previous_model_warmstart_sample_num=num_gfr - 1, # 0-indexed +) +``` + +:::: + +::::{.panel-tabset group="language"} + +## R + +```{r} +y_hat_test <- predict( + bart_model, + X = X_test, + leaf_basis = leaf_basis_test, + type = "mean", + terms = "y_hat" +) +plot(y_hat_test, y_test, xlab = "Predicted", ylab = "Actual") +abline(0, 1, col = "red", lty = 3, lwd = 3) +``` + +## Python + +```{python} +y_hat_test = bart_model.predict( + X=X_test, leaf_basis=leaf_basis_test, type="mean", terms="y_hat" +) +lo, hi = min(y_hat_test.min(), y_test.min()), max(y_hat_test.max(), y_test.max()) +plt.scatter(y_hat_test, y_test, alpha=0.5) +plt.plot([lo, hi], [lo, hi], color="red", linestyle="dashed", linewidth=2) +plt.xlabel("Predicted"); plt.ylabel("Actual") +plt.show() +``` + +:::: + +::::{.panel-tabset group="language"} + +## R + +```{r} +sigma2_coda_list <- coda::as.mcmc.list(lapply( + 1:num_chains, + function(chain_idx) { + offset <- (chain_idx - 1) * num_mcmc + inds_start <- offset + 1 + inds_end <- offset + num_mcmc + coda::mcmc(bart_model$sigma2_global_samples[inds_start:inds_end]) + } +)) +traceplot(sigma2_coda_list, ylab = expression(sigma^2)) +abline(h = noise_sd^2, col = "black", lty = 3, lwd = 3) +acf <- autocorr.diag(sigma2_coda_list) +ess <- effectiveSize(sigma2_coda_list) +rhat <- gelman.diag(sigma2_coda_list, autoburnin = F) +cat(paste0( + "Average autocorrelation across chains:\n", + paste0(paste0(rownames(acf), ": ", round(acf, 3)), collapse = ", "), + "\nTotal effective sample size across chains: ", + paste0(round(ess, 1), collapse = ", "), + "\n'R-hat' potential scale reduction factor of Gelman and Rubin (1992)): ", + paste0(round(rhat$psrf[, 1], 3), collapse = ", ") +)) +``` + +## Python + +```{python} +idata = az.from_dict({"posterior": {"sigma2": bart_model.global_var_samples.reshape(num_chains, num_mcmc)}}) + +az.plot_trace(idata) +plt.axhline(noise_sd**2, color="black", linestyle="dashed", linewidth=1.5) +plt.show() + +print("ESS: ", az.ess(idata)) +print("R-hat:", az.rhat(idata)) +az.plot_autocorr(idata) +plt.show() +``` + +:::: + +::::{.panel-tabset group="language"} + +## R + +```{r} +coda_array <- as.array(sigma2_coda_list) +dim(coda_array) <- c(nrow(coda_array), ncol(coda_array), 1) +dimnames(coda_array) <- list( + Iteration = paste0("iter", 1:num_mcmc), + Chain = paste0("chain", 1:num_chains), + Parameter = "sigma2_global" +) +``` + +## Python + +```{python} +sigma2_chains = bart_model.global_var_samples.reshape(num_chains, num_mcmc) +``` + +:::: + +::::{.panel-tabset group="language"} + +## R + +```{r} +#| warning: false +#| message: false +bayesplot::mcmc_hist_by_chain( + coda_array, + pars = "sigma2_global" +) + + ggplot2::labs( + title = "Global error scale posterior by chain", + x = expression(sigma^2) + ) + + ggplot2::theme( + plot.title = ggplot2::element_text(hjust = 0.5) + ) + + ggplot2::geom_vline( + xintercept = noise_sd^2, + color = "black", + linetype = "dashed", + size = 1 + ) +``` + +## Python + +```{python} +fig, axes = plt.subplots(1, num_chains, figsize=(12, 3), sharey=True) +for i, ax in enumerate(axes): + ax.hist(sigma2_chains[i], bins=30) + ax.axvline(noise_sd**2, color="black", linestyle="dashed", linewidth=1.5) + ax.set_title(f"Chain {i+1}") + ax.set_xlabel(r"$\sigma^2$") +fig.suptitle("Global error scale posterior by chain") +plt.tight_layout() +plt.show() +``` + +:::: + +## Sampling Multiple Chains in Parallel + +While the above examples used sequential multi-chain sampling internally, it is also +possible to run chains in parallel. In R, this is done via `doParallel` / `foreach`; +in Python, via `concurrent.futures.ProcessPoolExecutor`. In both cases, each chain +is serialized to JSON for cross-process communication, then combined into a single +model via `createBARTModelFromCombinedJsonString()` (R) or +`BARTModel.from_json_string_list()` (Python). + +In order to run multiple parallel stochtree chains in R, a parallel backend must be +registered. Note that we do not evaluate the cluster setup code below in order to +interact nicely with GitHub Actions. + +::::{.panel-tabset group="language"} + +## R + +```{r} +#| eval: false +ncores <- parallel::detectCores() +cl <- makeCluster(ncores) +registerDoParallel(cl) +``` + +## Python + +```{python} +#| eval: false +# Worker function must be defined at module level for pickling +from concurrent.futures import ProcessPoolExecutor + +def _run_bart_chain(args): + X_tr, lb_tr, y_tr, X_te, lb_te, num_burnin, num_mcmc, seed = args + from stochtree import BARTModel + m = BARTModel() + m.sample( + X_train=X_tr, leaf_basis_train=lb_tr, y_train=y_tr, + X_test=X_te, leaf_basis_test=lb_te, + num_gfr=0, num_burnin=num_burnin, num_mcmc=num_mcmc, + general_params={"num_threads": 1, "random_seed": seed}, + mean_forest_params={"sample_sigma2_leaf": False}, + ) + return m.to_json(), m.y_hat_test +``` + +:::: + +::::{.panel-tabset group="language"} + +## R + +```{r} +num_chains <- 4 +num_gfr <- 0 +num_burnin <- 100 +num_mcmc <- 100 +``` + +## Python + +```{python} +num_chains = 4 +num_gfr = 0 +num_burnin = 100 +num_mcmc = 100 +``` + +:::: + +::::{.panel-tabset group="language"} + +## R + +```{r} +bart_model_outputs <- foreach(i = 1:num_chains) %dopar% + { + random_seed <- i + general_params <- list(sample_sigma2_global = T, random_seed = random_seed) + mean_forest_params <- list(sample_sigma2_leaf = F) + bart_model <- stochtree::bart( + X_train = X_train, + leaf_basis_train = leaf_basis_train, + y_train = y_train, + X_test = X_test, + leaf_basis_test = leaf_basis_test, + num_gfr = num_gfr, + num_burnin = num_burnin, + num_mcmc = num_mcmc, + general_params = general_params, + mean_forest_params = mean_forest_params + ) + bart_model_string <- stochtree::saveBARTModelToJsonString(bart_model) + y_hat_test <- bart_model$y_hat_test + list(model = bart_model_string, yhat = y_hat_test) + } +``` + +## Python + +```{python} +# Sequential loop β€” replace the loop body with ProcessPoolExecutor for true parallelism +bart_model_outputs = [] +for i in range(num_chains): + m = BARTModel() + m.sample( + X_train=X_train, leaf_basis_train=leaf_basis_train, y_train=y_train, + X_test=X_test, leaf_basis_test=leaf_basis_test, + num_gfr=0, num_burnin=num_burnin, num_mcmc=num_mcmc, + general_params={"num_threads": 1, "sample_sigma2_global": True, "random_seed": i + 1}, + mean_forest_params={"sample_sigma2_leaf": False}, + ) + bart_model_outputs.append({"model": m.to_json(), "yhat": m.y_hat_test}) +``` + +:::: + +Close the parallel cluster (not evaluated here). + +::::{.panel-tabset group="language"} + +## R + +```{r} +#| eval: false +stopCluster(cl) +``` + +## Python + +```{python} +# No explicit teardown required when using concurrent.futures context manager +``` + +:::: + +Combine the forests from each BART model into a single forest. + +::::{.panel-tabset group="language"} + +## R + +```{r} +bart_model_strings <- list() +bart_model_yhats <- matrix(NA, nrow = length(y_test), ncol = num_chains) +for (i in 1:length(bart_model_outputs)) { + bart_model_strings[[i]] <- bart_model_outputs[[i]]$model + bart_model_yhats[, i] <- rowMeans(bart_model_outputs[[i]]$yhat) +} +combined_bart <- createBARTModelFromCombinedJsonString(bart_model_strings) +``` + +## Python + +```{python} +bart_model_strings = [out["model"] for out in bart_model_outputs] +bart_model_yhats = np.column_stack([ + out["yhat"].mean(axis=1) for out in bart_model_outputs +]) # shape: (n_test, num_chains) +combined_bart = BARTModel() +combined_bart.from_json_string_list(bart_model_strings) +``` + +:::: + +::::{.panel-tabset group="language"} + +## R + +```{r} +yhat_combined <- predict(combined_bart, X_test, leaf_basis_test)$y_hat +``` + +## Python + +```{python} +# type="posterior" (default) returns the full n_test Γ— (num_chains * num_mcmc) matrix +yhat_combined = combined_bart.predict(X=X_test, leaf_basis=leaf_basis_test, terms="y_hat") +``` + +:::: + +Compare average predictions from each chain to the original predictions and to +the true $y$ values. + +::::{.panel-tabset group="language"} + +## R + +```{r} +par(mfrow = c(1, 2)) +for (i in 1:num_chains) { + offset <- (i - 1) * num_mcmc + inds_start <- offset + 1 + inds_end <- offset + num_mcmc + plot( + rowMeans(yhat_combined[, inds_start:inds_end]), + bart_model_yhats[, i], + xlab = "deserialized", + ylab = "original", + main = paste0("Chain ", i, "\nPredictions") + ) + abline(0, 1, col = "red", lty = 3, lwd = 3) +} +par(mfrow = c(1, 1)) +``` + +## Python + +```{python} +fig, axes = plt.subplots(2, 2, figsize=(8, 8)) +for i, ax in enumerate(axes.flat): + chain_combined = yhat_combined[:, i * num_mcmc:(i + 1) * num_mcmc].mean(axis=1) + chain_orig = bart_model_yhats[:, i] + lo = min(chain_combined.min(), chain_orig.min()) + hi = max(chain_combined.max(), chain_orig.max()) + ax.scatter(chain_combined, chain_orig, alpha=0.4, s=10) + ax.plot([lo, hi], [lo, hi], color="red", linestyle="dashed", linewidth=1.5) + ax.set_xlabel("Deserialized"); ax.set_ylabel("Original") + ax.set_title(f"Chain {i+1} Predictions") +plt.tight_layout() +plt.show() +``` + +:::: + +::::{.panel-tabset group="language"} + +## R + +```{r} +par(mfrow = c(1, 2)) +for (i in 1:num_chains) { + offset <- (i - 1) * num_mcmc + inds_start <- offset + 1 + inds_end <- offset + num_mcmc + plot( + rowMeans(yhat_combined[, inds_start:inds_end]), + y_test, + xlab = "predicted", + ylab = "actual", + main = paste0("Chain ", i, "\nPredictions") + ) + abline(0, 1, col = "red", lty = 3, lwd = 3) +} +par(mfrow = c(1, 1)) +``` + +## Python + +```{python} +fig, axes = plt.subplots(2, 2, figsize=(8, 8)) +for i, ax in enumerate(axes.flat): + chain_pred = yhat_combined[:, i * num_mcmc:(i + 1) * num_mcmc].mean(axis=1) + lo = min(chain_pred.min(), y_test.min()) + hi = max(chain_pred.max(), y_test.max()) + ax.scatter(chain_pred, y_test, alpha=0.4, s=10) + ax.plot([lo, hi], [lo, hi], color="red", linestyle="dashed", linewidth=1.5) + ax.set_xlabel("Predicted"); ax.set_ylabel("Actual") + ax.set_title(f"Chain {i+1} Predictions") +plt.tight_layout() +plt.show() +``` + +:::: + +# References diff --git a/vignettes/multivariate-bcf.qmd b/vignettes/multivariate-bcf.qmd new file mode 100644 index 000000000..1e239d1f5 --- /dev/null +++ b/vignettes/multivariate-bcf.qmd @@ -0,0 +1,463 @@ +--- +title: "BCF with Vector-valued Treatments" +execute: + freeze: auto # re-render only when source changes +--- + +```{r} +#| include: false +reticulate::use_python( + Sys.getenv( + "RETICULATE_PYTHON", + unset = file.path(rprojroot::find_root(rprojroot::has_file(".here")), ".venv", "bin", "python") + ), + required = TRUE +) +``` + +BCF extended to vector-valued (multivariate) treatments, estimating heterogeneous effects for multiple treatment arms simultaneously. + +## Background + +When treatments are multivariate β€” such as continuous dose vectors or multiple binary arms β€” the standard BCF model extends to + +$$ +Y_i = \mu(X_i) + \tau(X_i)^\top Z_i + \epsilon_i +$$ + +where $Z_i \in \mathbb{R}^p$ and $\tau(X_i) \in \mathbb{R}^p$ is a vector of covariate-dependent treatment effects. + +## Setup + +Load necessary packages + +::::{.panel-tabset group="language"} + +## R + +```{r} +library(stochtree) +``` + +## Python + +```{python} +import matplotlib.pyplot as plt +import numpy as np +from sklearn.model_selection import train_test_split +from stochtree import BCFModel +``` + +:::: + +Set a seed for reproducibility + +::::{.panel-tabset group="language"} + +## R + +```{r} +random_seed <- 4321 +set.seed(random_seed) +``` + +## Python + +```{python} +random_seed = 4321 +rng = np.random.default_rng(random_seed) +``` + +:::: + +## Data Simulation + +::::{.panel-tabset group="language"} + +## R + +```{r} +# Generate covariates, propensities, and treatments +n <- 1000 +p_X <- 5 +X <- matrix(runif(n * p_X), nrow = n, ncol = p_X) +pi_X <- cbind(0.25 + 0.5 * X[, 1], 0.75 - 0.5 * X[, 2]) +Z <- cbind( + as.numeric(rbinom(n, 1, pi_X[, 1])), + as.numeric(rbinom(n, 1, pi_X[, 2])) +) + +# Define outcome mean functions (prognostic and treatment effects) +mu_X <- pi_X[, 1] * 5 + pi_X[, 2] * 2 + 2 * X[, 3] +tau_X <- cbind(X[, 2], X[, 3]) + +# Generate outcome +treatment_term <- rowSums(tau_X * Z) +y <- mu_X + treatment_term + rnorm(n) +``` + +## Python + +```{python} +# Generate covariates, propensities, and treatments +n = 1000 +p_X = 5 +X = rng.uniform(0, 1, (n, p_X)) +pi_X = np.c_[0.25 + 0.5 * X[:, 0], 0.75 - 0.5 * X[:, 1]] +Z = rng.binomial(1, pi_X, (n, 2)).astype(float) + +# Define the outcome mean functions (prognostic and treatment effects) +mu_X = pi_X[:, 0] * 5 + pi_X[:, 1] * 2 + 2 * X[:, 2] +tau_X = np.stack((X[:, 1], X[:, 2]), axis=-1) + +# Generate outcome +epsilon = rng.normal(0, 1, n) +treatment_term = np.multiply(tau_X, Z).sum(axis=1) +y = mu_X + treatment_term + epsilon +``` + +:::: + +Split the data into train and test sets + +::::{.panel-tabset group="language"} + +## R + +```{r} +n_test <- round(n * 0.2) +test_inds <- sort(sample(seq_len(n), n_test, replace = FALSE)) +train_inds <- setdiff(seq_len(n), test_inds) +X_train <- X[train_inds, ] +X_test <- X[test_inds, ] +Z_train <- Z[train_inds, ] +Z_test <- Z[test_inds, ] +y_train <- y[train_inds] +y_test <- y[test_inds] +pi_train <- pi_X[train_inds, ] +pi_test <- pi_X[test_inds, ] +mu_train <- mu_X[train_inds] +mu_test <- mu_X[test_inds] +tau_train <- tau_X[train_inds, ] +tau_test <- tau_X[test_inds, ] +``` + +## Python + +```{python} +sample_inds = np.arange(n) +train_inds, test_inds = train_test_split(sample_inds, test_size=0.2) +X_train = X[train_inds, :] +X_test = X[test_inds, :] +Z_train = Z[train_inds, :] +Z_test = Z[test_inds, :] +y_train = y[train_inds] +y_test = y[test_inds] +pi_train = pi_X[train_inds] +pi_test = pi_X[test_inds] +mu_train = mu_X[train_inds] +mu_test = mu_X[test_inds] +tau_train = tau_X[train_inds, :] +tau_test = tau_X[test_inds, :] +``` + +:::: + +## Model Fitting + +Fit a multivariate BCF model + +::::{.panel-tabset group="language"} + +## R + +```{r} +general_params <- list( + num_threads = 1, + num_chains = 4, + random_seed = random_seed, + adaptive_coding = FALSE +) +bcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + num_gfr = 10, + num_burnin = 500, + num_mcmc = 100, + general_params = general_params +) +``` + +## Python + +```{python} +general_params = { + "num_threads": 1, + "num_chains": 4, + "random_seed": random_seed, + "adaptive_coding": False +} +bcf_model = BCFModel() +bcf_model.sample( + X_train=X_train, + Z_train=Z_train, + y_train=y_train, + propensity_train=pi_train, + num_gfr=10, + num_burnin=500, + num_mcmc=100, + general_params=general_params, +) +``` + +:::: + +## Posterior Summaries + +Compare true outcomes to predicted conditional means + +::::{.panel-tabset group="language"} + +## R + +```{r} +y_hat_test <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + terms = "y_hat", + type = "mean" +) +plot( + y_hat_test, + y_test, + xlab = "Average estimated outcome", + ylab = "True outcome" +) +abline(0, 1, col = "black", lty = 3) +rmse <- sqrt(mean((y_hat_test - y_test)^2)) +cat("Test-set RMSE: ", rmse, "\n") +``` + +## Python + +```{python} +y_hat_test = bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_test, terms="y_hat", type="mean" +) +lo, hi = min(y_hat_test.min(), y_test.min()), max(y_hat_test.max(), y_test.max()) +plt.scatter(y_hat_test, y_test, alpha=0.5) +plt.plot([lo, hi], [lo, hi], color="red", linestyle="dashed", linewidth=2) +plt.xlabel("Predicted") +plt.ylabel("Actual") +plt.title("Outcome") +plt.show() +rmse = np.sqrt(np.mean(np.power(y_hat_test - y_test, 2))) +print(f"Test-set RMSE: {rmse:.2f}") +``` + +:::: + +Compare true versus estimated treatment effects for each treatment entry + +::::{.panel-tabset group="language"} + +## R + +```{r} +tau_hat_test <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + terms = "cate", + type = "mean" +) +plot( + tau_test[, 1], + tau_hat_test[, 1], + xlab = "True tau", + ylab = "Average estimated tau", + main = "Treatment 1" +) +abline(0, 1, col = "black", lty = 3) +``` + +```{r} +plot( + tau_test[, 2], + tau_hat_test[, 2], + xlab = "True tau", + ylab = "Average estimated tau", + main = "Treatment 2" +) +abline(0, 1, col = "black", lty = 3) +``` + +## Python + +```{python} +tau_hat_test = bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_test, terms="cate", type="mean" +) +treatment_idx = 0 +lo, hi = ( + min((tau_hat_test[:, treatment_idx]).min(), (tau_test[:, treatment_idx]).min()), + max((tau_hat_test[:, treatment_idx]).max(), (tau_test[:, treatment_idx]).max()), +) +plt.scatter(tau_test[:, treatment_idx], tau_hat_test[:, treatment_idx], alpha=0.5) +plt.plot([lo, hi], [lo, hi], color="red", linestyle="dashed", linewidth=2) +plt.xlabel("True tau") +plt.ylabel("Average estimated tau") +plt.title(f"Treatment {treatment_idx + 1}") +plt.show() +``` + +```{python} +treatment_idx = 1 +lo, hi = ( + min((tau_hat_test[:, treatment_idx]).min(), (tau_test[:, treatment_idx]).min()), + max((tau_hat_test[:, treatment_idx]).max(), (tau_test[:, treatment_idx]).max()), +) +plt.scatter(tau_test[:, treatment_idx], tau_hat_test[:, treatment_idx], alpha=0.5) +plt.plot([lo, hi], [lo, hi], color="red", linestyle="dashed", linewidth=2) +plt.xlabel("True tau") +plt.ylabel("Average estimated tau") +plt.title(f"Treatment {treatment_idx + 1}") +plt.show() +``` + +:::: + +Now compare the true versus estimated treatment terms of the model (i.e. $t_i = \sum_j(\tau_{i,j}(X) * Z_{i,j})$ where $i$ indexes observations and $j$ indexes treatments) + +::::{.panel-tabset group="language"} + +## R + +```{r} +tau_hat_test <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + terms = "cate", + type = "posterior" +) +treatment_term_mcmc <- apply(tau_hat_test, 3, function(tau_s) { + rowSums(tau_s * Z_test) +}) +true_treatment_term <- rowSums(tau_test * Z_test) +plot( + true_treatment_term, + rowMeans(treatment_term_mcmc), + xlab = "True treatment term", + ylab = "Average estimated treatment term" +) +abline(0, 1, col = "black", lty = 3) +``` + +## Python + +```{python} +tau_hat_test = bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_test, terms="cate", type="posterior" +) +treatment_term_mcmc_test = np.multiply( + np.atleast_3d(Z_test).swapaxes(1, 2), tau_hat_test +).sum(axis=2) +treatment_term_test = np.multiply(tau_test, Z_test).sum(axis=1) +treatment_term_hat_test = np.squeeze(treatment_term_mcmc_test).mean( + axis=1, keepdims=True +) +lo, hi = ( + min((treatment_term_hat_test).min(), (treatment_term_test).min()), + max((treatment_term_hat_test).max(), (treatment_term_test).max()), +) +plt.scatter(treatment_term_test, treatment_term_hat_test, alpha=0.5) +plt.plot([lo, hi], [lo, hi], color="red", linestyle="dashed", linewidth=2) +plt.xlabel("True value") +plt.ylabel("Average estimated value") +plt.title("Treatment Term") +plt.show() +``` + +:::: + +Compare true and predicted prognostic function values + +::::{.panel-tabset group="language"} + +## R + +```{r} +mu_hat_test <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + terms = "prognostic_function", + type = "mean" +) +plot( + mu_test, + mu_hat_test, + xlab = "True value", + ylab = "Average estimated value", + main = "Prognostic Function" +) +abline(0, 1, col = "black", lty = 3) +``` + +## Python + +```{python} +mu_hat_test = bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_test, terms="prognostic_function", type="mean" +) +lo, hi = ( + min((mu_hat_test).min(), (mu_test).min()), + max((mu_hat_test).max(), (mu_test).max()), +) +plt.scatter(mu_hat_test, mu_test, alpha=0.5) +plt.plot([lo, hi], [lo, hi], color="red", linestyle="dashed", linewidth=2) +plt.xlabel("True value") +plt.ylabel("Average estimated value") +plt.title("Prognostic Function") +plt.show() +``` + +:::: + +Finally, we inspect the traceplot of the global error variance, $\sigma^2$ + +::::{.panel-tabset group="language"} + +## R + +```{r} +sigma2_global_samples <- extractParameter(bcf_model, "sigma2_global") +plot( + sigma2_global_samples, + xlab = "Sample", + ylab = expression(sigma^2) +) +abline(h = 1, lty = 3, lwd = 3, col = "blue") +``` + +## Python + +```{python} +global_var_samples = bcf_model.extract_parameter("sigma2_global") +plt.plot(global_var_samples) +plt.axhline(1, color="blue", linestyle="dashed", linewidth=2) +plt.xlabel("Sample") +plt.ylabel(r"$\sigma^2$") +plt.title("Global variance parameter") +plt.show() +``` + +:::: diff --git a/vignettes/ordinal-outcome.qmd b/vignettes/ordinal-outcome.qmd new file mode 100644 index 000000000..b39bd59f7 --- /dev/null +++ b/vignettes/ordinal-outcome.qmd @@ -0,0 +1,528 @@ +--- +title: "BART with the Complementary Log-Log Link for Ordinal Outcomes" +bibliography: vignettes.bib +execute: + freeze: auto # re-render only when source changes +--- + +```{r} +#| include: false +reticulate::use_python( + Sys.getenv( + "RETICULATE_PYTHON", + unset = file.path(rprojroot::find_root(rprojroot::has_file(".here")), ".venv", "bin", "python") + ), + required = TRUE +) +``` + +This vignette demonstrates how to use BART to model ordinal outcomes with a +complementary log-log (cloglog) link function (@alam2025unified). + +## Introduction to Ordinal BART with CLogLog Link + +Ordinal data refers to outcomes that have a natural ordering but undefined distances +between categories. Examples include survey responses (strongly disagree, disagree, +neutral, agree, strongly agree), severity ratings (mild, moderate, severe), or +educational levels (elementary, high school, college, graduate). + +The complementary log-log (cloglog) model uses the cumulative link function +$$ +\text{cloglog}(p) = \log(-\log(1-p)) +$$ +to express cumulative category probabilities as a function of covariates +$$ +\text{cloglog}(P(Y \leq k \mid X = x)) = \log(-\log(1-P(Y \leq k \mid X = x))) = \gamma_k + \lambda(x) +$$ + +This link function is asymmetric and particularly appropriate when the probability of +being in higher categories changes rapidly at certain thresholds, making it different +from the symmetric probit or logit links commonly used in ordinal regression. + +In `stochtree`, we let $\lambda(x)$ be represented by a stochastic tree ensemble. + +## Setup + +::::{.panel-tabset group="language"} + +## R + +```{r} +library(stochtree) +``` + +## Python + +```{python} +import numpy as np +import matplotlib.pyplot as plt +from stochtree import BARTModel, OutcomeModel +``` + +:::: + +## Data Simulation + +We simulate a dataset with an ordinal outcome with three categories, +$y_i \in \left\{1,2,3\right\}$ whose probabilities depend on covariates, $X$. + +::::{.panel-tabset group="language"} + +## R + +```{r} +# Set seed +random_seed <- 2026 +set.seed(random_seed) + +# Sample size and number of predictors +n <- 2000 +p <- 5 + +# Design matrix and true lambda function +X <- matrix(rnorm(n * p), n, p) +beta <- rep(1 / sqrt(p), p) +true_lambda_function <- X %*% beta + +# Set cutpoints for ordinal categories (3 categories: 1, 2, 3) +n_categories <- 3 +gamma_true <- c(-2, 1) +ordinal_cutpoints <- log(cumsum(exp(gamma_true))) + +# True ordinal class probabilities +true_probs <- matrix(0, nrow = n, ncol = n_categories) +for (j in 1:n_categories) { + if (j == 1) { + true_probs[, j] <- 1 - exp(-exp(gamma_true[j] + true_lambda_function)) + } else if (j == n_categories) { + true_probs[, j] <- 1 - rowSums(true_probs[, 1:(j - 1), drop = FALSE]) + } else { + true_probs[, j] <- exp(-exp(gamma_true[j - 1] + true_lambda_function)) * + (1 - exp(-exp(gamma_true[j] + true_lambda_function))) + } +} + +# Generate ordinal outcomes +y <- sapply(1:nrow(X), function(i) { + sample(1:n_categories, 1, prob = true_probs[i, ]) +}) +cat("Outcome distribution:", table(y), "\n") + +# Train test split +train_idx <- sample(1:n, size = floor(0.8 * n)) +test_idx <- setdiff(1:n, train_idx) +X_train <- X[train_idx, ] +y_train <- y[train_idx] +X_test <- X[test_idx, ] +y_test <- y[test_idx] +``` + +## Python + +```{python} +random_seed = 2026 +rng = np.random.default_rng(random_seed) + +# Sample size and number of predictors +n = 2000 +p = 5 + +# Design matrix and true lambda function +X = rng.standard_normal((n, p)) +beta = np.ones(p) / np.sqrt(p) +true_lambda = X @ beta + +# Set cutpoints for ordinal categories (3 categories: 1, 2, 3) +n_categories = 3 +gamma_true = np.array([-2.0, 1.0]) + +# True ordinal class probabilities +true_probs = np.zeros((n, n_categories)) +true_probs[:, 0] = 1 - np.exp(-np.exp(gamma_true[0] + true_lambda)) +for j in range(1, n_categories - 1): + true_probs[:, j] = ( + np.exp(-np.exp(gamma_true[j - 1] + true_lambda)) + * (1 - np.exp(-np.exp(gamma_true[j] + true_lambda))) + ) +true_probs[:, n_categories - 1] = 1 - true_probs[:, :-1].sum(axis=1) + +# Generate ordinal outcomes (1-indexed integers) +y = np.array( + [rng.choice(np.arange(1, n_categories + 1), p=true_probs[i]) for i in range(n)], + dtype=float, +) +unique, counts = np.unique(y, return_counts=True) +print("Outcome distribution:", dict(zip(unique.astype(int), counts))) + +# Train-test split +n_test = round(0.2 * n) +n_train = n - n_test +test_inds = rng.choice(n, n_test, replace=False) +train_inds = np.setdiff1d(np.arange(n), test_inds) +X_train = X[train_inds] +X_test = X[test_inds] +y_train = y[train_inds] +y_test = y[test_inds] +``` + +:::: + +## Model Fitting + +We specify the cloglog link function for modeling an ordinal outcome by setting +`outcome_model=OutcomeModel(outcome="ordinal", link="cloglog")` in the +`general_params` argument list. Since ordinal outcomes are incompatible with the +Gaussian global error variance model, we also set `sample_sigma2_global=FALSE`. + +We also override the default `num_trees` for the mean forest (200) in favor of +greater regularization for the ordinal model and set `sample_sigma2_leaf=FALSE`. + +::::{.panel-tabset group="language"} + +## R + +```{r} +# Sample the cloglog ordinal BART model +bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 0, + num_burnin = 1000, + num_mcmc = 1000, + general_params = list( + cutpoint_grid_size = 100, + sample_sigma2_global = FALSE, + keep_every = 1, + num_chains = 1, + verbose = FALSE, + random_seed = random_seed, + outcome_model = OutcomeModel(outcome = 'ordinal', link = 'cloglog') + ), + mean_forest_params = list(num_trees = 50, sample_sigma2_leaf = FALSE) +) +``` + +## Python + +```{python} +bart_model = BARTModel() +bart_model.sample( + X_train=X_train, + y_train=y_train, + X_test=X_test, + num_gfr=0, + num_burnin=1000, + num_mcmc=1000, + general_params={ + "num_threads": 1, + "cutpoint_grid_size": 100, + "sample_sigma2_global": False, + "keep_every": 1, + "num_chains": 1, + "random_seed": random_seed, + "outcome_model": OutcomeModel(outcome="ordinal", link="cloglog"), + }, + mean_forest_params={"num_trees": 50, "sample_sigma2_leaf": False}, +) +``` + +:::: + +## Prediction + +As with any other BART model in `stochtree`, we can use the `predict` function on +our ordinal model. Specifying `scale = "linear"` and `terms = "y_hat"` will simply +return predictions from the estimated $\lambda(x)$ function, but users can estimate +class probabilities via `scale = "probability"`, which by default returns an array of +dimension (`num_observations`, `num_categories`, `num_samples`). Specifying +`type = "mean"` collapses the output to a `num_observations` x `num_categories` +matrix with the average posterior class probability for each observation. Users can +also specify `type = "class"` for the maximum a posteriori (MAP) class label estimate +for each draw of each observation. + +Below we compute the posterior class probabilities for the train and test sets. + +::::{.panel-tabset group="language"} + +## R + +```{r} +est_probs_train <- predict( + bart_model, + X = X_train, + scale = "probability", + terms = "y_hat" +) +est_probs_test <- predict( + bart_model, + X = X_test, + scale = "probability", + terms = "y_hat" +) +``` + +## Python + +```{python} +# predict returns (n_obs, n_categories) posterior mean class probabilities +est_probs_train = bart_model.predict(X=X_train, scale="probability", terms="y_hat", type="mean") +est_probs_test = bart_model.predict(X=X_test, scale="probability", terms="y_hat", type="mean") +``` + +:::: + +## Model Results and Interpretation + +Since one of the "cutpoints" is fixed for identifiability, we plot the posterior +distributions of the other two cutpoints and compare them to their true simulated +values (blue dotted lines). + +The cutpoint samples are accessed via `extractParameter(bart_model, "cloglog_cutpoints")` +(shape: `(n_categories - 1, num_samples)`) and are shifted by the per-sample mean of +the training predictions to account for the non-identifiable intercept. + +::::{.panel-tabset group="language"} + +## R + +```{r} +y_hat_train_post <- predict( + bart_model, + X = X_train, + scale = "linear", + terms = "y_hat", + type = "posterior" +) +cutpoint_samples <- extractParameter(bart_model, "cloglog_cutpoints") +gamma1 <- cutpoint_samples[1, ] + colMeans(y_hat_train_post) +hist( + gamma1, + main = "Posterior Distribution of Cutpoint 1", + xlab = "Cutpoint 1", + freq = FALSE +) +abline(v = gamma_true[1], col = 'blue', lty = 3, lwd = 3) +gamma2 <- cutpoint_samples[2, ] + colMeans(y_hat_train_post) +hist( + gamma2, + main = "Posterior Distribution of Cutpoint 2", + xlab = "Cutpoint 2", + freq = FALSE +) +abline(v = gamma_true[2], col = 'blue', lty = 3, lwd = 3) +``` + +## Python + +```{python} +# cutpoint_samples shape: (n_categories - 1, num_samples) +# shifted by per-sample mean of train predictions to remove non-identifiable intercept +cutpoint_samples = bart_model.extract_parameter("cloglog_cutpoints") +y_hat_train_post = bart_model.predict(X=X_train, scale="linear", terms="y_hat", type="posterior") +gamma1 = cutpoint_samples[0, :] + y_hat_train_post.mean(axis=0) +gamma2 = cutpoint_samples[1, :] + y_hat_train_post.mean(axis=0) + +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) +ax1.hist(gamma1, density=True, bins=40) +ax1.axvline(gamma_true[0], color="blue", linestyle="dotted", linewidth=2) +ax1.set_title("Posterior Distribution of Cutpoint 1") +ax1.set_xlabel("Cutpoint 1") +ax2.hist(gamma2, density=True, bins=40) +ax2.axvline(gamma_true[1], color="blue", linestyle="dotted", linewidth=2) +ax2.set_title("Posterior Distribution of Cutpoint 2") +ax2.set_xlabel("Cutpoint 2") +plt.tight_layout() +plt.show() +``` + +:::: + +We can compare the true value of the latent "utility function" $\lambda(x)$ to the +(mean-shifted) BART forest predictions. + +::::{.panel-tabset group="language"} + +## R + +```{r} +# Train set predicted versus actual +y_hat_train <- predict( + bart_model, + X = X_train, + scale = "linear", + terms = "y_hat", + type = "mean" +) +lambda_pred_train <- y_hat_train - mean(y_hat_train) +plot( + lambda_pred_train, + true_lambda_function[train_idx], + main = "Train Set: Predicted vs Actual", + xlab = "Predicted", + ylab = "Actual" +) +abline(a = 0, b = 1, col = 'blue', lwd = 2) +cor_train <- cor(true_lambda_function[train_idx], lambda_pred_train) +text( + min(lambda_pred_train), + max(true_lambda_function[train_idx]), + paste('Correlation:', round(cor_train, 3)), + adj = 0, + col = 'red' +) + +# Test set predicted versus actual +y_hat_test <- predict( + bart_model, + X = X_test, + scale = "linear", + terms = "y_hat", + type = "mean" +) +lambda_pred_test <- y_hat_test - mean(y_hat_test) +plot( + lambda_pred_test, + true_lambda_function[test_idx], + main = "Test Set: Predicted vs Actual", + xlab = "Predicted", + ylab = "Actual" +) +abline(a = 0, b = 1, col = 'blue', lwd = 2) +cor_test <- cor(true_lambda_function[test_idx], lambda_pred_test) +text( + min(lambda_pred_test), + max(true_lambda_function[test_idx]), + paste('Correlation:', round(cor_test, 3)), + adj = 0, + col = 'red' +) +``` + +## Python + +```{python} +y_hat_train = bart_model.predict(X=X_train, scale="linear", terms="y_hat", type="mean") +y_hat_test = bart_model.predict(X=X_test, scale="linear", terms="y_hat", type="mean") +lambda_pred_train = y_hat_train - y_hat_train.mean() +lambda_pred_test = y_hat_test - y_hat_test.mean() +corr_train = np.corrcoef(true_lambda[train_inds], lambda_pred_train)[0, 1] +corr_test = np.corrcoef(true_lambda[test_inds], lambda_pred_test)[0, 1] + +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) +ax1.scatter(lambda_pred_train, true_lambda[train_inds], alpha=0.3, s=10) +ax1.axline((0, 0), slope=1, color="blue", linewidth=2) +ax1.set_title("Train Set: Predicted vs Actual") +ax1.set_xlabel("Predicted") +ax1.set_ylabel("Actual") +ax1.text(0.05, 0.95, f"Correlation: {corr_train:.3f}", transform=ax1.transAxes, + color="red", verticalalignment="top") +ax2.scatter(lambda_pred_test, true_lambda[test_inds], alpha=0.3, s=10) +ax2.axline((0, 0), slope=1, color="blue", linewidth=2) +ax2.set_title("Test Set: Predicted vs Actual") +ax2.set_xlabel("Predicted") +ax2.set_ylabel("Actual") +ax2.text(0.05, 0.95, f"Correlation: {corr_test:.3f}", transform=ax2.transAxes, + color="red", verticalalignment="top") +plt.tight_layout() +plt.show() +``` + +:::: + +Finally, we compare the estimated class probabilities with their true simulated values +for each class on the training set. + +::::{.panel-tabset group="language"} + +## R + +```{r} +for (j in 1:n_categories) { + mean_probs <- rowMeans(est_probs_train[, j, ]) + plot( + true_probs[train_idx, j], + mean_probs, + main = paste("Training Set: True vs Estimated Probability, Class", j), + xlab = "True Class Probability", + ylab = "Estimated Class Probability" + ) + abline(a = 0, b = 1, col = 'blue', lwd = 2) + cor_train_prob <- cor(true_probs[train_idx, j], mean_probs) + text( + min(true_probs[train_idx, j]), + max(mean_probs), + paste('Correlation:', round(cor_train_prob, 3)), + adj = 0, + col = 'red' + ) +} +``` + +## Python + +```{python} +fig, axes = plt.subplots(1, n_categories, figsize=(15, 5)) +for j in range(n_categories): + corr = np.corrcoef(true_probs[train_inds, j], est_probs_train[:, j])[0, 1] + axes[j].scatter(true_probs[train_inds, j], est_probs_train[:, j], alpha=0.3, s=10) + axes[j].axline((0, 0), slope=1, color="blue", linewidth=2) + axes[j].set_title(f"Training Set: True vs Estimated Probability, Class {j + 1}") + axes[j].set_xlabel("True Class Probability") + axes[j].set_ylabel("Estimated Class Probability") + axes[j].text(0.05, 0.95, f"Correlation: {corr:.3f}", transform=axes[j].transAxes, + color="red", verticalalignment="top") +plt.tight_layout() +plt.show() +``` + +:::: + +And the same comparison on the test set. + +::::{.panel-tabset group="language"} + +## R + +```{r} +for (j in 1:n_categories) { + mean_probs <- rowMeans(est_probs_test[, j, ]) + plot( + true_probs[test_idx, j], + mean_probs, + main = paste("Test Set: True vs Estimated Probability, Class", j), + xlab = "True Class Probability", + ylab = "Estimated Class Probability" + ) + abline(a = 0, b = 1, col = 'blue', lwd = 2) + cor_test_prob <- cor(true_probs[test_idx, j], mean_probs) + text( + min(true_probs[test_idx, j]), + max(mean_probs), + paste('Correlation:', round(cor_test_prob, 3)), + adj = 0, + col = 'red' + ) +} +``` + +## Python + +```{python} +fig, axes = plt.subplots(1, n_categories, figsize=(15, 5)) +for j in range(n_categories): + corr = np.corrcoef(true_probs[test_inds, j], est_probs_test[:, j])[0, 1] + axes[j].scatter(true_probs[test_inds, j], est_probs_test[:, j], alpha=0.3, s=10) + axes[j].axline((0, 0), slope=1, color="blue", linewidth=2) + axes[j].set_title(f"Test Set: True vs Estimated Probability, Class {j + 1}") + axes[j].set_xlabel("True Class Probability") + axes[j].set_ylabel("Estimated Class Probability") + axes[j].text(0.05, 0.95, f"Correlation: {corr:.3f}", transform=axes[j].transAxes, + color="red", verticalalignment="top") +plt.tight_layout() +plt.show() +``` + +:::: + +# References diff --git a/vignettes/prior-calibration.qmd b/vignettes/prior-calibration.qmd new file mode 100644 index 000000000..a518169b4 --- /dev/null +++ b/vignettes/prior-calibration.qmd @@ -0,0 +1,277 @@ +--- +title: "Calibrating Leaf Node Scale Parameter Priors" +bibliography: vignettes.bib +execute: + freeze: auto # re-render only when source changes +--- + +```{r} +#| include: false +reticulate::use_python( + Sys.getenv( + "RETICULATE_PYTHON", + unset = file.path(rprojroot::find_root(rprojroot::has_file(".here")), ".venv", "bin", "python") + ), + required = TRUE +) +``` + +This vignette demonstrates prior calibration approaches for the parametric components +of stochastic tree ensembles (@chipman2010bart). + +# Background + +The "classic" BART model of @chipman2010bart + +\begin{equation*} +\begin{aligned} +y &= f(X) + \epsilon\\ +f(X) &\sim \text{BART}\left(\alpha, \beta\right)\\ +\epsilon &\sim \mathcal{N}\left(0,\sigma^2\right)\\ +\sigma^2 &\sim \text{IG}\left(a,b\right) +\end{aligned} +\end{equation*} + +is semiparametric, with a nonparametric tree ensemble $f(X)$ and a homoskedastic error +variance parameter $\sigma^2$. Note that in @chipman2010bart, $a$ and $b$ are +parameterized with $a = \frac{\nu}{2}$ and $b = \frac{\nu\lambda}{2}$. + +# Setting Priors on Variance Parameters in `stochtree` + +By default, `stochtree` employs a Jeffreys' prior for $\sigma^2$ +\begin{equation*} +\begin{aligned} +\sigma^2 &\propto \frac{1}{\sigma^2} +\end{aligned} +\end{equation*} +which corresponds to an improper prior with $a = 0$ and $b = 0$. + +We provide convenience functions for users wishing to set the $\sigma^2$ prior as in +@chipman2010bart. In this case, $\nu$ is set by default to 3 and $\lambda$ is +calibrated as follows: + +1. An "overestimate," $\hat{\sigma}^2$, of $\sigma^2$ is obtained via simple linear + regression of $y$ on $X$ +2. $\lambda$ is chosen to ensure that $p(\sigma^2 < \hat{\sigma}^2) = q$ for some value + $q$, typically set to a default value of 0.9. + +# Setup + +Load the necessary packages + +:::{.panel-tabset group="language"} + +## R + +```{r} +#| message: false +library(stochtree) +``` + +## Python + +```{python} +import numpy as np +import matplotlib.pyplot as plt +from stochtree import BARTModel, calibrate_global_error_variance +``` + +::: + +Set a seed for reproducibility + +:::{.panel-tabset group="language"} + +## R + +```{r} +#| message: false +random_seed <- 1234 +set.seed(random_seed) +``` + +## Python + +```{python} +random_seed = 1234 +rng = np.random.default_rng(random_seed) +``` + +::: + +# Data Generation + +Generate data for a straightforward supervised learning problem + +::::{.panel-tabset group="language"} + +## R + +```{r} +n <- 500 +p <- 5 +X <- matrix(runif(n*p), ncol = p) +f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) +) +noise_sd <- 1 +y <- f_XW + rnorm(n, 0, noise_sd) +``` + +## Python + +```{python} +n = 500 +p = 5 +X = rng.uniform(size=(n, p)) +f_XW = ( + ((X[:, 0] >= 0) & (X[:, 0] < 0.25)) * (-7.5) + + ((X[:, 0] >= 0.25) & (X[:, 0] < 0.5)) * (-2.5) + + ((X[:, 0] >= 0.5) & (X[:, 0] < 0.75)) * (2.5) + + ((X[:, 0] >= 0.75) & (X[:, 0] < 1.0)) * (7.5) +) +noise_sd = 1.0 +y = f_XW + rng.normal(0, noise_sd, n) +``` + +:::: + +Split into train and test set + +::::{.panel-tabset group="language"} + +## R + +```{r} +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds,] +X_train <- X[train_inds,] +y_test <- y[test_inds] +y_train <- y[train_inds] +``` + +## Python + +```{python} +test_set_pct = 0.2 +n_test = round(test_set_pct * n) +n_train = n - n_test +test_inds = rng.choice(n, n_test, replace=False) +train_inds = np.setdiff1d(np.arange(n), test_inds) +X_test = X[test_inds] +X_train = X[train_inds] +y_test = y[test_inds] +y_train = y[train_inds] +``` + +:::: + +# Model Sampling + +First, we calibrate the scale parameter for the variance term as in Chipman et al (2010) + +::::{.panel-tabset group="language"} + +## R + +```{r} +nu <- 3 +lambda <- calibrateInverseGammaErrorVariance(y_train, X_train, nu = nu) +``` + +## Python + +```{python} +nu = 3 +lambda_ = calibrate_global_error_variance(X_train, y_train, nu=nu) +``` + +:::: + +Then, we run a BART model with this variance parameterization + +::::{.panel-tabset group="language"} + +## R + +```{r} +general_params <- list(sigma2_global_shape = nu/2, sigma2_global_scale = (nu*lambda)/2) +bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, + num_gfr = 0, num_burnin = 1000, num_mcmc = 100, + general_params = general_params) +``` + +## Python + +```{python} +bart_model = BARTModel() +bart_model.sample( + X_train=X_train, y_train=y_train, X_test=X_test, + num_gfr=0, num_burnin=1000, num_mcmc=100, + general_params={ + "num_threads": 1, + "sigma2_global_shape": nu / 2, + "sigma2_global_scale": (nu * lambda_) / 2, + }, +) +``` + +:::: + +Inspect the out-of-sample predictions of the model + +::::{.panel-tabset group="language"} + +## R + +```{r} +plot(rowMeans(bart_model$y_hat_test), y_test, xlab = "predicted", ylab = "actual") +abline(0,1,col="red",lty=3,lwd=3) +``` + +## Python + +```{python} +pred_mean = bart_model.y_hat_test.mean(axis=1) +lo = min(pred_mean.min(), y_test.min()) +hi = max(pred_mean.max(), y_test.max()) +plt.scatter(pred_mean, y_test, alpha=0.5) +plt.plot([lo, hi], [lo, hi], color="red", linestyle="dashed", linewidth=2) +plt.xlabel("Predicted") +plt.ylabel("Actual") +plt.show() +``` + +:::: + +Inspect the posterior samples of $\sigma^2$ + +::::{.panel-tabset group="language"} + +## R + +```{r} +plot(bart_model$sigma2_global_samples, ylab = "sigma^2", xlab = "iteration") +abline(h = noise_sd^2, col = "red", lty = 3, lwd = 3) +``` + +## Python + +```{python} +plt.plot(bart_model.global_var_samples) +plt.xlabel("Iteration") +plt.ylabel(r"$\sigma^2$") +plt.axhline(noise_sd**2, color="red", linestyle="dashed", linewidth=2) +plt.show() +``` + +:::: + +# References diff --git a/vignettes/rdd.qmd b/vignettes/rdd.qmd new file mode 100644 index 000000000..44253e1bc --- /dev/null +++ b/vignettes/rdd.qmd @@ -0,0 +1,564 @@ +--- +title: "Regression Discontinuity Design (RDD) with StochTree" +bibliography: vignettes.bib +execute: + freeze: auto # re-render only when source changes +--- + +::: {.hidden} +$$ +\newcommand{\ind}{\perp \!\!\! \perp} +\newcommand{\B}{\mathcal{B}} +\newcommand{\res}{\mathbf{r}} +\newcommand{\m}{\mathbf{m}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\N}{\mathrm{N}} +\newcommand{\w}{\mathrm{w}} +\newcommand{\iidsim}{\stackrel{\mathrm{iid}}{\sim}} +\newcommand{\V}{\mathbb{V}} +\newcommand{\F}{\mathbf{F}} +\newcommand{\Y}{\mathbf{Y}} +$$ +::: + +```{r} +#| include: false +reticulate::use_python( + Sys.getenv( + "RETICULATE_PYTHON", + unset = file.path(rprojroot::find_root(rprojroot::has_file(".here")), ".venv", "bin", "python") + ), + required = TRUE +) +``` + +# Introduction + +We study conditional average treatment effect (CATE) estimation for regression +discontinuity designs (RDD), in which treatment assignment is based on whether a +particular covariate β€” referred to as the running variable β€” lies above or below a +known value, referred to as the cutoff value. Because treatment is deterministically +assigned as a known function of the running variable, RDDs are trivially deconfounded: +treatment assignment is independent of the outcome variable, given the running variable. +However, estimation of treatment effects in RDDs is more complicated than simply +controlling for the running variable, because doing so introduces a complete lack of +overlap. Nonetheless, the CATE _at the cutoff_, $X=c$, may still be identified +provided the conditional expectation $E[Y \mid X,W]$ is continuous at that point for +_all_ $W=w$. We exploit this assumption with the leaf regression BART model implemented +in stochtree, which allows us to define an explicit prior on the CATE. + +# Regression Discontinuity Design + +We conceptualize the treatment effect estimation problem via a quartet of random +variables $(Y, X, Z, U)$. The variable $Y$ is the outcome variable; $X$ is the running +variable; $Z$ is the treatment assignment indicator; and $U$ represents additional, +possibly unobserved, causal factors. What makes this an RDD is the stipulation that +$Z = I(X > c)$ for cutoff $c$. We assume $c = 0$ without loss of generality. + +The following figure depicts a causal diagram representing the assumed causal +relationships between these variables. Two key features are: (1) $X$ blocks the +impact of $U$ on $Z$, satisfying the back-door criterion; and (2) $X$ and $U$ are +not descendants of $Z$. + +![A causal directed acyclic graph representing the general structure of a regression discontinuity design problem.](R/RDD/RDD_DAG.png){width=40% fig-align="center"} + +Using this causal diagram, we may express $Y$ as some function of its graph parents +$(X,Z,U)$: $Y = \F(X,Z,U)$. We relate this to the potential outcomes framework via + +$$ +Y^1 = \F(X,1,U), \qquad Y^0 = \F(X,0,U). +$$ + +Defining conditional expectations +$$ +\mu_1(x) = E[Y \mid X=x, Z=1], \qquad \mu_0(x) = E[Y \mid X=x, Z=0], +$$ +the treatment effect function is $\tau(x) = \mu_1(x) - \mu_0(x)$. Because $Z = I(X > 0)$, +we can only learn $\mu_1(x)$ for $X > 0$ and $\mu_0(x)$ for $X < 0$. Overlap is +violated, so the overall ATE $\bar{\tau} = E(\tau(X))$ is unidentified. We instead +estimate $\tau(0) = \mu_1(0) - \mu_0(0)$, which is identified for continuous $X$ +under the assumption that $\mu_1$ and $\mu_0$ are suitably smooth at $x = 0$. + +## Conditional Average Treatment Effects in RDD + +We are concerned with learning not only $\tau(0)$ but also RDD CATEs, +$\tau(0, \w)$ for covariate vector $\w$. Defining potential outcome means + +$$ +\mu_z(x,\w) = E[Y \mid X=x, W=\w, Z=z], +$$ + +our treatment effect function is $\tau(x,\w) = \mu_1(x,\w) - \mu_0(x,\w)$. We +must assume $\mu_1(x,\w)$ and $\mu_0(x,\w)$ are suitably smooth in $x$ for every $\w$. +CATE estimation in RDDs then reduces to estimating $E[Y \mid X=x, W=\w, Z=z]$, for +which we turn to BART. + +# The BARDDT Model + +We propose a BART model where the trees split on $(x,\w)$ but each leaf node parameter +is a vector of regression coefficients tailored to the RDD context. Let $\psi$ denote +the following basis vector: +$$ +\psi(x,z) = \begin{bmatrix} 1 & zx & (1-z)x & z \end{bmatrix}. +$$ + +The prediction function for tree $j$ is defined as $g_j(x, \w, z) = \psi(x, z) \Gamma_{b_j(x, \w)}$ +for leaf-specific regression vector $\Gamma_{b_j} = (\eta_{b_j}, \lambda_{b_j}, \theta_{b_j}, \Delta_{b_j})^t$. +The model for observations in leaf $b_j$ is + +$$ +\Y_{b_j} \mid \Gamma_{b_j}, \sigma^2 \sim \N(\Psi_{b_j} \Gamma_{b_j}, \sigma^2), \qquad +\Gamma_{b_j} \sim \N(0, \Sigma_0), +$$ + +where we set $\Sigma_0 = \frac{0.033}{J}\mathrm{I}$ as a default (for $x$ standardized +to unit variance in-sample). + +This choice of basis entails that the RDD CATE at $\w$, $\tau(0, \w)$, is the sum of +the $\Delta_{b_j(0, \w)}$ elements across all trees: + +$$ +\tau(0, \w) = \sum_{j=1}^J \Delta_{b_j(0, \w)}. +$$ + +The priors on the $\Delta$ coefficients directly regularize the treatment effect. + +The following figures illustrate how BARDDT fits a response surface and estimates CATEs. + +![Two regression trees with splits in $x$ and a single scalar $w$. Node images depict the $g(x,w,z)$ function defined by that node's coefficients. The vertical gap between line segments at $x=0$ is that node's contribution to the CATE.](R/RDD/trees1.png){width=70% fig-align="center"} + +![The same two trees represented as a partition of the $x$-$w$ plane. The bottom figure shows the combined partition; the red dashed line marks $W=w^*$.](R/RDD/trees2.png){width=70% fig-align="center"} + +![Left: the function fit at $W = w^*$ for the two trees, superimposed. Right: the aggregated fit. The magnitude of the discontinuity at $x = 0$ is the treatment effect.](R/RDD/trees3.png){width=70% fig-align="center"} + +An interesting property of BARDDT: by letting the regression trees split on the running +variable, there is no need to separately define a bandwidth as in polynomial RDD. The +regression trees automatically determine (in the course of posterior sampling) when to +prune away regions far from the cutoff. + +# Demo + +In this section, we provide code for implementing BARDDT in `stochtree` on a +popular RDD dataset. + +## Setup + +Load the necessary packages + +:::{.panel-tabset group="language"} + +## R + +```{r} +#| message: false +library(stochtree) +library(rpart) +library(rpart.plot) +library(foreach) +library(doParallel) +``` + +## Python + +```{python} +import matplotlib.pyplot as plt +import seaborn as sns +import numpy as np +import pandas as pd +from sklearn.tree import DecisionTreeRegressor, plot_tree +from stochtree import BARTModel +``` + +::: + +Set a seed for reproducibility + +:::{.panel-tabset group="language"} + +## R + +```{r} +#| message: false +random_seed <- 1234 +set.seed(random_seed) +``` + +## Python + +```{python} +random_seed = 1234 +rng = np.random.default_rng(random_seed) +``` + +::: + +## Dataset + +The data comes from @lindo2010ability, who analyze data on college students at a large +Canadian university to evaluate an academic probation policy. Students whose GPA falls +below a threshold are placed on academic probation. The running variable $X$ is the +negative distance between a student's previous-term GPA and the probation threshold, so +students on probation ($Z = 1$) have positive scores and the cutoff is 0. The outcome +$Y$ is the student's GPA at the end of the current term. Potential moderators $W$ are: +gender (`male`), age at university entry (`age_at_entry`), a dummy for being born in +North America (`bpl_north_america`), credits taken in the first year +(`totcredits_year1`), campus indicators (`loc_campus` 1–3), and high school GPA +quantile (`hsgrade_pct`). + +:::{.panel-tabset group="language"} + +## R + +```{r} +# Load and organize data +data <- read.csv("https://raw.githubusercontent.com/rdpackages-replication/CIT_2024_CUP/refs/heads/main/CIT_2024_CUP_discrete.csv") +y <- data$nextGPA +x <- data$X +n <- nrow(data) + +# Standardize x +x <- x / sd(x) + +# Extract covariates +w <- data[, 4:11] + +# Encode categorical features as ordered/unordered factors +w$totcredits_year1 <- factor(w$totcredits_year1, ordered = TRUE) +w$male <- factor(w$male, ordered = FALSE) +w$bpl_north_america <- factor(w$bpl_north_america, ordered = FALSE) +w$loc_campus1 <- factor(w$loc_campus1, ordered = FALSE) +w$loc_campus2 <- factor(w$loc_campus2, ordered = FALSE) +w$loc_campus3 <- factor(w$loc_campus3, ordered = FALSE) + +# x is normalized so the cutoff occurs at c = 0 +c <- 0 + +# Binarize the running variable into a "treatment" indicator +z <- as.numeric(x > c) + +# Window for prediction sample +h <- 0.1 + +# Define the prediction subset +test <- -h < x & x < h +ntest <- sum(test) +``` + +## Python + +```{python} +# Load and organize data +data = pd.read_csv("https://raw.githubusercontent.com/rdpackages-replication/CIT_2024_CUP/refs/heads/main/CIT_2024_CUP_discrete.csv") +y = data.loc[:, "nextGPA"].to_numpy().squeeze() +x = data.loc[:, "X"].to_numpy().squeeze() +n = data.shape[0] + +# Standardize x +x = x / np.std(x) + +# Extract covariates +w = data.iloc[:, 3:11] + +# Encode categorical features as ordered/unordered factors +w["totcredits_year1"] = pd.Categorical( + w["totcredits_year1"], ordered=True +) +unordered_categorical_cols = [ + "male", + "bpl_north_america", + "loc_campus1", + "loc_campus2", + "loc_campus3", +] +for col in unordered_categorical_cols: + w.loc[:, col] = pd.Categorical(w.loc[:, col], ordered=False) + +# x is normalized so the cutoff occurs at c = 0 +c = 0 + +# Binarize the running variable into a "treatment" indicator +z = (x > c).astype(float) + +# Window for prediction sample +h = 0.1 + +# Define the prediction subset +test = (-h < x) & (x < h) +ntest = np.sum(test) +``` + +::: + +## Target Estimand + +Our estimand is the CATE function at $x = 0$, i.e. $\tau(0, \w)$. To focus on +feasible estimation points, we restrict to observed $\w_i$ such that $|x_i| \leq \delta$ +(here $\delta = 0.1$ after standardizing $X$). Our estimand is therefore + +$$ +\tau(0, \w_i) \quad \forall i \text{ such that } |x_i| \leq \delta. +$$ + +## Implementing BARDDT + +The $\psi$ basis vector for the leaf regression is +$\psi = [1,\, zx,\, (1-z)x,\, z]$, and the training covariate matrix is +$[x,\, W]$. The prediction basis at the cutoff for $Z=1$ and $Z=0$ is + +$$ +\psi_1 = [1, 0, 0, 1], \qquad \psi_0 = [1, 0, 0, 0]. +$$ + +:::{.panel-tabset group="language"} + +Define basis functions for model sampling + +## R + +```{r} +Psi <- cbind(rep(1, n), z * x, (1 - z) * x, z) +``` + +## Python + +```{python} +Psi = np.c_[np.ones(n), z * x, (1 - z) * x, z] +``` + +::: + +## Fitting the Model + +We run multiple chains and combine their posterior draws. To compute the CATE posterior, we obtain $Y(z)$ predictions by predicting from the model with $Z = z$ set in the basis. `stochtree` provides a function / method (`computeContrastBARTModel` in R, `compute_contrast` in Python) for directly computing this contrast from a sampled BART model. + +:::{.panel-tabset group="language"} + +## R + +```{r} +# Define sampling parameters +num_chains <- 4 +num_gfr <- 4 +num_burnin <- 0 +num_mcmc <- 500 + +# Parameter lists for BART model fit +global_params <- list( + standardize = T, + sample_sigma_global = TRUE, + sigma2_global_init = 0.1, + random_seed = random_seed, + num_threads = 1, + num_chains = num_chains +) +forest_params <- list( + num_trees = 50, + min_samples_leaf = 20, + alpha = 0.95, + beta = 2, + max_depth = 20, + sample_sigma2_leaf = FALSE, + sigma2_leaf_init = 0.1 / 50 +) + +# Fit the BART model +bart_model <- bart( + X_train = cbind(x, w), + leaf_basis_train = Psi, + y_train = y, + num_gfr = num_gfr, + num_burnin = num_burnin, + num_mcmc = num_mcmc, + general_params = global_params, + mean_forest_params = forest_params +) + +# Compute the CATE posterior +Psi0 <- cbind(rep(1, n), rep(0, n), rep(0, n), rep(0, n))[test, ] +Psi1 <- cbind(rep(1, n), rep(0, n), rep(0, n), rep(1, n))[test, ] +covariates_test <- cbind(x = rep(0, n), w)[test, ] +cate_posterior <- computeContrastBARTModel( + bart_model, + X_0 = covariates_test, + X_1 = covariates_test, + leaf_basis_0 = Psi0, + leaf_basis_1 = Psi1, + type = "posterior", + scale = "linear" +) +``` + +## Python + +```{python} +# Define sampling parameters +num_chains = 4 +num_gfr = 4 +num_burnin = 0 +num_mcmc = 500 + +# Parameter lists for BART model fit +global_params = { + "standardize": True, + "sample_sigma_global": True, + "sigma2_global_init": 0.1, + "random_seed": random_seed, + "num_threads": 1, + "num_chains": num_chains +} +forest_params = { + "num_trees": 50, + "min_samples_leaf": 20, + "alpha": 0.95, + "beta": 2, + "max_depth": 20, + "sample_sigma2_leaf": False, + "sigma2_leaf_init": 0.1 / 50, +} + +# Fit the BART model +covariates_train = w +covariates_train.loc[:, "x"] = x +bart_model = BARTModel() +bart_model.sample( + X_train=covariates_train, + leaf_basis_train=Psi, + y_train=y, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + general_params=global_params, + mean_forest_params=forest_params, +) + +# Compute the CATE posterior +Psi0 = np.c_[np.ones(n), np.zeros(n), np.zeros(n), np.zeros(n)][test, :] +Psi1 = np.c_[np.ones(n), np.zeros(n), np.zeros(n), np.ones(n)][test, :] +covariates_test = w.iloc[test, :] +covariates_test.loc[:, "x"] = np.zeros(ntest) +cate_posterior = bart_model.compute_contrast( + X_0=covariates_test, + X_1=covariates_test, + leaf_basis_0=Psi0, + leaf_basis_1=Psi1, + type="posterior", + scale="linear", +) +``` + +::: + +## Analyzing CATE Heterogeneity + +To summarize the CATE posterior we fit a regression tree to the posterior mean +point estimates $\bar{\tau}_i = \frac{1}{M} \sum_{h=1}^M \tau^{(h)}(0, \w_i)$, +using $W$ as predictors. We restrict to observations with $|x_i| \leq \delta$. + +:::{.panel-tabset group="language"} + +## R + +```{r} +#| fig-cap: "Regression tree fit to posterior point estimates of individual treatment effects. Top number in each box is the average subgroup treatment effect; lower number is the share of the sample." +cate <- rpart(y ~ ., data.frame(y = rowMeans(cate_posterior), w[test, ]), + control = rpart.control(cp = 0.015)) + +plot_cart <- function(rp) { + fr <- rp$frame + left <- which.min(fr$yval) + right <- which.max(fr$yval) + cols <- rep("lightblue3", nrow(fr)) + cols[fr$yval == fr$yval[left]] <- "tomato3" + cols[fr$yval == fr$yval[right]] <- "gold2" + cols +} + +rpart.plot(cate, main = "", box.col = plot_cart(cate)) +``` + +## Python + +```{python} +#| fig-cap: "Decision tree fit to posterior mean CATEs, used as an effect moderation summary." +y_surrogate = np.mean(cate_posterior, axis=1) +X_surrogate = w.iloc[test, :] +cp = 0.015 +min_impurity_decrease = cp * np.var(y_surrogate) +cate_tree = DecisionTreeRegressor(min_impurity_decrease=min_impurity_decrease) +cate_tree.fit(X=X_surrogate, y=y_surrogate) +plot_tree(cate_tree, impurity=False, filled=True, + feature_names=w.columns, proportion=False, + label="root", node_ids=True) +plt.show() +``` + +::: + +The resulting tree indicates that course load (`totcredits_year1`) in the academic term +leading to probation is a strong moderator of the treatment effect. The tree also flags +campus, age at entry, and gender as secondary moderators β€” all prima facie plausible. + +## Comparing Subgroup Posteriors + +The effect moderation tree is a posterior summary tool; it does not alter the +posterior itself. We can compare any two subgroups by averaging their individual +posterior draws. Consider the two groups at opposite ends of the effect range: + +- **Group A**: male student, entered college older than 19, attempted > 4.8 credits in + the first year (leftmost leaf, red) +- **Group B**: any gender, entered college younger than 19, attempted 4.3–4.8 credits + in the first year (rightmost leaf, gold) + +Subgroup posteriors are + +$$ +\bar{\tau}_A^{(h)} = \frac{1}{n_A} \sum_{i \in A} \tau^{(h)}(0, \w_i), +$$ + +where $h$ indexes a posterior draw and $n_A$ is the group size. + +:::{.panel-tabset group="language"} + +## R + +```{r} +#| fig-cap: "Joint kernel density estimate of the CATE posteriors for Groups A and B. Nearly all contour lines lie above the 45Β° line, indicating that Group B has persistently higher treatment effects." +cate_kde <- function(rp, pred) { + left <- rp$where == which.min(rp$frame$yval) + right <- rp$where == which.max(rp$frame$yval) + cate_a <- colMeans(pred[left, , drop = FALSE]) + cate_b <- colMeans(pred[right, , drop = FALSE]) + MASS::kde2d(cate_a, cate_b, n = 200) +} +contour(cate_kde(cate, cate_posterior), bty = "n", + xlab = "Group A", ylab = "Group B") +abline(a = 0, b = 1) +``` + +## Python + +```{python} +#| fig-cap: "Joint KDE of Group A and Group B CATE posteriors. Contours above the diagonal indicate Group B has persistently higher treatment effects." +predicted_nodes = cate_tree.apply(X=X_surrogate) +max_value_node = np.argmax(cate_tree.tree_.value) +min_value_node = np.argmin(cate_tree.tree_.value) +posterior_group_a = np.mean(cate_posterior[predicted_nodes == min_value_node, :], axis=0) +posterior_group_b = np.mean(cate_posterior[predicted_nodes == max_value_node, :], axis=0) +posterior_df = pd.DataFrame({"group_a": posterior_group_a, + "group_b": posterior_group_b}) +sns.kdeplot(data=posterior_df, x="group_a", y="group_b") +plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3))) +plt.show() +``` + +::: + +The contour lines are nearly all above the 45Β° line, indicating that the posterior +probability mass lies in the region where Group B has a larger treatment effect than +Group A, even after accounting for estimation uncertainty. + +As always, CATEs that vary with observable factors do not necessarily represent a +_causal_ moderating relationship; uncovering these patterns is crucial for suggesting +causal mechanisms to investigate in future studies. + +# References diff --git a/vignettes/serialization.qmd b/vignettes/serialization.qmd new file mode 100644 index 000000000..5ec5a58e2 --- /dev/null +++ b/vignettes/serialization.qmd @@ -0,0 +1,557 @@ +--- +title: "Saving and Loading Fitted Models" +bibliography: vignettes.bib +execute: + freeze: auto # re-render only when source changes +--- + +```{r} +#| include: false +reticulate::use_python( + Sys.getenv( + "RETICULATE_PYTHON", + unset = file.path(rprojroot::find_root(rprojroot::has_file(".here")), ".venv", "bin", "python") + ), + required = TRUE +) +``` + +This vignette demonstrates how to serialize ensemble models to JSON files and +deserialize back to an R or Python session, where the forests and other parameters +can be used for prediction and further analysis. + +# Setup + +Load necessary packages + +::::{.panel-tabset group="language"} + +## R + +```{r} +library(stochtree) +``` + +## Python + +```{python} +import json +import numpy as np +import matplotlib.pyplot as plt +import pandas as pd +from scipy.stats import norm +from stochtree import BARTModel, BCFModel +``` + +:::: + +Define several simple helper functions used in the data generating processes below + +::::{.panel-tabset group="language"} + +## R + +```{r} +g <- function(x) {ifelse(x[,5]==1,2,ifelse(x[,5]==2,-1,-4))} +mu1 <- function(x) {1+g(x)+x[,1]*x[,3]} +mu2 <- function(x) {1+g(x)+6*abs(x[,3]-1)} +tau1 <- function(x) {rep(3,nrow(x))} +tau2 <- function(x) {1+2*x[,2]*x[,4]} +``` + +## Python + +```{python} +def g(x): return np.where(x[:,4]==1, 2, np.where(x[:,4]==2, -1, -4)) +def mu1(x): return 1 + g(x) + x[:,0] * x[:,2] +def mu2(x): return 1 + g(x) + 6 * np.abs(x[:,2] - 1) +def tau1(x): return np.full(x.shape[0], 3.0) +def tau2(x): return 1 + 2 * x[:,1] * x[:,3] +``` + +:::: + +Set a seed for reproducibility + +::::{.panel-tabset group="language"} + +## R + +```{r} +random_seed = 1234 +set.seed(random_seed) +``` + +## Python + +```{python} +random_seed = 1234 +rng = np.random.default_rng(random_seed) +``` + +:::: + +# BART Serialization + +BART models are initially sampled and constructed using the `bart()` function. +Here we show how to save and reload models from JSON files on disk. + +## Model Building + +Draw from a relatively straightforward heteroskedastic supervised learning DGP. + +::::{.panel-tabset group="language"} + +## R + +```{r} +# Generate the data +n <- 500 +p_x <- 10 +X <- matrix(runif(n*p_x), ncol = p_x) +f_XW <- 0 +s_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.5*X[,3]) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (1*X[,3]) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2*X[,3]) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (3*X[,3]) +) +y <- f_XW + rnorm(n, 0, 1)*s_XW + +# Split data into test and train sets +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- as.data.frame(X[test_inds,]) +X_train <- as.data.frame(X[train_inds,]) +y_test <- y[test_inds] +y_train <- y[train_inds] +s_x_test <- s_XW[test_inds] +s_x_train <- s_XW[train_inds] +``` + +## Python + +```{python} +# Note: new rng here so Python Demo 2 is independent of Demo 1 +rng2 = np.random.default_rng(5678) + +n = 500 +p_x = 10 +X2 = rng2.uniform(size=(n, p_x)) +s_XW = ( + ((X2[:, 0] >= 0) & (X2[:, 0] < 0.25)) * (0.5 * X2[:, 2]) + + ((X2[:, 0] >= 0.25) & (X2[:, 0] < 0.5)) * (1.0 * X2[:, 2]) + + ((X2[:, 0] >= 0.5) & (X2[:, 0] < 0.75)) * (2.0 * X2[:, 2]) + + ((X2[:, 0] >= 0.75) & (X2[:, 0] < 1.0)) * (3.0 * X2[:, 2]) +) +y2 = rng2.standard_normal(n) * s_XW + +n_test2 = round(0.2 * n) +test_inds2 = rng2.choice(n, n_test2, replace=False) +train_inds2 = np.setdiff1d(np.arange(n), test_inds2) +X_test2 = pd.DataFrame(X2[test_inds2]) +X_train2 = pd.DataFrame(X2[train_inds2]) +y_test2 = y2[test_inds2] +y_train2 = y2[train_inds2] +``` + +:::: + +Sample a BART model. + +::::{.panel-tabset group="language"} + +## R + +```{r} +num_gfr <- 10 +num_burnin <- 0 +num_mcmc <- 100 +general_params <- list(sample_sigma2_global = F) +mean_forest_params <- list(sample_sigma2_leaf = F, num_trees = 100, + alpha = 0.95, beta = 2, min_samples_leaf = 5) +variance_forest_params <- list(num_trees = 50, alpha = 0.95, + beta = 1.25, min_samples_leaf = 1) +bart_model <- stochtree::bart( + X_train = X_train, y_train = y_train, X_test = X_test, + num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, + general_params = general_params, mean_forest_params = mean_forest_params, + variance_forest_params = variance_forest_params +) +``` + +## Python + +```{python} +num_gfr = 10 +num_burnin = 0 +num_mcmc = 100 +bart_model = BARTModel() +bart_model.sample( + X_train=X_train2, y_train=y_train2, X_test=X_test2, + num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, + general_params={"num_threads": 1, "sample_sigma2_global": False}, + mean_forest_params={"sample_sigma2_leaf": False, "num_trees": 100, + "alpha": 0.95, "beta": 2.0, "min_samples_leaf": 5}, + variance_forest_params={"num_trees": 50, "alpha": 0.95, + "beta": 1.25, "min_samples_leaf": 1}, +) +``` + +:::: + +## Serialization + +Save the BART model to disk. + +::::{.panel-tabset group="language"} + +## R + +```{r} +saveBARTModelToJsonFile(bart_model, "bart_r.json") +``` + +## Python + +```{python} +bart_json_string = bart_model.to_json() +with open("bart_py.json", "w") as f: + json.dump(json.loads(bart_json_string), f) +``` + +:::: + +## Deserialization + +Reload the BART model from disk. + +::::{.panel-tabset group="language"} + +## R + +```{r} +bart_model_reload <- createBARTModelFromJsonFile("bart_r.json") +``` + +## Python + +```{python} +with open("bart_py.json", "r") as f: + bart_json_reload = json.dumps(json.load(f)) +bart_model_reload = BARTModel() +bart_model_reload.from_json(bart_json_reload) +``` + +:::: + +Check that the predictions align with those of the original model. + +::::{.panel-tabset group="language"} + +## R + +```{r} +bart_preds_reload <- predict(bart_model_reload, X_train) +plot(rowMeans(bart_model$y_hat_train), rowMeans(bart_preds_reload$y_hat), + xlab = "Original", ylab = "Deserialized", main = "Conditional Mean Estimates") +abline(0,1,col="red",lwd=3,lty=3) +plot(rowMeans(bart_model$sigma2_x_hat_train), rowMeans(bart_preds_reload$variance_forest_predictions), + xlab = "Original", ylab = "Deserialized", main = "Conditional Variance Estimates") +abline(0,1,col="red",lwd=3,lty=3) +``` + +## Python + +```{python} +bart_preds_orig = bart_model.predict(X=X_train2, terms=["y_hat", "variance_forest"]) +bart_preds_reload = bart_model_reload.predict(X=X_train2, terms=["y_hat", "variance_forest"]) + +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) +yhat_orig = bart_preds_orig["y_hat"].mean(axis=1) +yhat_reload = bart_preds_reload["y_hat"].mean(axis=1) +lo, hi = min(yhat_orig.min(), yhat_reload.min()), max(yhat_orig.max(), yhat_reload.max()) +ax1.scatter(yhat_orig, yhat_reload, alpha=0.4, s=10) +ax1.plot([lo, hi], [lo, hi], color="red", linestyle="dashed", linewidth=2) +ax1.set_xlabel("Original") +ax1.set_ylabel("Deserialized") +ax1.set_title("Conditional Mean Estimates") + +# multi-term predict returns variance forest under "variance_forest_predictions" +vhat_orig = bart_preds_orig["variance_forest_predictions"].mean(axis=1) +vhat_reload = bart_preds_reload["variance_forest_predictions"].mean(axis=1) +lo, hi = min(vhat_orig.min(), vhat_reload.min()), max(vhat_orig.max(), vhat_reload.max()) +ax2.scatter(vhat_orig, vhat_reload, alpha=0.4, s=10) +ax2.plot([lo, hi], [lo, hi], color="red", linestyle="dashed", linewidth=2) +ax2.set_xlabel("Original") +ax2.set_ylabel("Deserialized") +ax2.set_title("Conditional Variance Estimates") + +plt.tight_layout() +plt.show() +``` + +:::: + +# Bayesian Causal Forest (BCF) Serialization + +BCF models are initially sampled and constructed using the `bcf()` function. +Here we show how to save and reload models from JSON files on disk. + +## Model Building + +Draw from a modified version of the data generating process defined in +@hahn2020bayesian. + +::::{.panel-tabset group="language"} + +## R + +```{r} +# Generate synthetic data +n <- 1000 +snr <- 2 +x1 <- rnorm(n) +x2 <- rnorm(n) +x3 <- rnorm(n) +x4 <- as.numeric(rbinom(n,1,0.5)) +x5 <- as.numeric(sample(1:3,n,replace=TRUE)) +X <- cbind(x1,x2,x3,x4,x5) +p <- ncol(X) +mu_x <- mu1(X) +tau_x <- tau2(X) +pi_x <- 0.8*pnorm((3*mu_x/sd(mu_x)) - 0.5*X[,1]) + 0.05 + runif(n)/10 +Z <- rbinom(n,1,pi_x) +E_XZ <- mu_x + Z*tau_x +rfx_group_ids <- rep(c(1,2), n %/% 2) +rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow=2, byrow=TRUE) +rfx_basis <- cbind(1, runif(n, -1, 1)) +rfx_term <- rowSums(rfx_coefs[rfx_group_ids,] * rfx_basis) +y <- E_XZ + rfx_term + rnorm(n, 0, 1)*(sd(E_XZ)/snr) +X <- as.data.frame(X) +X$x4 <- factor(X$x4, ordered = TRUE) +X$x5 <- factor(X$x5, ordered = TRUE) + +# Split data into test and train sets +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds,] +X_train <- X[train_inds,] +pi_test <- pi_x[test_inds] +pi_train <- pi_x[train_inds] +Z_test <- Z[test_inds] +Z_train <- Z[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] +mu_test <- mu_x[test_inds] +mu_train <- mu_x[train_inds] +tau_test <- tau_x[test_inds] +tau_train <- tau_x[train_inds] +rfx_group_ids_test <- rfx_group_ids[test_inds] +rfx_group_ids_train <- rfx_group_ids[train_inds] +rfx_basis_test <- rfx_basis[test_inds,] +rfx_basis_train <- rfx_basis[train_inds,] +rfx_term_test <- rfx_term[test_inds] +rfx_term_train <- rfx_term[train_inds] +``` + +## Python + +```{python} +random_seed = 1234 +rng = np.random.default_rng(random_seed) + +n = 1000 +snr = 2 +x1 = rng.standard_normal(n) +x2 = rng.standard_normal(n) +x3 = rng.standard_normal(n) +x4 = rng.binomial(1, 0.5, n).astype(float) +x5 = rng.choice([1, 2, 3], n).astype(float) +X = np.column_stack([x1, x2, x3, x4, x5]) +mu_x = mu1(X) +tau_x = tau2(X) +pi_x = 0.8 * norm.cdf((3 * mu_x / np.std(mu_x)) - 0.5 * X[:, 0]) + 0.05 + rng.uniform(size=n) / 10 +Z = rng.binomial(1, pi_x) +E_XZ = mu_x + Z * tau_x +rfx_group_ids = np.tile([1, 2], n // 2) # 1-indexed group IDs +rfx_coefs = np.array([[-1.0, -1.0], [1.0, 1.0]]) +rfx_basis = np.column_stack([np.ones(n), rng.uniform(-1, 1, n)]) +rfx_term = np.sum(rfx_coefs[rfx_group_ids - 1] * rfx_basis, axis=1) +y = E_XZ + rfx_term + rng.standard_normal(n) * (np.std(E_XZ) / snr) + +# Ordered categoricals +X_df = pd.DataFrame(X, columns=["x1", "x2", "x3", "x4", "x5"]) +X_df["x4"] = pd.Categorical(X_df["x4"].astype(int), categories=[0, 1], ordered=True) +X_df["x5"] = pd.Categorical(X_df["x5"].astype(int), categories=[1, 2, 3], ordered=True) + +# Train/test split +test_set_pct = 0.2 +n_test = round(test_set_pct * n) +test_inds = rng.choice(n, n_test, replace=False) +train_inds = np.setdiff1d(np.arange(n), test_inds) +X_test = X_df.iloc[test_inds] +X_train = X_df.iloc[train_inds] +pi_test = pi_x[test_inds] +pi_train = pi_x[train_inds] +Z_test = Z[test_inds] +Z_train = Z[train_inds] +y_test = y[test_inds] +y_train = y[train_inds] +rfx_group_ids_test = rfx_group_ids[test_inds] +rfx_group_ids_train = rfx_group_ids[train_inds] +rfx_basis_test = rfx_basis[test_inds] +rfx_basis_train = rfx_basis[train_inds] +``` + +:::: + +Sample a BCF model. + +::::{.panel-tabset group="language"} + +## R + +```{r} +num_gfr <- 10 +num_burnin <- 0 +num_mcmc <- 100 +prognostic_forest_params <- list(sample_sigma2_leaf = F) +treatment_effect_forest_params <- list(sample_sigma2_leaf = F) +bcf_model <- bcf( + X_train = X_train, Z_train = Z_train, y_train = y_train, propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, rfx_basis_train = rfx_basis_train, + X_test = X_test, Z_test = Z_test, propensity_test = pi_test, + rfx_group_ids_test = rfx_group_ids_test, rfx_basis_test = rfx_basis_test, + num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, + prognostic_forest_params = prognostic_forest_params, + treatment_effect_forest_params = treatment_effect_forest_params +) +``` + +## Python + +```{python} +num_gfr = 10 +num_burnin = 0 +num_mcmc = 100 +bcf_model = BCFModel() +bcf_model.sample( + X_train=X_train, Z_train=Z_train, y_train=y_train, propensity_train=pi_train, + rfx_group_ids_train=rfx_group_ids_train, rfx_basis_train=rfx_basis_train, + X_test=X_test, Z_test=Z_test, propensity_test=pi_test, + rfx_group_ids_test=rfx_group_ids_test, rfx_basis_test=rfx_basis_test, + num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, + general_params={"num_threads": 1}, + prognostic_forest_params={"sample_sigma2_leaf": False}, + treatment_effect_forest_params={"sample_sigma2_leaf": False}, +) +``` + +:::: + +## Serialization + +Save the BCF model to disk. + +::::{.panel-tabset group="language"} + +## R + +```{r} +saveBCFModelToJsonFile(bcf_model, "bcf_r.json") +``` + +## Python + +```{python} +bcf_json_string = bcf_model.to_json() +with open("bcf_py.json", "w") as f: + json.dump(json.loads(bcf_json_string), f) +``` + +:::: + +## Deserialization + +Reload the BCF model from disk. + +::::{.panel-tabset group="language"} + +## R + +```{r} +bcf_model_reload <- createBCFModelFromJsonFile("bcf_r.json") +``` + +## Python + +```{python} +with open("bcf_py.json", "r") as f: + bcf_json_reload = json.dumps(json.load(f)) +bcf_model_reload = BCFModel() +bcf_model_reload.from_json(bcf_json_reload) +``` + +:::: + +Check that the predictions align with those of the original model. + +::::{.panel-tabset group="language"} + +## R + +```{r} +bcf_preds_reload <- predict(bcf_model_reload, X_train, Z_train, pi_train, rfx_group_ids_train, rfx_basis_train) +plot(rowMeans(bcf_model$mu_hat_train), rowMeans(bcf_preds_reload$mu_hat), + xlab = "Original", ylab = "Deserialized", main = "Prognostic forest") +abline(0,1,col="red",lwd=3,lty=3) +plot(rowMeans(bcf_model$tau_hat_train), rowMeans(bcf_preds_reload$tau_hat), + xlab = "Original", ylab = "Deserialized", main = "Treatment forest") +abline(0,1,col="red",lwd=3,lty=3) +plot(rowMeans(bcf_model$y_hat_train), rowMeans(bcf_preds_reload$y_hat), + xlab = "Original", ylab = "Deserialized", main = "Overall outcome") +abline(0,1,col="red",lwd=3,lty=3) +``` + +## Python + +```{python} +bcf_preds_orig = bcf_model.predict( + X=X_train, Z=Z_train, propensity=pi_train, + rfx_group_ids=rfx_group_ids_train, rfx_basis=rfx_basis_train, + terms=["mu", "tau", "y_hat"], +) +bcf_preds_reload = bcf_model_reload.predict( + X=X_train, Z=Z_train, propensity=pi_train, + rfx_group_ids=rfx_group_ids_train, rfx_basis=rfx_basis_train, + terms=["mu", "tau", "y_hat"], +) + +fig, axes = plt.subplots(1, 3, figsize=(15, 5)) +for ax, term, title in zip( + axes, + ["mu_hat", "tau_hat", "y_hat"], + ["Prognostic forest", "Treatment forest", "Overall outcome"], +): + orig = bcf_preds_orig[term].mean(axis=1) + reload = bcf_preds_reload[term].mean(axis=1) + lo, hi = min(orig.min(), reload.min()), max(orig.max(), reload.max()) + ax.scatter(orig, reload, alpha=0.4, s=10) + ax.plot([lo, hi], [lo, hi], color="red", linestyle="dashed", linewidth=2) + ax.set_xlabel("Original") + ax.set_ylabel("Deserialized") + ax.set_title(title) +plt.tight_layout() +plt.show() +``` + +:::: + +# References diff --git a/vignettes/sklearn.qmd b/vignettes/sklearn.qmd new file mode 100644 index 000000000..fc40e8075 --- /dev/null +++ b/vignettes/sklearn.qmd @@ -0,0 +1,220 @@ +--- +title: "Using Stochtree via Sklearn-Compatible Estimators in Python" +execute: + freeze: auto # re-render only when source changes +--- + +```{r} +#| include: false +reticulate::use_python( + Sys.getenv( + "RETICULATE_PYTHON", + unset = file.path(rprojroot::find_root(rprojroot::has_file(".here")), ".venv", "bin", "python") + ), + required = TRUE +) +``` + +This vignette is python-specific and no similar interface is implemented for R. + +`stochtree.BARTModel` is fundamentally a Bayesian interface in which users specify a prior, provide data, sample from the posterior, and manage and inspect the resulting posterior samples. However, the basic BART model + +$$y_i \sim \mathcal{N}\left(f(X_i), \sigma^2\right)$$ + +involves samples of a nonparametric function $f$ which estimates the expected +value of $y$ given $X$. Averaging over these draws, the posterior mean $\bar{f}$ +alone may satisfy some supervised learning use cases. To serve this use case +straightforwardly, `stochtree` offers +[scikit-learn-compatible estimator](https://scikit-learn.org/stable/developers/develop.html) +wrappers around `BARTModel` which implement the familiar `sklearn` API. + +- **`StochTreeBARTRegressor`**: continuous outcomes β€” provides `fit`, `predict`, + and `score` +- **`StochTreeBARTBinaryClassifier`**: binary outcomes via probit BART β€” + provides `fit`, `predict`, `predict_proba`, `decision_function`, and `score` +- Multi-class classification is supported by wrapping + [`OneVsRestClassifier`](https://scikit-learn.org/stable/modules/generated/sklearn.multiclass.OneVsRestClassifier.html) + around `StochTreeBARTBinaryClassifier` + +## Setup + +```{python} +import matplotlib.pyplot as plt +import numpy as np +from sklearn.datasets import load_wine, load_breast_cancer +from sklearn.model_selection import GridSearchCV +from sklearn.multiclass import OneVsRestClassifier +from stochtree import ( + StochTreeBARTRegressor, + StochTreeBARTBinaryClassifier, +) +``` + +```{python} +random_seed = 1234 +rng = np.random.default_rng(random_seed) +``` + +## BART Regression + +We simulate simple regression data to demonstrate the continuous outcome case. + +```{python} +n = 100 +p = 10 +X = rng.normal(size=(n, p)) +y = X[:, 0] * 3 + rng.normal(size=n) +``` + +We fit a BART regression model by initializing a `StochTreeBARTRegressor` and +calling `fit()`. Since `BARTModel` is configured primarily through parameter +dictionaries, downstream parameters are passed through as such β€” here we only +specify the random seed. + +```{python} +reg = StochTreeBARTRegressor(general_params={"random_seed": random_seed, "num_threads": 1}) +reg.fit(X, y) +``` + +We can then predict from the model and compare posterior mean predictions to +the true outcome. + +```{python} +pred = reg.predict(X) +plt.scatter(pred, y) +plt.xlabel("Predicted") +plt.ylabel("Actual") +plt.show() +``` + +We can also verify determinism by running the model again with the same seed +and comparing predictions. + +```{python} +reg2 = StochTreeBARTRegressor(general_params={"random_seed": random_seed, "num_threads": 1}) +reg2.fit(X, y) +pred2 = reg2.predict(X) +plt.scatter(pred, pred2) +plt.xlabel("First model") +plt.ylabel("Second model") +plt.show() +``` + +## Cross-Validating a BART Model + +While the default hyperparameters of `BARTModel` are designed to work well +out of the box, we can use posterior mean prediction error to cross-validate +the model's parameters. Below we use grid search to consider the effect of +several BART parameters: + +1. Number of GFR iterations (`num_gfr`) +2. Number of MCMC iterations (`num_mcmc`) +3. `num_trees`, `alpha`, and `beta` for the mean forest + +```{python} +param_grid = { + "num_gfr": [10, 40], + "num_mcmc": [0, 1000], + "mean_forest_params": [ + {"num_trees": 50, "alpha": 0.95, "beta": 2.0}, + {"num_trees": 100, "alpha": 0.90, "beta": 1.5}, + {"num_trees": 200, "alpha": 0.85, "beta": 1.0}, + ], +} +grid_search = GridSearchCV( + estimator=StochTreeBARTRegressor(general_params={"num_threads": 1}), + param_grid=param_grid, + cv=5, + scoring="r2", + n_jobs=1, # n_jobs=-1 deadlocks when stochtree's C++ thread pool is active +) +grid_search.fit(X, y) +``` + +Note that we set `n_jobs=1` above to avoid deadlocks arising from interactions between `reticulate` (which renders these python vignettes), `joblib`, and stochtree's own C++ multithreading model. Users running this vignette interactively or as a script do not need to fix `n_jobs=1`. + +```{python} +cv_best_ind = np.argwhere(grid_search.cv_results_['rank_test_score'] == 1).item(0) +best_num_gfr = grid_search.cv_results_['param_num_gfr'][cv_best_ind].item(0) +best_num_mcmc = grid_search.cv_results_['param_num_mcmc'][cv_best_ind].item(0) +best_mean_forest_params = grid_search.cv_results_['param_mean_forest_params'][cv_best_ind] +best_num_trees = best_mean_forest_params['num_trees'] +best_alpha = best_mean_forest_params['alpha'] +best_beta = best_mean_forest_params['beta'] +print_message = f""" +Hyperparameters chosen by grid search: + num_gfr: {best_num_gfr} + num_mcmc: {best_num_mcmc} + num_trees: {best_num_trees} + alpha: {best_alpha} + beta: {best_beta} +""" +print(print_message) +``` + +## BART Classification + +### Binary Classification + +We load a binary outcome dataset from `sklearn`. + +```{python} +dataset = load_breast_cancer() +X = dataset.data +y = dataset.target +``` + +We fit a binary classification model using `StochTreeBARTBinaryClassifier`. + +```{python} +clf = StochTreeBARTBinaryClassifier(general_params={"random_seed": random_seed, "num_threads": 1}) +clf.fit(X=X, y=y) +``` + +In addition to class predictions, we can compute and visualize the predicted +probability of each class via `predict_proba()`. + +```{python} +probs = clf.predict_proba(X) +plt.hist(probs[:, 1], bins=30) +plt.xlabel("Predicted probability (class 1)") +plt.ylabel("Count") +plt.show() +``` + +### Multi-Class Classification + +For multi-class outcomes, we wrap `OneVsRestClassifier` around +`StochTreeBARTBinaryClassifier`. Here we use the Wine dataset, which has three +classes. + +```{python} +dataset = load_wine() +X = dataset.data +y = dataset.target +``` + +```{python} +clf = OneVsRestClassifier( + StochTreeBARTBinaryClassifier(general_params={"random_seed": random_seed, "num_threads": 1}) +) +clf.fit(X=X, y=y) +``` + +We visualize the histogram of predicted probabilities for each outcome category. + +```{python} +fig, (ax1, ax2, ax3) = plt.subplots(3, 1) +fig.tight_layout(pad=3.0) +probs = clf.predict_proba(X) +ax1.hist(probs[y == 0, 0], bins=30) +ax1.set_title("Predicted Probabilities for Class 0") +ax1.set_xlim(0, 1) +ax2.hist(probs[y == 1, 1], bins=30) +ax2.set_title("Predicted Probabilities for Class 1") +ax2.set_xlim(0, 1) +ax3.hist(probs[y == 2, 2], bins=30) +ax3.set_title("Predicted Probabilities for Class 2") +ax3.set_xlim(0, 1) +plt.show() +``` diff --git a/vignettes/summary-plotting.qmd b/vignettes/summary-plotting.qmd new file mode 100644 index 000000000..a2b9f6381 --- /dev/null +++ b/vignettes/summary-plotting.qmd @@ -0,0 +1,443 @@ +--- +title: "Posterior Summary and Visualization Utilities" +bibliography: vignettes.bib +execute: + freeze: auto # re-render only when source changes +--- + +```{r} +#| include: false +reticulate::use_python( + Sys.getenv( + "RETICULATE_PYTHON", + unset = file.path(rprojroot::find_root(rprojroot::has_file(".here")), ".venv", "bin", "python") + ), + required = TRUE +) +``` + +This vignette demonstrates the summary and plotting utilities available for +`stochtree` models. + +# Setup + +Load necessary packages + +::::{.panel-tabset group="language"} + +## R + +```{r} +library(stochtree) +``` + +## Python + +```{python} +import numpy as np +import matplotlib.pyplot as plt +from stochtree import BARTModel, BCFModel, plot_parameter_trace +``` + +:::: + +Set a seed for reproducibility + +::::{.panel-tabset group="language"} + +## R + +```{r} +random_seed = 1234 +set.seed(random_seed) +``` + +## Python + +```{python} +random_seed = 1234 +rng = np.random.default_rng(random_seed) +``` + +:::: + +# Supervised Learning + +We begin with the supervised learning use case served by the `bart()` function. + +Below we simulate a simple regression dataset. + +::::{.panel-tabset group="language"} + +## R + +```{r} +n <- 1000 +p_x <- 10 +p_w <- 1 +X <- matrix(runif(n * p_x), ncol = p_x) +W <- matrix(runif(n * p_w), ncol = p_w) +f_XW <- (((0 <= X[, 10]) & (0.25 > X[, 10])) * + (-7.5 * W[, 1]) + + ((0.25 <= X[, 10]) & (0.5 > X[, 10])) * (-2.5 * W[, 1]) + + ((0.5 <= X[, 10]) & (0.75 > X[, 10])) * (2.5 * W[, 1]) + + ((0.75 <= X[, 10]) & (1 > X[, 10])) * (7.5 * W[, 1])) +noise_sd <- 1 +y <- f_XW + rnorm(n, 0, 1) * noise_sd +``` + +## Python + +```{python} +n = 1000 +p_x = 10 +p_w = 1 +X = rng.uniform(size=(n, p_x)) +W = rng.uniform(size=(n, p_w)) +# R uses X[,10] (1-indexed) = Python X[:,9] +f_XW = ( + ((X[:, 9] >= 0) & (X[:, 9] < 0.25)) * (-7.5 * W[:, 0]) + + ((X[:, 9] >= 0.25) & (X[:, 9] < 0.5)) * (-2.5 * W[:, 0]) + + ((X[:, 9] >= 0.5) & (X[:, 9] < 0.75)) * ( 2.5 * W[:, 0]) + + ((X[:, 9] >= 0.75) & (X[:, 9] < 1.0)) * ( 7.5 * W[:, 0]) +) +noise_sd = 1.0 +y = f_XW + rng.standard_normal(n) * noise_sd +``` + +:::: + +Now we fit a simple BART model to the data. + +::::{.panel-tabset group="language"} + +## R + +```{r} +num_gfr <- 10 +num_burnin <- 0 +num_mcmc <- 1000 +general_params <- list( + num_threads = 1, + num_chains = 3 +) +bart_model <- stochtree::bart( + X_train = X, + y_train = y, + leaf_basis_train = W, + num_gfr = num_gfr, + num_burnin = num_burnin, + num_mcmc = num_mcmc, + general_params = general_params +) +``` + +## Python + +```{python} +bart_model = BARTModel() +bart_model.sample( + X_train=X, + y_train=y, + leaf_basis_train=W, + num_gfr=10, + num_burnin=0, + num_mcmc=1000, + general_params={ + "num_threads": 1, + "num_chains": 3 + }, +) +``` + +:::: + +We obtain a high level summary of the BART model by running `print()`. + +::::{.panel-tabset group="language"} + +## R + +```{r} +print(bart_model) +``` + +## Python + +```{python} +print(bart_model) +``` + +:::: + +For a more detailed summary (including the information above), we use the `summary()` +function. + +::::{.panel-tabset group="language"} + +## R + +```{r} +summary(bart_model) +``` + +## Python + +```{python} +print(bart_model.summary()) +``` + +:::: + +We can use the `plot()` function to produce a traceplot of model terms like the global +error scale $\sigma^2$ or (if $\sigma^2$ is not sampled) the first observation of +cached train set predictions. + +::::{.panel-tabset group="language"} + +## R + +```{r} +plot(bart_model) +``` + +## Python + +```{python} +ax = plot_parameter_trace(bart_model, term="global_error_scale") +plt.show() +``` + +:::: + +For finer-grained control over which parameters to plot, we can also use the +`extractParameter()` function to pull the posterior distribution of any valid model +term (e.g., global error scale $\sigma^2$, leaf scale $\sigma^2_{\ell}$, in-sample +mean function predictions `y_hat_train`) and then plot any subset or transformation +of these values. + +::::{.panel-tabset group="language"} + +## R + +```{r} +y_hat_train_samples <- extractParameter(bart_model, "y_hat_train") +obs_index <- 1 +plot( + y_hat_train_samples[obs_index, ], + type = "l", + main = paste0("In-Sample Predictions Traceplot, Observation ", obs_index), + xlab = "Index", + ylab = "Parameter Values" +) +``` + +## Python + +```{python} +y_hat_train_samples = bart_model.extract_parameter("y_hat_train") +obs_index = 0 +fig, ax = plt.subplots() +ax.plot(y_hat_train_samples[obs_index, :]) +ax.set_title(f"In-Sample Predictions Traceplot, Observation {obs_index}") +ax.set_xlabel("Index") +ax.set_ylabel("Parameter Values") +plt.show() +``` + +:::: + +# Causal Inference + +We now run the same demo for the causal inference use case served by the `bcf()` function in R and the `BCFModel` Python class. + +Below we simulate a simple dataset for a causal inference problem with binary treatment and continuous outcome. + +::::{.panel-tabset group="language"} + +## R + +```{r} +# Generate covariates and treatment +n <- 1000 +p_X = 5 +X = matrix(runif(n * p_X), ncol = p_X) +pi_X = 0.25 + 0.5 * X[, 1] +Z = rbinom(n, 1, pi_X) + +# Define the outcome mean functions (prognostic and treatment effects) +mu_X = pi_X * 5 + 2 * X[, 3] +tau_X = X[, 2] * 2 - 1 + +# Generate outcome +epsilon = rnorm(n, 0, 1) +y = mu_X + tau_X * Z + epsilon +``` + +## Python + +```{python} +# Generate covariates and treatment +n = 1000 +p_X = 5 +X = rng.uniform(size=(n, p_X)) +pi_X = 0.25 + 0.5 * X[:, 0] +Z = rng.binomial(1, pi_X, n).astype(float) + +# Define the outcome mean functions (prognostic and treatment effects) +mu_X = pi_X * 5 + 2 * X[:, 2] +tau_X = X[:, 1] * 2 - 1 + +# Generate outcome +epsilon = rng.standard_normal(n) +y = mu_X + tau_X * Z + epsilon +``` + +:::: + +Now we fit a simple BCF model to the data + +::::{.panel-tabset group="language"} + +## R + +```{r} +num_gfr <- 10 +num_burnin <- 0 +num_mcmc <- 1000 +general_params <- list( + num_threads = 1, + num_chains = 3, + adaptive_coding = TRUE +) +bcf_model <- stochtree::bcf( + X_train = X, + y_train = y, + Z_train = Z, + num_gfr = num_gfr, + num_burnin = num_burnin, + num_mcmc = num_mcmc, + general_params = general_params +) +``` + +## Python + +```{python} +bcf_model = BCFModel() +bcf_model.sample( + X_train=X, + Z_train=Z, + y_train=y, + propensity_train=pi_X, + num_gfr=10, + num_burnin=0, + num_mcmc=1000, + general_params={ + "num_threads": 1, + "num_chains": 3, + "adaptive_coding": True + }, +) +``` + +:::: + +We obtain a high level summary of the BCF model by running `print()`. + +::::{.panel-tabset group="language"} + +## R + +```{r} +print(bcf_model) +``` + +## Python + +```{python} +print(bcf_model) +``` + +:::: + +For a more detailed summary (including the information above), we use the `summary()` function / method. + +::::{.panel-tabset group="language"} + +## R + +```{r} +summary(bcf_model) +``` + +## Python + +```{python} +print(bcf_model.summary()) +``` + +:::: + +In R, we have a `plot()` that produces a traceplot of model terms like the global error scale $\sigma^2$ or (if $\sigma^2$ is not sampled) the first observation of cached train set predictions. + +In Python, we provide a `plot_parameter_trace()` function for requesting a traceplot of a specific model parameter. + +::::{.panel-tabset group="language"} + +## R + +```{r} +plot(bcf_model) +``` + +## Python + +```{python} +ax = plot_parameter_trace(bcf_model, term="global_error_scale") +plt.show() +``` + +:::: + +For finer-grained control over which parameters to plot, we can also use the `extractParameter()` function in R or the `extract_parameter()` method in Python to query the posterior distribution of any valid model term (e.g., global error scale $\sigma^2$, prognostic forest leaf scale $\sigma^2_{\mu}$, CATE forest leaf scale $\sigma^2_{\tau}$, adaptive coding parameters $b_0$ and $b_1$ for binary treatment, in-sample mean function predictions `y_hat_train`, in-sample CATE function predictions `tau_hat_train`) and then plot any subset or transformation of these values. + +::::{.panel-tabset group="language"} + +## R + +```{r} +adaptive_coding_samples <- extractParameter(bcf_model, "adaptive_coding") +plot( + adaptive_coding_samples[1, ], + type = "l", + main = "Adaptive Coding Parameter Traceplot", + xlab = "Index", + ylab = "Parameter Values", + ylim = range(adaptive_coding_samples), + col = "blue" +) +lines(adaptive_coding_samples[2, ], col = "orange") +legend( + "topright", + legend = c("Control", "Treated"), + lty = 1, + col = c("blue", "orange") +) +``` + +## Python + +```{python} +adaptive_coding_samples = bcf_model.extract_parameter("adaptive_coding") +fig, ax = plt.subplots() +ax.plot(adaptive_coding_samples[0, :], color="blue", label="Control") +ax.plot(adaptive_coding_samples[1, :], color="orange", label="Treated") +ax.set_title("Adaptive Coding Parameter Traceplot") +ax.set_xlabel("Index") +ax.set_ylabel("Parameter Values") +ax.legend(loc="upper right") +plt.show() +``` + +:::: diff --git a/vignettes/tree-inspection.qmd b/vignettes/tree-inspection.qmd new file mode 100644 index 000000000..94d9098b6 --- /dev/null +++ b/vignettes/tree-inspection.qmd @@ -0,0 +1,397 @@ +--- +title: "Examining Individual Trees in a Fitted Ensemble" +bibliography: vignettes.bib +execute: + freeze: auto # re-render only when source changes +--- + +```{r} +#| include: false +reticulate::use_python( + Sys.getenv( + "RETICULATE_PYTHON", + unset = file.path(rprojroot::find_root(rprojroot::has_file(".here")), ".venv", "bin", "python") + ), + required = TRUE +) +``` + +While out of sample evaluation and MCMC diagnostics on parametric BART components +(i.e. $\sigma^2$, the global error variance) are helpful, it's important to be able +to inspect the trees in a BART / BCF model. This vignette walks through some of the +features `stochtree` provides to query and understand the forests and trees in a model. + +# Setup + +Load necessary packages + +::::{.panel-tabset group="language"} + +## R + +```{r} +library(stochtree) +``` + +## Python + +```{python} +import numpy as np +import matplotlib.pyplot as plt +from stochtree import BARTModel +``` + +:::: + +Set a seed for reproducibility + +::::{.panel-tabset group="language"} + +## R + +```{r} +random_seed = 1234 +set.seed(random_seed) +``` + +## Python + +```{python} +random_seed = 1234 +rng = np.random.default_rng(random_seed) +``` + +:::: + +# Data Generation + +Generate sample data where feature 10 is the only "important" feature + +::::{.panel-tabset group="language"} + +## R + +```{r} +n <- 500 +p_x <- 10 +X <- matrix(runif(n*p_x), ncol = p_x) +f_XW <- ( + ((0 <= X[,10]) & (0.25 > X[,10])) * (-7.5) + + ((0.25 <= X[,10]) & (0.5 > X[,10])) * (-2.5) + + ((0.5 <= X[,10]) & (0.75 > X[,10])) * (2.5) + + ((0.75 <= X[,10]) & (1 > X[,10])) * (7.5) +) +noise_sd <- 1 +y <- f_XW + rnorm(n, 0, 1)*noise_sd +``` + +## Python + +```{python} +n = 500 +p_x = 10 +X = rng.uniform(size=(n, p_x)) +# Feature 10 (R) = feature index 9 (Python, 0-indexed) +f_XW = ( + ((X[:, 9] >= 0) & (X[:, 9] < 0.25)) * (-7.5) + + ((X[:, 9] >= 0.25) & (X[:, 9] < 0.5)) * (-2.5) + + ((X[:, 9] >= 0.5) & (X[:, 9] < 0.75)) * (2.5) + + ((X[:, 9] >= 0.75) & (X[:, 9] < 1.0)) * (7.5) +) +noise_sd = 1.0 +y = f_XW + rng.standard_normal(n) * noise_sd +``` + +:::: + +Split into train and test sets + +::::{.panel-tabset group="language"} + +## R + +```{r} +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- as.data.frame(X[test_inds,]) +X_train <- as.data.frame(X[train_inds,]) +y_test <- y[test_inds] +y_train <- y[train_inds] +``` + +## Python + +```{python} +n_test = round(0.2 * n) +test_inds = rng.choice(n, n_test, replace=False) +train_inds = np.setdiff1d(np.arange(n), test_inds) +X_test = X[test_inds] +X_train = X[train_inds] +y_test = y[test_inds] +y_train = y[train_inds] +``` + +:::: + +# Model Sampling + +Sample a BART model with 10 GFR and 100 MCMC iterations + +::::{.panel-tabset group="language"} + +## R + +```{r} +num_gfr <- 10 +num_burnin <- 0 +num_mcmc <- 100 +general_params <- list(keep_gfr = T) +bart_model <- stochtree::bart( + X_train = X_train, y_train = y_train, X_test = X_test, + num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, + general_params = general_params +) +``` + +## Python + +```{python} +num_gfr = 10 +num_burnin = 0 +num_mcmc = 100 +bart_model = BARTModel() +bart_model.sample( + X_train=X_train, y_train=y_train, X_test=X_test, + num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, + general_params={"num_threads": 1, "keep_gfr": True}, +) +``` + +:::: + +# Model Inspection + +Assess the global error variance traceplot and test set prediction quality + +::::{.panel-tabset group="language"} + +## R + +```{r} +sigma2_samples <- extractParameter(bart_model, "sigma2_global") +plot(sigma2_samples, ylab="sigma^2") +abline(h=noise_sd^2,col="red",lty=2,lwd=2.5) +y_hat_test <- predict(bart_model, X=X_test, type="mean", terms="y_hat") +plot(y_hat_test, y_test, pch=16, cex=0.75, xlab = "pred", ylab = "actual") +abline(0,1,col="red",lty=2,lwd=2.5) +``` + +## Python + +```{python} +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) + +sigma2_samples = bart_model.extract_parameter("sigma2_global") +ax1.plot(sigma2_samples) +ax1.axhline(noise_sd**2, color="red", linestyle="dashed", linewidth=2) +ax1.set_ylabel(r"$\sigma^2$") + +y_hat_test = bart_model.predict(X=X_test, terms="y_hat", type="mean") +ax2.scatter(y_hat_test, y_test, s=15, alpha=0.6) +lo = min(y_hat_test.min(), y_test.min()) +hi = max(y_hat_test.max(), y_test.max()) +ax2.plot([lo, hi], [lo, hi], color="red", linestyle="dashed", linewidth=2) +ax2.set_xlabel("pred") +ax2.set_ylabel("actual") + +plt.tight_layout() +plt.show() +``` + +:::: + +## Variable Split Counts + +The `get_forest_split_counts` method of a BART model's internal forest objects allows us to compute the number of times each variable was used in a split rule across all trees in a given forest. + +Below we query this vector for the final GFR sample (1-indexed as 10 in R, 0-indexed as 9 in Python), where the second argument is the dimensionality of the covariates. + +::::{.panel-tabset group="language"} + +## R + +```{r} +bart_model$mean_forests$get_forest_split_counts(10, p_x) +``` + +## Python + +```{python} +bart_model.forest_container_mean.get_forest_split_counts(9, p_x) +``` + +:::: + +We can also compute split counts for each feature aggregated over all forests + +::::{.panel-tabset group="language"} + +## R + +```{r} +bart_model$mean_forests$get_aggregate_split_counts(p_x) +``` + +## Python + +```{python} +bart_model.forest_container_mean.get_overall_split_counts(p_x) +``` + +:::: + +The split counts appear relatively uniform across features, so let's dig deeper and +look at individual trees. + +The `get_granular_split_counts` method returns a 3-dimensional array of shape `(num_forests, num_trees, num_features)`, where each entry represents the number of times a feature was used in a split for a specific tree in a specific forest. + +That is, we can count the number of times feature $k$ was split on in tree $j$ of forest $i$ by looking at the `(i,j,k)` entry of this array. + +Below we compute the split count for all features in the first tree of the last GFR sample in our model (noting again the use of 1-indexing in R and 0-indexing in Python). + +::::{.panel-tabset group="language"} + +## R + +```{r} +splits = bart_model$mean_forests$get_granular_split_counts(p_x) +splits[10,1,] +``` + +## Python + +```{python} +splits = bart_model.forest_container_mean.get_granular_split_counts(p_x) +splits[9, 0, :] +``` + +:::: + +This tree has a single split on the only "important" feature (10). Now, let's look at +the second tree. + +::::{.panel-tabset group="language"} + +## R + +```{r} +splits[10,2,] +``` + +## Python + +```{python} +splits[9, 1, :] +``` + +:::: + +And the 20th and 30th trees + +::::{.panel-tabset group="language"} + +## R + +```{r} +splits[10,20,] +``` + +## Python + +```{python} +splits[9, 19, :] +``` + +:::: + +::::{.panel-tabset group="language"} + +## R + +```{r} +splits[10,30,] +``` + +## Python + +```{python} +splits[9, 29, :] +``` + +:::: + +We see that "later" trees are splitting on other features, but we also note that these +trees are fitting an outcome that is already residualized by many "relevant splits" +made by trees 1 and 2. + +## Tree Structure + +Now, let's inspect the first tree for the last GFR sample in more depth, following +[this scikit-learn vignette](https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html). + +::::{.panel-tabset group="language"} + +## R + +```{r} +forest_num <- 9 +tree_num <- 0 +nodes <- sort(bart_model$mean_forests$nodes(forest_num, tree_num)) +for (nid in nodes) { + if (bart_model$mean_forests$is_leaf_node(forest_num, tree_num, nid)) { + node_depth <- bart_model$mean_forests$node_depth(forest_num, tree_num, nid) + space_text <- rep("\t", node_depth) + leaf_values <- bart_model$mean_forests$node_leaf_values(forest_num, tree_num, nid) + cat(space_text, "node=", nid, " is a leaf node with value=", + format(leaf_values, digits = 3), "\n", sep = "") + } else { + node_depth <- bart_model$mean_forests$node_depth(forest_num, tree_num, nid) + space_text <- rep("\t", node_depth) + left <- bart_model$mean_forests$left_child_node(forest_num, tree_num, nid) + feature <- bart_model$mean_forests$node_split_index(forest_num, tree_num, nid) + threshold <- bart_model$mean_forests$node_split_threshold(forest_num, tree_num, nid) + right <- bart_model$mean_forests$right_child_node(forest_num, tree_num, nid) + cat(space_text, "node=", nid, " is a split node, which tells us to go to node ", + left, " if X[:, ", feature, "] <= ", format(threshold, digits = 3), + " else to node ", right, "\n", sep = "") + } +} +``` + +## Python + +```{python} +forest_num = 9 +tree_num = 0 +fc = bart_model.forest_container_mean +nodes = np.sort(fc.nodes(forest_num, tree_num)) +for nid in nodes: + depth = fc.node_depth(forest_num, tree_num, nid) + indent = "\t" * depth + if fc.is_leaf_node(forest_num, tree_num, nid): + value = np.round(fc.node_leaf_values(forest_num, tree_num, nid), 3) + print(f"{indent}node={nid} is a leaf node with value={value}") + else: + + left = fc.left_child_node(forest_num, tree_num, nid) + feature = fc.node_split_index(forest_num, tree_num, nid) + threshold = round(fc.node_split_threshold(forest_num, tree_num, nid), 3) + right = fc.right_child_node(forest_num, tree_num, nid) + print(f"{indent}node={nid} is a split node, which tells us to go to node " + f"{left} if X[:, {feature}] <= {threshold} else to node {right}") +``` + +:::: diff --git a/vignettes/vignettes.bib b/vignettes/vignettes.bib new file mode 100644 index 000000000..3de94125f --- /dev/null +++ b/vignettes/vignettes.bib @@ -0,0 +1,237 @@ +@book{gelman2013bayesian, + title={Bayesian Data Analysis}, + edition={Third}, + author={Gelman, Andrew and Carlin, John B and Stern, Hal S and Dunson, David B and Vehtari, Aki and Rubin, Donald B}, + year={2013}, + publisher={Chapman and Hall/CRC} +} + +@article{friedman1991multivariate, + title={Multivariate adaptive regression splines}, + author={Friedman, Jerome H}, + journal={The annals of statistics}, + volume={19}, + number={1}, + pages={1--67}, + year={1991}, + publisher={Institute of Mathematical Statistics} +} + +@article{mcdonald1992effects, + title={Effects of computer reminders for influenza vaccination on morbidity during influenza epidemics.}, + author={McDonald, Clement J and Hui, Siu L and Tierney, William M}, + journal={MD computing: computers in medical practice}, + volume={9}, + number={5}, + pages={304--312}, + year={1992} +} + +@article{hirano2000assessing, + author = {Hirano, Keisuke and Imbens, Guido W. and Rubin, Donald B. and Zhou, Xiao-Hua}, + title = {Assessing the effect of an influenza vaccine in an + encouragement design }, + journal = {Biostatistics}, + volume = {1}, + number = {1}, + pages = {69-88}, + year = {2000}, + month = {03}, + issn = {1465-4644}, + doi = {10.1093/biostatistics/1.1.69}, + url = {https://doi.org/10.1093/biostatistics/1.1.69}, + eprint = {https://academic.oup.com/biostatistics/article-pdf/1/1/69/17744019/100069.pdf}, +} + +@incollection{richardson2011transparent, + author = {Richardson, Thomas S. and Evans, Robin J. and Robins, James M.}, + isbn = {9780199694587}, + title = {Transparent Parametrizations of Models for Potential Outcomes}, + booktitle = {Bayesian Statistics 9}, + publisher = {Oxford University Press}, + year = {2011}, + month = {10}, + doi = {10.1093/acprof:oso/9780199694587.003.0019}, + url = {https://doi.org/10.1093/acprof:oso/9780199694587.003.0019}, + eprint = {https://academic.oup.com/book/0/chapter/141661815/chapter-ag-pdf/45787772/book\_1879\_section\_141661815.ag.pdf}, +} + +@book{imbens2015causal, + place={Cambridge}, + title={Causal Inference for Statistics, Social, and Biomedical Sciences: An Introduction}, + publisher={Cambridge University Press}, + author={Imbens, Guido W. and Rubin, Donald B.}, + year={2015} +} + +@article{hahn2016bayesian, + title={A Bayesian partial identification approach to inferring the prevalence of accounting misconduct}, + author={Hahn, P Richard and Murray, Jared S and Manolopoulou, Ioanna}, + journal={Journal of the American Statistical Association}, + volume={111}, + number={513}, + pages={14--26}, + year={2016}, + publisher={Taylor \& Francis} +} + +@article{albert1993bayesian, + title={Bayesian analysis of binary and polychotomous response data}, + author={Albert, James H and Chib, Siddhartha}, + journal={Journal of the American statistical Association}, + volume={88}, + number={422}, + pages={669--679}, + year={1993}, + publisher={Taylor \& Francis} +} + +@article{papakostas2023forecasts, + title={Do forecasts of bankruptcy cause bankruptcy? A machine learning sensitivity analysis}, + author={Papakostas, Demetrios and Hahn, P Richard and Murray, Jared and Zhou, Frank and Gerakos, Joseph}, + journal={The Annals of Applied Statistics}, + volume={17}, + number={1}, + pages={711--739}, + year={2023}, + publisher={Institute of Mathematical Statistics} +} + +@article{lindo2010ability, + title={Ability, gender, and performance standards: Evidence from academic probation}, + author={Lindo, Jason M and Sanders, Nicholas J and Oreopoulos, Philip}, + journal={American economic journal: Applied economics}, + volume={2}, + number={2}, + pages={95--117}, + year={2010}, + publisher={American Economic Association} +} + +@article{murray2021log, + title={Log-linear Bayesian additive regression trees for multinomial logistic and count regression models}, + author={Murray, Jared S}, + journal={Journal of the American Statistical Association}, + volume={116}, + number={534}, + pages={756--769}, + year={2021}, + publisher={Taylor \& Francis} +} + +@article{pratola2020heteroscedastic, + title={Heteroscedastic BART via multiplicative regression trees}, + author={Pratola, Matthew T and Chipman, Hugh A and George, Edward I and McCulloch, Robert E}, + journal={Journal of Computational and Graphical Statistics}, + volume={29}, + number={2}, + pages={405--417}, + year={2020}, + publisher={Taylor \& Francis} +} + +@article{murray2021log, + title={Log-linear Bayesian additive regression trees for multinomial logistic and count regression models}, + author={Murray, Jared S}, + journal={Journal of the American Statistical Association}, + volume={116}, + number={534}, + pages={756--769}, + year={2021}, + publisher={Taylor \& Francis} +} + +@article{hahn2020bayesian, + title={Bayesian regression tree models for causal inference: Regularization, confounding, and heterogeneous effects (with discussion)}, + author={Hahn, P Richard and Murray, Jared S and Carvalho, Carlos M}, + journal={Bayesian Analysis}, + volume={15}, + number={3}, + pages={965--1056}, + year={2020}, + publisher={International Society for Bayesian Analysis} +} + +@article{chipman2010bart, +author = {Hugh A. Chipman and Edward I. George and Robert E. McCulloch}, +title = {{BART: Bayesian additive regression trees}}, +volume = {4}, +journal = {The Annals of Applied Statistics}, +number = {1}, +publisher = {Institute of Mathematical Statistics}, +pages = {266 -- 298}, +keywords = {Bayesian backfitting, boosting, CART, β€Žclassificationβ€Ž, ensemble, MCMC, Nonparametric regression, probit model, random basis, regularizatio, sum-of-trees model, Variable selection, weak learner}, +year = {2010}, +doi = {10.1214/09-AOAS285}, +URL = {https://doi.org/10.1214/09-AOAS285} +} + +@article{he2023stochastic, + title={Stochastic tree ensembles for regularized nonlinear regression}, + author={He, Jingyu and Hahn, P Richard}, + journal={Journal of the American Statistical Association}, + volume={118}, + number={541}, + pages={551--570}, + year={2023}, + publisher={Taylor \& Francis} +} + +@book{pearl2009causality, + title={Causality}, + author={Pearl, Judea}, + year={2009}, + publisher={Cambridge university press} +} + +@book{imbens2015causal, + title={Causal inference in statistics, social, and biomedical sciences}, + author={Imbens, Guido W and Rubin, Donald B}, + year={2015}, + publisher={Cambridge university press} +} + +@inproceedings{krantsevich2023stochastic, + title={Stochastic tree ensembles for estimating heterogeneous effects}, + author={Krantsevich, Nikolay and He, Jingyu and Hahn, P Richard}, + booktitle={International Conference on Artificial Intelligence and Statistics}, + pages={6120--6131}, + year={2023}, + organization={PMLR} +} + +@Article{gramacy2010categorical, + title = {Categorical Inputs, Sensitivity Analysis, Optimization and Importance Tempering with {tgp} Version 2, an {R} Package for Treed Gaussian Process Models}, + author = {Robert B. Gramacy and Matthew Taddy}, + journal = {Journal of Statistical Software}, + year = {2010}, + volume = {33}, + number = {6}, + pages = {1--48}, + url = {https://www.jstatsoft.org/v33/i06/}, + doi = {10.18637/jss.v033.i06}, +} + +@book{gramacy2020surrogates, + title = {Surrogates: {G}aussian Process Modeling, Design and \ + Optimization for the Applied Sciences}, + author = {Robert B. Gramacy}, + publisher = {Chapman Hall/CRC}, + address = {Boca Raton, Florida}, + note = {\url{http://bobby.gramacy.com/surrogates/}}, + year = {2020} +} + +@book{scholkopf2002learning, + title={Learning with kernels: support vector machines, regularization, optimization, and beyond}, + author={Sch{\"o}lkopf, Bernhard and Smola, Alexander J}, + year={2002}, + publisher={MIT press} +} + +@article{alam2025unified, + title={A Unified Bayesian Nonparametric Framework for Ordinal, Survival, and Density Regression Using the Complementary Log-Log Link}, + author={Alam, Entejar and Linero, Antonio R}, + journal={arXiv preprint arXiv:2502.00606}, + year={2025} +}