diff --git a/docs/examples.rst b/docs/examples.rst index b813285..ad98a55 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -17,3 +17,15 @@ The gallery below presents examples that demonstrate the use of Simuk. +++ SBC + + .. grid-item-card:: + :link: ./examples/gallery/posterior_sbc.html + :text-align: center + :shadow: none + :class-card: example-gallery + + .. image:: examples/img/posterior_sbc.png + :alt: Posterior SBC + + +++ + Posterior SBC diff --git a/docs/examples/gallery/posterior_sbc.md b/docs/examples/gallery/posterior_sbc.md new file mode 100644 index 0000000..fa61ca4 --- /dev/null +++ b/docs/examples/gallery/posterior_sbc.md @@ -0,0 +1,157 @@ +--- +jupytext: + text_representation: + extension: .md + format_name: myst +kernelspec: + display_name: Python 3 + language: python + name: python3 +--- + +# Posterior simulation based calibration + +```{jupyter-execute} + +from arviz_plots import plot_ecdf_pit, style +import numpy as np +import simuk +style.use("arviz-variat") +``` + +This example demonstrates how to use the `SBC` class for posterior simulation-based calibration (SBC), supporting PyMC, Bambi and Numpyro models. In this version of SBC, we aim to validate the inference conditioned on the observed data and we restrict the analysis to space of parameters supported by the posterior distribution. + +::::::{tab-set} +:class: full-width + +:::::{tab-item} PyMC +:sync: pymc_default + +First, define a PyMC model and sample from the posterior distribution. In this example, we will use the centered eight schools model. + +```{jupyter-execute} + +import pymc as pm + +data = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]) +sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]) + +with pm.Model() as centered_eight: + mu = pm.Normal('mu', mu=0, sigma=5) + tau = pm.HalfCauchy('tau', beta=5) + theta = pm.Normal('theta', mu=mu, sigma=tau, shape=8) + y_obs = pm.Normal('y', mu=theta, sigma=sigma, observed=data) + trace = pm.sample(1000) +``` + +Pass the model and the trace to the `SBC` class, set the number of simulations to 100, and run the simulations. Parameters will be drawn from the provided trace (which are from the posterior distribution). If the trace is not provided, the model will be sampled from the prior distribution (prior SBC). This process may take some time since the model runs multiple times (100 in this example). + +```{jupyter-execute} + +sbc = simuk.SBC(centered_eight, + num_simulations=100, + trace=trace, + sample_kwargs={'draws': 25, 'tune': 50}) + +sbc.run_simulations(); +``` + +We compare the posterior distribution (conditional on our observed data) and the posterior distribution conditional on the data and the simulated data using the ArviZ function `plot_ecdf_pit`. We expect a uniform distribution; the gray envelope corresponds to the 94% credible interval. + +```{jupyter-execute} + +plot_ecdf_pit(sbc.simulations, + visuals={"xlabel":False}, +); +``` + +::::: + +:::::{tab-item} Bambi +:sync: bambi_default + +Now, we define a Bambi Model and sample from the posterior distribution. + +```{jupyter-execute} + +import bambi as bmb +import pandas as pd + +x = np.random.normal(0, 1, 200) +y = 2 + np.random.normal(x, 1) +df = pd.DataFrame({"x": x, "y": y}) +bmb_model = bmb.Model("y ~ x", df) +trace = bmb_model.fit(num_samples=25, tune=50) +``` + +Pass the model and the trace to the `SBC` class, set the number of simulations to 100, and run the simulations. +Parameters will be drawn from the provided trace (which are from the posterior distribution). If the trace is not provided, the model will be sampled from the prior distribution (prior SBC). This process may take some time, as the model runs multiple times. + +```{jupyter-execute} + +sbc = simuk.SBC(bmb_model, + num_simulations=100, + trace=trace, + sample_kwargs={'draws': 25, 'tune': 50}) + +sbc.run_simulations(); +``` + +We compare the posterior distribution (conditional on our observed data) and the posterior distribution conditional on the data and the simulated data using the ArviZ function `plot_ecdf_pit`. We expect a uniform distribution; the gray envelope corresponds to the 94% credible interval. + + +```{jupyter-execute} +plot_ecdf_pit(sbc.simulations) +``` + +::::: + +:::::{tab-item} Numpyro +:sync: numpyro_default + +We define a Numpyro Model, we use the centered eight schools model. + +```{jupyter-execute} +import numpyro +import numpyro.distributions as dist +from jax import random +from numpyro.infer import MCMC, NUTS + +y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]) +sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]) + +def eight_schools_cauchy_prior(J, sigma, y=None): + mu = numpyro.sample("mu", dist.Normal(0, 5)) + tau = numpyro.sample("tau", dist.HalfCauchy(5)) + with numpyro.plate("J", J): + theta = numpyro.sample("theta", dist.Normal(mu, tau)) + numpyro.sample("y", dist.Normal(theta, sigma), obs=y) +``` + +We obtain samples from the posterior by running MCMC. +```{jupyter-execute} +# We use the NUTS sampler +nuts_kernel = NUTS(eight_schools_cauchy_prior) +mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000) +mcmc.run(random.PRNGKey(0), J=8, sigma=sigma, y=y) +``` + +Pass the model and the `mcmc` to the `SBC` class, set the number of simulations to 100, and run the simulations. For numpyro model, we pass in the ``data_dir`` parameter. + +```{jupyter-execute} +sbc = simuk.SBC(nuts_kernel, + sample_kwargs={"num_warmup": 50, "num_samples": 75}, + trace=mcmc, + num_simulations=100, + data_dir={"J": 8, "sigma": sigma, "y": y}, +) +sbc.run_simulations() +``` + +We compare the posterior distribution (conditional on our observed data) and the posterior distribution conditional on the data and the simulated data using the ArviZ function `plot_ecdf_pit`. We expect a uniform distribution; the gray envelope corresponds to the 94% credible interval. + +```{jupyter-execute} +plot_ecdf_pit(sbc.simulations, + visuals={"xlabel":False}, +); +``` diff --git a/docs/examples/gallery/sbc.md b/docs/examples/gallery/sbc.md index f7b66f4..12bd166 100644 --- a/docs/examples/gallery/sbc.md +++ b/docs/examples/gallery/sbc.md @@ -11,8 +11,6 @@ kernelspec: # Simulation based calibration -This example demonstrates how to use the `SBC` class for simulation-based calibration, supporting both PyMC and Bambi models. - ```{jupyter-execute} from arviz_plots import plot_ecdf_pit, style @@ -21,11 +19,15 @@ import simuk style.use("arviz-variat") ``` +## Out-of-the-box SBC +This example demonstrates how to use the `SBC` class for simulation-based calibration, supporting PyMC, Bambi and Numpyro models. By default, the generative model implied by the probabilistic model is used. + + ::::::{tab-set} :class: full-width :::::{tab-item} PyMC -:sync: pymc +:sync: pymc_default First, define a PyMC model. In this example, we will use the centered eight schools model. @@ -69,7 +71,7 @@ plot_ecdf_pit(sbc.simulations, ::::: :::::{tab-item} Bambi -:sync: bambi +:sync: bambi_default Now, we define a Bambi Model. @@ -106,7 +108,7 @@ plot_ecdf_pit(sbc.simulations) ::::: :::::{tab-item} Numpyro -:sync: numpyro +:sync: numpyro_default We define a Numpyro Model, we use the centered eight schools model. @@ -150,3 +152,89 @@ plot_ecdf_pit(sbc.simulations, visuals={"xlabel":False}, ); ``` + +::::: + +:::::: + +## Custom simulator SBC + +::::::{tab-set} +:class: full-width + +:::::{tab-item} PyMC +:sync: pymc_custom + +In certain scenarios, you might want to pass a custom function to the `SBC` class to generate the data. For instance, if you aim to evaluate the effect of model misspecification by generating data from a different model than the one used for model fitting. + +Next, we determine the impact of occasional large deviations (outliers) by drawing from a Laplace distribution instead of a normal distribution (which we use to fit the model). + +```{jupyter-execute} +def simulator(theta, seed, **kwargs): + rng = np.random.default_rng(seed) + # Here we use a Laplace distribution, but it could also be some mechanistic simulator + scale = sigma / np.sqrt(2) + return {"y": rng.laplace(theta, scale)} + +sbc = simuk.SBC(centered_eight, + num_simulations=100, + simulator=simulator, + sample_kwargs={'draws': 25, 'tune': 50}) + +sbc.run_simulations(); +``` + +::::: + +:::::{tab-item} Bambi +:sync: bambi_custom + +In certain scenarios, you might want to pass a custom function to the `SBC` class to generate the data. For instance, if you aim to evaluate the effect of model misspecification by generating data from a different model than the one used for model fitting. + +Next, we determine the impact of occasional large deviations (outliers) by drawing from a Laplace distribution instead of a normal distribution (which we use to fit the model). + +```{jupyter-execute} +def simulator(mu, seed, sigma, **kwargs): + rng = np.random.default_rng(seed) + # Here we use a Laplace distribution, but it could also be some mechanistic simulator + scale = sigma / np.sqrt(2) + return {"y": rng.laplace(mu, scale)} + +sbc = simuk.SBC(bmb_model, + num_simulations=100, + simulator=simulator, + sample_kwargs={'draws': 25, 'tune': 50}) + +sbc.run_simulations(); +``` + +::::: + + +:::::{tab-item} Numpyro +:sync: numpyro_custom + +In certain scenarios, you might want to pass a custom function to the `SBC` class to generate the data. For instance, if you aim to evaluate the effect of model misspecification by generating data from a different model than the one used for model fitting. + +Next, we determine the impact of occasional large deviations (outliers) by drawing from a Laplace distribution instead of a normal distribution (which we use to fit the model). + +```{jupyter-execute} +def simulator(theta, seed, **kwargs): + rng = np.random.default_rng(seed) + # Here we use a Laplace distribution, but it could also be some mechanistic simulator + scale = sigma / np.sqrt(2) + return {"y": rng.laplace(theta, scale)} + +sbc = simuk.SBC(nuts_kernel, + sample_kwargs={"num_warmup": 50, "num_samples": 75}, + num_simulations=100, + simulator=simulator, + data_dir={"J": 8, "sigma": sigma, "y": y} +) + +sbc.run_simulations(); +``` + +::::: + +:::::: diff --git a/docs/examples/img/posterior_sbc.png b/docs/examples/img/posterior_sbc.png new file mode 100644 index 0000000..94fffae Binary files /dev/null and b/docs/examples/img/posterior_sbc.png differ diff --git a/simuk/sbc.py b/simuk/sbc.py index 5ea6d1d..3458ecb 100644 --- a/simuk/sbc.py +++ b/simuk/sbc.py @@ -18,6 +18,8 @@ from collections.abc import Mapping +import arviz +import xarray import numpy as np from arviz_base import dict_to_dataset, extract, from_dict, from_numpyro from tqdm import tqdm @@ -61,9 +63,14 @@ class SBC: data_dir : dict Keyword arguments passed to numpyro model, intended for use when providing an MCMC Kernel model. - simulator : callable - A custom simulator function that takes as input the model parameters and - a int parameter named `seed`, and must return a dictionary of named observations. + simulator : callable + A custom simulator function that takes as input the model parameters and + a int parameter named `seed`, and must return a dictionary of named observations. + trace : arviz.InferenceData | numpyro.infer.mcmc.MCMC + Trace generated from fitting the model to the data. If provided, posterior SBC + (rather than prior SBC) will be performed. In this version, we aim to validate + the inference conditionally on observed data by drawing parameters from the + posterior distribution. Example ------- @@ -86,6 +93,7 @@ def __init__( seed=None, data_dir=None, simulator=None, + trace=None, ): if hasattr(model, "basic_RVs") and isinstance(model, pm.Model): self.engine = "pymc" @@ -139,6 +147,37 @@ def __init__( "with the `simulator` argument." ) self.simulator = simulator + # Check provided trace + self._trace_posterior = None + if trace is not None: + if simulator is not None: + raise NotImplementedError + if self.engine == "numpyro" and isinstance(trace, MCMC): + # Recall that we are not calling `arviz_base.from_dict` but `arviz.from_dict` + # TODO: check if this is intended or some transitions between API versions + trace = arviz.from_dict(trace.get_samples(group_by_chain=True)) + if not hasattr(trace, "posterior"): + raise ValueError("The provided trace does not contain a `posterior` attribute.") + # Flatten the posterior + posterior = trace.posterior.stack(sample=("chain", "draw")) + n_draws = posterior.sizes["sample"] + if n_draws < self.num_simulations: + # TODO: We could sample with replacement from the trace. + # Or set num_simulations to n_draws with a warning + raise ValueError( + f"The provided trace does not contain enough samples, " + f"it contains {n_draws} draws, " + f"but {self.num_simulations} are required." + ) + rng = np.random.default_rng(seed) + sampled_indices = rng.choice(n_draws, size=self.num_simulations, replace=False) + posterior = posterior.isel(sample=sampled_indices) + # Reshape posterior to comply with PyMC requirements + posterior = posterior.reset_index("sample", drop=True) + posterior = posterior.expand_dims({"chain": [0]}).rename({"sample": "draw"}) + self._trace_posterior = posterior + # Get observed data from `trace.observed_data`. This will fail if no observed data present + self._observed_data = trace.observed_data def _extract_variable_names(self): """Extract observed and free variables from the model.""" @@ -165,8 +204,51 @@ def _get_seeds(self): rng = np.random.default_rng(self.seed) return rng.integers(0, 2**30, size=self.num_simulations) + def _get_augmented_predictive_samples(self): + """Generate samples to use for the simulations in posterior SBC.""" + # Draw parameters from posterior for posterior SBC + with self.model: + # We treat the posterior as the new prior and generate synthetic data + prior = extract(self._trace_posterior, group="posterior", keep_dataset=True) + if self.simulator is None: + idata = pm.sample_posterior_predictive( + trace=self._trace_posterior, random_seed=self._seeds[0] + ) + prior_pred = extract(idata, group="posterior_predictive", keep_dataset=True) + # Augment `prior_pred` with observed data + observed = self._observed_data + y_dim = [d for d in observed.dims if d.startswith("y")][0] + sample_dim = "sample" + observed_b = observed.expand_dims( + {sample_dim: prior_pred[sample_dim]} + ).assign_coords({sample_dim: prior_pred[sample_dim]}) + augmented = xarray.concat([prior_pred, observed_b], dim=y_dim) + return prior, augmented + else: + # Deal with custom simulator + prior_pred = [] + for i in range(prior.sizes["sample"]): + params = {var: prior[var].isel(sample=i).values for var in prior.data_vars} + params["seed"] = self._seeds[i] + try: + res = self.simulator(**params) + assert isinstance(res, Mapping), ( + f"Simulator must return a dictionary, got {type(res)}" + ) + prior_pred.append(res) + except Exception as e: + raise ValueError( + f"Error generating prior predictive sample with parameters {params}: {e}." + ) + prior_pred = dict_to_dataset( + {key: np.stack([pp[key] for pp in prior_pred]) for key in prior_pred[0]}, + sample_dims=["sample"], + coords={**prior.coords}, + ) + def _get_prior_predictive_samples(self): - """Generate samples to use for the simulations.""" + """Generate samples to use for the simulations in prior SBC.""" + # Draw parameters from prior for prior SBC with self.model: idata = pm.sample_prior_predictive( samples=self.num_simulations, random_seed=self._seeds[0] @@ -175,53 +257,76 @@ def _get_prior_predictive_samples(self): if self.simulator is None: prior_pred = extract(idata, group="prior_predictive", keep_dataset=True) return prior, prior_pred - # Deal with custom simulator - prior_pred = [] - for i in range(prior.sizes["sample"]): - params = {var: prior[var].isel(sample=i).values for var in prior.data_vars} - params["seed"] = self._seeds[i] - try: - res = self.simulator(**params) - assert isinstance( - res, Mapping - ), f"Simulator must return a dictionary, got {type(res)}" - prior_pred.append(res) - except Exception as e: - raise ValueError( - f"Error generating prior predictive sample with parameters {params}: {e}." - ) - prior_pred = dict_to_dataset( - {key: np.stack([pp[key] for pp in prior_pred]) for key in prior_pred[0]}, - sample_dims=["sample"], - coords={**prior.coords}, - ) + # Deal with custom simulator + prior_pred = [] + for i in range(prior.sizes["sample"]): + params = {var: prior[var].isel(sample=i).values for var in prior.data_vars} + params["seed"] = self._seeds[i] + try: + res = self.simulator(**params) + assert isinstance(res, Mapping), ( + f"Simulator must return a dictionary, got {type(res)}" + ) + prior_pred.append(res) + except Exception as e: + raise ValueError( + f"Error generating prior predictive sample with parameters {params}: {e}." + ) + prior_pred = dict_to_dataset( + {key: np.stack([pp[key] for pp in prior_pred]) for key in prior_pred[0]}, + sample_dims=["sample"], + coords={**prior.coords}, + ) return prior, prior_pred def _get_prior_predictive_samples_numpyro(self): """Generate samples to use for the simulations using numpyro.""" - predictive = Predictive(self.model, num_samples=self.num_simulations) - free_vars_data = {k: v for k, v in self.data_dir.items() if k not in self.observed_vars} - samples = predictive(jax.random.PRNGKey(self._seeds[0]), **free_vars_data) - prior = {k: v for k, v in samples.items() if k not in self.observed_vars} - if self.simulator: - results = [] - for i, vals in enumerate(zip(*prior.values())): - params = dict(zip(prior.keys(), vals)) - params["seed"] = self._seeds[i] - results.append(self.simulator(**params)) - prior_pred = {key: [result[key] for result in results] for key in results[0]} + # Draw parameters from prior for prior SBC + if self._trace_posterior is None: + predictive = Predictive(self.model, num_samples=self.num_simulations) + free_vars_data = {k: v for k, v in self.data_dir.items() if k not in self.observed_vars} + samples = predictive(jax.random.PRNGKey(self._seeds[0]), **free_vars_data) + prior = {k: v for k, v in samples.items() if k not in self.observed_vars} + if self.simulator is None: + prior_pred = {k: v for k, v in samples.items() if k in self.observed_vars} + # Draw parameters from posterior for posterior SBC else: - prior_pred = {k: v for k, v in samples.items() if k in self.observed_vars} + # Use the flattened posterior as prior samples + prior = {} + for var_name in self.var_names: + if var_name in self._trace_posterior: + # Convert from xarray to numpy array + # TODO: this feels like a hack to maintain the shape of + # the array one gets if running mcmc.get_samples(). + # Perhaps consider an alternative approach + prior[var_name] = np.squeeze( + self._trace_posterior[var_name].to_numpy(), axis=0 + ).T + if self.simulator is None: + predictive = Predictive(self.model, posterior_samples=prior) + free_vars_data = { + k: v for k, v in self.data_dir.items() if k not in self.observed_vars + } + pred_samples = predictive(jax.random.PRNGKey(self._seeds[0]), **free_vars_data) + prior_pred = {k: v for k, v in pred_samples.items() if k in self.observed_vars} + if self.simulator is None: + return prior, prior_pred + # If using custom simulator, we don't have to differentiate between prior and posterior SBC + results = [] + for i, vals in enumerate(zip(*prior.values())): + params = dict(zip(prior.keys(), vals)) + params["seed"] = self._seeds[i] + results.append(self.simulator(**params)) + prior_pred = {key: [result[key] for result in results] for key in results[0]} return prior, prior_pred def _get_posterior_samples(self, prior_predictive_draw): - """Generate posterior samples conditioned to a prior predictive sample.""" + """Generate posterior samples conditioned to a prior predictive sample or an augmented posterior predictive sample.""" new_model = pm.observe(self.model, prior_predictive_draw) with new_model: check = pm.sample( **self.sample_kwargs, random_seed=self._seeds[self._simulations_complete] ) - posterior = extract(check, group="posterior", keep_dataset=True) return posterior @@ -260,8 +365,10 @@ def run_simulations(self): seed was passed initially, it will still be respected (that is, the resulting simulations will be identical to running without pausing in the middle). """ - prior, prior_pred = self._get_prior_predictive_samples() - + if self._trace_posterior is None: + prior, prior_pred = self._get_prior_predictive_samples() + else: + prior, prior_pred = self._get_augmented_predictive_samples() progress = tqdm( initial=self._simulations_complete, total=self.num_simulations, diff --git a/simuk/tests/test_sbc.py b/simuk/tests/test_sbc.py index 1a53b0d..3a74b2b 100644 --- a/simuk/tests/test_sbc.py +++ b/simuk/tests/test_sbc.py @@ -6,8 +6,9 @@ import pandas as pd import pymc as pm import pytest +from jax import random from numba import njit -from numpyro.infer import NUTS +from numpyro.infer import MCMC, NUTS import simuk @@ -179,3 +180,23 @@ def test_sbc_numpyro_fail_no_observed_variable(): sample_kwargs={"num_warmup": 50, "num_samples": 25}, ) sbc.run_simulations() + + +# Test posterior SBC + +with pm.Model() as posterior_model: + mu = pm.Normal("mu", mu=0, sigma=5) + y_obs = pm.Normal("y", mu=mu, sigma=1.0, observed=data) + + +def test_posterior_sbc_pymc_with_observed_variables(): + with posterior_model: + trace = pm.sample(draws=100, tune=100, chains=4) + sbc = simuk.SBC( + posterior_model, + trace=trace, + num_simulations=10, + sample_kwargs={"draws": 5, "tune": 5}, + ) + sbc.run_simulations() + assert "posterior_sbc" in sbc.simulations