diff --git a/environment-dev.yml b/environment-dev.yml index deda50d..7f934ba 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -15,4 +15,3 @@ dependencies: - pytest>=4.4.0 - pre-commit>=2.19 - ruff==0.9.1 - diff --git a/requirements-dev.txt b/requirements-dev.txt index 954b640..a889cba 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -9,3 +9,4 @@ bambi>=0.13.0 arviz_base>=0.5.0 ruff==0.9.1 numpyro>=0.17.0 +numba>=0.60.0 diff --git a/simuk/sbc.py b/simuk/sbc.py index 4c070b3..5ea6d1d 100644 --- a/simuk/sbc.py +++ b/simuk/sbc.py @@ -16,8 +16,10 @@ except ImportError: pass +from collections.abc import Mapping + import numpy as np -from arviz_base import extract, from_dict, from_numpyro +from arviz_base import dict_to_dataset, extract, from_dict, from_numpyro from tqdm import tqdm @@ -59,6 +61,9 @@ class SBC: data_dir : dict Keyword arguments passed to numpyro model, intended for use when providing an MCMC Kernel model. + simulator : callable + A custom simulator function that takes as input the model parameters and + a int parameter named `seed`, and must return a dictionary of named observations. Example ------- @@ -73,7 +78,15 @@ class SBC: sbc.run_simulations() """ - def __init__(self, model, num_simulations=1000, sample_kwargs=None, seed=None, data_dir=None): + def __init__( + self, + model, + num_simulations=1000, + sample_kwargs=None, + seed=None, + data_dir=None, + simulator=None, + ): if hasattr(model, "basic_RVs") and isinstance(model, pm.Model): self.engine = "pymc" self.model = model @@ -110,6 +123,22 @@ def __init__(self, model, num_simulations=1000, sample_kwargs=None, seed=None, d self._extract_variable_names() self.simulations = {name: [] for name in self.var_names} self._simulations_complete = 0 + if simulator is not None and not callable(simulator): + raise ValueError("simulator should be a function or None") + if simulator is not None and self.observed_vars: + logging.warning( + "Provided model contains both observed variables and a simulator. " + "Ignoring observed variables and using the simulator instead." + ) + if simulator is None and not self.observed_vars and self.engine == "pymc": + # Ideally, we could raise an error early for `numpyro` also, + # but `factor` also produces 'observed_vars' + raise ValueError( + "There are no observed variables, and PyMC will not generate prior " + "predictive samples. Either change the model or specify a simulator " + "with the `simulator` argument." + ) + self.simulator = simulator def _extract_variable_names(self): """Extract observed and free variables from the model.""" @@ -142,8 +171,30 @@ def _get_prior_predictive_samples(self): idata = pm.sample_prior_predictive( samples=self.num_simulations, random_seed=self._seeds[0] ) - prior_pred = extract(idata, group="prior_predictive", keep_dataset=True) prior = extract(idata, group="prior", keep_dataset=True) + if self.simulator is None: + prior_pred = extract(idata, group="prior_predictive", keep_dataset=True) + return prior, prior_pred + # Deal with custom simulator + prior_pred = [] + for i in range(prior.sizes["sample"]): + params = {var: prior[var].isel(sample=i).values for var in prior.data_vars} + params["seed"] = self._seeds[i] + try: + res = self.simulator(**params) + assert isinstance( + res, Mapping + ), f"Simulator must return a dictionary, got {type(res)}" + prior_pred.append(res) + except Exception as e: + raise ValueError( + f"Error generating prior predictive sample with parameters {params}: {e}." + ) + prior_pred = dict_to_dataset( + {key: np.stack([pp[key] for pp in prior_pred]) for key in prior_pred[0]}, + sample_dims=["sample"], + coords={**prior.coords}, + ) return prior, prior_pred def _get_prior_predictive_samples_numpyro(self): @@ -152,7 +203,15 @@ def _get_prior_predictive_samples_numpyro(self): free_vars_data = {k: v for k, v in self.data_dir.items() if k not in self.observed_vars} samples = predictive(jax.random.PRNGKey(self._seeds[0]), **free_vars_data) prior = {k: v for k, v in samples.items() if k not in self.observed_vars} - prior_pred = {k: v for k, v in samples.items() if k in self.observed_vars} + if self.simulator: + results = [] + for i, vals in enumerate(zip(*prior.values())): + params = dict(zip(prior.keys(), vals)) + params["seed"] = self._seeds[i] + results.append(self.simulator(**params)) + prior_pred = {key: [result[key] for result in results] for key in results[0]} + else: + prior_pred = {k: v for k, v in samples.items() if k in self.observed_vars} return prior, prior_pred def _get_posterior_samples(self, prior_predictive_draw): @@ -170,7 +229,12 @@ def _get_posterior_samples_numpyro(self, prior_predictive_draw): """Generate posterior samples using numpyro conditioned to a prior predictive sample.""" mcmc = MCMC(self.numpyro_model, **self.sample_kwargs) rng_seed = jax.random.PRNGKey(self._seeds[self._simulations_complete]) - free_vars_data = {k: v for k, v in self.data_dir.items() if k not in self.observed_vars} + # If using a custom simulator, some variables present in `prior_predictive_draw` + # might be missing from self.observed_vars. + # TODO: Not sure if the union is redundant here and perhaps prior_predictive_draw.keys() + # could be sufficient. + extended_observed_vars = set(prior_predictive_draw.keys()).union(self.observed_vars) + free_vars_data = {k: v for k, v in self.data_dir.items() if k not in extended_observed_vars} mcmc.run(rng_seed, **free_vars_data, **prior_predictive_draw) return from_numpyro(mcmc)["posterior"] diff --git a/simuk/tests/test_sbc.py b/simuk/tests/test_sbc.py index aab4b85..17e0e1a 100644 --- a/simuk/tests/test_sbc.py +++ b/simuk/tests/test_sbc.py @@ -1,33 +1,88 @@ import bambi as bmb +import jax.numpy as jnp import numpy as np import numpyro import numpyro.distributions as dist import pandas as pd import pymc as pm import pytest +from numba import njit from numpyro.infer import NUTS import simuk np.random.seed(1234) +# Test data data = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]) sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]) +# PyMC models with pm.Model() as centered_eight: mu = pm.Normal("mu", mu=0, sigma=5) tau = pm.HalfCauchy("tau", beta=5) theta = pm.Normal("theta", mu=mu, sigma=tau, shape=8) y_obs = pm.Normal("y", mu=theta, sigma=sigma, observed=data) +with pm.Model() as centered_eight_no_observed: + mu = pm.Normal("mu", mu=0, sigma=5) + tau = pm.HalfCauchy("tau", beta=5) + theta = pm.Normal("theta", mu=mu, sigma=tau, shape=8) + + def log_likelihood(theta, observed): + return pm.math.sum(pm.logp(pm.Normal.dist(mu=theta, sigma=sigma), observed)) + + pm.Potential("y_loglike", log_likelihood(mu, data)) + +# Bambi model x = np.random.normal(0, 1, 20) y = 2 + np.random.normal(x, 1) df = pd.DataFrame({"x": x, "y": y}) bmb_model = bmb.Model("y ~ x", df) +# NumPyro models +def eight_schools_cauchy_prior(J, sigma, y=None): + mu = numpyro.sample("mu", dist.Normal(0, 5)) + tau = numpyro.sample("tau", dist.HalfCauchy(5)) + with numpyro.plate("J", J): + theta = numpyro.sample("theta", dist.Normal(mu, tau)) + numpyro.sample("y", dist.Normal(theta, sigma), obs=y) + + +def eight_schools_cauchy_prior_no_observed(J, sigma, y=None): + mu = numpyro.sample("mu", dist.Normal(0, 5)) + tau = numpyro.sample("tau", dist.HalfCauchy(5)) + with numpyro.plate("J", J): + theta = numpyro.sample("theta", dist.Normal(mu, tau)) + if y is not None: + log_likelihood = jnp.sum(dist.Normal(theta, sigma).log_prob(y)) + numpyro.factor("custom_likelihood", log_likelihood) + + +# Custom simulator functions +def centered_eight_simulator(theta, seed, **kwargs): + rng = np.random.default_rng(seed) + return {"y": rng.normal(theta, sigma)} + +@njit +def centered_eight_jitted_simulator(tau, mu, theta, seed): + # Some expensive computation + n = theta.shape[0] + y = np.zeros(n) + for i in range(n): + y[i] = theta[i] + return {"y": y} + + +def bmb_simulator(mu, sigma, seed, **kwargs): + rng = np.random.default_rng(seed) + return {"y": rng.normal(mu, sigma)} + + +# --- Tests with observed variables --- @pytest.mark.parametrize("model", [centered_eight, bmb_model]) -def test_sbc(model): +def test_sbc_with_observed_data(model): sbc = simuk.SBC( model, num_simulations=10, @@ -37,22 +92,89 @@ def test_sbc(model): assert "prior_sbc" in sbc.simulations -def test_sbc_numpyro(): - y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]) - sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]) +def test_sbc_numpyro_with_observed_data(): + sbc = simuk.SBC( + NUTS(eight_schools_cauchy_prior), + data_dir={"J": 8, "sigma": sigma, "y": data}, + num_simulations=10, + sample_kwargs={"num_warmup": 50, "num_samples": 25}, + ) + sbc.run_simulations() + assert "prior_sbc" in sbc.simulations - def eight_schools_cauchy_prior(J, sigma, y=None): - mu = numpyro.sample("mu", dist.Normal(0, 5)) - tau = numpyro.sample("tau", dist.HalfCauchy(5)) - with numpyro.plate("J", J): - theta = numpyro.sample("theta", dist.Normal(mu, tau)) - numpyro.sample("y", dist.Normal(theta, sigma), obs=y) +# --- Tests with custom simulators --- +@pytest.mark.parametrize( + "model,simulator", + [ + # Case 1: Both simulator function and observed variables present + (centered_eight, centered_eight_simulator), + # Case 2: Only simulator function present + (centered_eight_no_observed, centered_eight_simulator), + ], +) +def test_sbc_with_custom_simulator(model, simulator): sbc = simuk.SBC( - NUTS(eight_schools_cauchy_prior), - data_dir={"J": 8, "sigma": sigma, "y": y}, + model, num_simulations=10, sample_kwargs={"draws": 5, "tune": 5}, simulator=simulator + ) + sbc.run_simulations() + assert "prior_sbc" in sbc.simulations + + +@pytest.mark.skipif( + hasattr(bmb, "__version__") and tuple(map(int, bmb.__version__.split("."))) <= (0, 14), + reason="requires bambi version > 0.14", +) +def test_sbc_bambi_with_custom_simulator(): + sbc = simuk.SBC( + bmb_model, + num_simulations=10, + sample_kwargs={"draws": 5, "tune": 5}, + simulator=bmb_simulator, + ) + sbc.run_simulations() + assert "prior_sbc" in sbc.simulations + + +@pytest.mark.parametrize( + "model,simulator", + [ + # Case 1: Both simulator function and observed variables present + (eight_schools_cauchy_prior, centered_eight_simulator), + # Case 2: Only simulator function present + (eight_schools_cauchy_prior_no_observed, centered_eight_simulator), + ], +) +def test_sbc_numpyro_with_custom_simulator(model, simulator): + sbc = simuk.SBC( + NUTS(model), + data_dir={"J": 8, "sigma": sigma, "y": data}, num_simulations=10, sample_kwargs={"num_warmup": 50, "num_samples": 25}, + simulator=simulator, ) sbc.run_simulations() assert "prior_sbc" in sbc.simulations + + +# --- Error handling tests with custom simulators --- +def test_sbc_fail_no_observed_variable(): + with pytest.raises(ValueError, match="no observed variables"): + simuk.SBC( + centered_eight_no_observed, + num_simulations=10, + sample_kwargs={"draws": 5, "tune": 5}, + ) + + +def test_sbc_numpyro_fail_no_observed_variable(): + # Note: factor variables are catalogued as 'observed_vars' in NumPyro + # therefore, we cannot raise an early exception with an informative message + with pytest.raises(ValueError): + sbc = simuk.SBC( + NUTS(eight_schools_cauchy_prior_no_observed), + data_dir={"J": 8, "sigma": sigma, "y": data}, + num_simulations=10, + sample_kwargs={"num_warmup": 50, "num_samples": 25}, + ) + sbc.run_simulations()