From 6477a0d646d061b60145c05e1c34f61bd4666214 Mon Sep 17 00:00:00 2001 From: Curro Campuzano Date: Sat, 30 Aug 2025 11:17:41 +0200 Subject: [PATCH 1/8] Minimal function --- simuk/sbc.py | 80 +++++++++++++++++++++++++-- simuk/tests/test_sbc.py | 120 ++++++++++++++++++++++++++++++++++++---- 2 files changed, 184 insertions(+), 16 deletions(-) diff --git a/simuk/sbc.py b/simuk/sbc.py index 4c070b3..25f7f37 100644 --- a/simuk/sbc.py +++ b/simuk/sbc.py @@ -17,6 +17,7 @@ pass import numpy as np +import xarray as xr from arviz_base import extract, from_dict, from_numpyro from tqdm import tqdm @@ -59,6 +60,9 @@ class SBC: data_dir : dict Keyword arguments passed to numpyro model, intended for use when providing an MCMC Kernel model. + simulator : callable + A custom simulator function that takes as input the model parameters and + returns a dictionary of named observations. Example ------- @@ -73,7 +77,15 @@ class SBC: sbc.run_simulations() """ - def __init__(self, model, num_simulations=1000, sample_kwargs=None, seed=None, data_dir=None): + def __init__( + self, + model, + num_simulations=1000, + sample_kwargs=None, + seed=None, + data_dir=None, + simulator=None, + ): if hasattr(model, "basic_RVs") and isinstance(model, pm.Model): self.engine = "pymc" self.model = model @@ -110,6 +122,22 @@ def __init__(self, model, num_simulations=1000, sample_kwargs=None, seed=None, d self._extract_variable_names() self.simulations = {name: [] for name in self.var_names} self._simulations_complete = 0 + if simulator is not None and not callable(simulator): + raise ValueError("simulator should be a function or None") + if simulator is not None and self.observed_vars: + logging.warning( + "Provided model contains both observed variables and a simulator. " + "Ignoring observed variables and using simulator instead." + ) + if simulator is None and not self.observed_vars and self.engine == "pymc": + # Ideally, we could raise an error early for `numpyro` also, + # but `factor` also produces 'observed_vars' + raise ValueError( + "There are no observed variables and PyMC will not generate prior " + "predictive samples. Either change the model or specify a simulator " + "with the `simulator` argument." + ) + self.simulator = simulator def _extract_variable_names(self): """Extract observed and free variables from the model.""" @@ -142,8 +170,41 @@ def _get_prior_predictive_samples(self): idata = pm.sample_prior_predictive( samples=self.num_simulations, random_seed=self._seeds[0] ) - prior_pred = extract(idata, group="prior_predictive", keep_dataset=True) prior = extract(idata, group="prior", keep_dataset=True) + if self.simulator is None: + prior_pred = extract(idata, group="prior_predictive", keep_dataset=True) + return prior, prior_pred + # Deal with custom simulator + prior_pred = [] + for i in range(prior.sizes["sample"]): + params = {var: prior[var].isel(sample=i).values for var in prior.data_vars} + try: + res = self.simulator(**params) + assert isinstance( + res, dict + ), f"Simulator must return a dictionary, got {type(res)}" + prior_pred.append(res) + except Exception as e: + raise ValueError( + f"Error generating prior predictive sample with parameters {params}: {e}." + ) + # --- Convert list of dicts to xarray.Dataset --- + # Get the sample coordinate (keeps chain/draw MultiIndex intact) + sample_coord = prior.coords["sample"] + + # Collect variables into dict-of-arrays, stacking along 'sample' + data_vars = {} + for key in prior_pred[0].keys(): + stacked = np.stack([np.asarray(d[key]) for d in prior_pred], axis=-1) + dims = [f"{key}_dim_{i}" for i in range(stacked.ndim - 1)] + ["sample"] + data_vars[key] = (dims, stacked) + + # Build dataset + prior_pred = xr.Dataset( + data_vars=data_vars, + coords={**{k: v for k, v in prior.coords.items()}, "sample": sample_coord}, + attrs=prior.attrs, + ) return prior, prior_pred def _get_prior_predictive_samples_numpyro(self): @@ -152,7 +213,13 @@ def _get_prior_predictive_samples_numpyro(self): free_vars_data = {k: v for k, v in self.data_dir.items() if k not in self.observed_vars} samples = predictive(jax.random.PRNGKey(self._seeds[0]), **free_vars_data) prior = {k: v for k, v in samples.items() if k not in self.observed_vars} - prior_pred = {k: v for k, v in samples.items() if k in self.observed_vars} + if self.simulator: + results = [ + self.simulator(**dict(zip(prior.keys(), values))) for values in zip(*prior.values()) + ] + prior_pred = dict(zip(results[0].keys(), zip(*[result.values() for result in results]))) + else: + prior_pred = {k: v for k, v in samples.items() if k in self.observed_vars} return prior, prior_pred def _get_posterior_samples(self, prior_predictive_draw): @@ -170,7 +237,12 @@ def _get_posterior_samples_numpyro(self, prior_predictive_draw): """Generate posterior samples using numpyro conditioned to a prior predictive sample.""" mcmc = MCMC(self.numpyro_model, **self.sample_kwargs) rng_seed = jax.random.PRNGKey(self._seeds[self._simulations_complete]) - free_vars_data = {k: v for k, v in self.data_dir.items() if k not in self.observed_vars} + # If using a custom simulator, some variables present in `prior_predictive_draw` + # might be missing from self.observed_vars. + # TODO: Not sure if the union is redundant here and perhaps prior_predictive_draw.keys() + # could be sufficient. + extended_observed_vars = set(prior_predictive_draw.keys()).union(self.observed_vars) + free_vars_data = {k: v for k, v in self.data_dir.items() if k not in extended_observed_vars} mcmc.run(rng_seed, **free_vars_data, **prior_predictive_draw) return from_numpyro(mcmc)["posterior"] diff --git a/simuk/tests/test_sbc.py b/simuk/tests/test_sbc.py index aab4b85..640caac 100644 --- a/simuk/tests/test_sbc.py +++ b/simuk/tests/test_sbc.py @@ -1,4 +1,5 @@ import bambi as bmb +import jax.numpy as jnp import numpy as np import numpyro import numpyro.distributions as dist @@ -11,23 +12,64 @@ np.random.seed(1234) +# Test data data = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]) sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]) +# PyMC models with pm.Model() as centered_eight: mu = pm.Normal("mu", mu=0, sigma=5) tau = pm.HalfCauchy("tau", beta=5) theta = pm.Normal("theta", mu=mu, sigma=tau, shape=8) y_obs = pm.Normal("y", mu=theta, sigma=sigma, observed=data) +with pm.Model() as centered_eight_no_observed: + mu = pm.Normal("mu", mu=0, sigma=5) + tau = pm.HalfCauchy("tau", beta=5) + theta = pm.Normal("theta", mu=mu, sigma=tau, shape=8) + + def log_likelihood(theta, observed): + return pm.math.sum(pm.logp(pm.Normal.dist(mu=theta, sigma=sigma), observed)) + + pm.Potential("y_loglike", log_likelihood(mu, data)) + +# Bambi model x = np.random.normal(0, 1, 20) y = 2 + np.random.normal(x, 1) df = pd.DataFrame({"x": x, "y": y}) bmb_model = bmb.Model("y ~ x", df) +# NumPyro models +def eight_schools_cauchy_prior(J, sigma, y=None): + mu = numpyro.sample("mu", dist.Normal(0, 5)) + tau = numpyro.sample("tau", dist.HalfCauchy(5)) + with numpyro.plate("J", J): + theta = numpyro.sample("theta", dist.Normal(mu, tau)) + numpyro.sample("y", dist.Normal(theta, sigma), obs=y) + + +def eight_schools_cauchy_prior_no_observed(J, sigma, y=None): + mu = numpyro.sample("mu", dist.Normal(0, 5)) + tau = numpyro.sample("tau", dist.HalfCauchy(5)) + with numpyro.plate("J", J): + theta = numpyro.sample("theta", dist.Normal(mu, tau)) + if y is not None: + log_likelihood = jnp.sum(dist.Normal(theta, sigma).log_prob(y)) + numpyro.factor("custom_likelihood", log_likelihood) + +# Custom simulator functions +def centered_eight_simulator(theta, **kwargs): + return {"y": np.random.normal(theta, sigma)} + + +def bmb_simulator(mu, sigma, **kwargs): + return {"y": np.random.normal(mu, sigma)} + + +# --- Tests with observed variables --- @pytest.mark.parametrize("model", [centered_eight, bmb_model]) -def test_sbc(model): +def test_sbc_with_observed_data(model): sbc = simuk.SBC( model, num_simulations=10, @@ -37,22 +79,76 @@ def test_sbc(model): assert "prior_sbc" in sbc.simulations -def test_sbc_numpyro(): - y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]) - sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]) +def test_sbc_numpyro_with_observed_data(): + sbc = simuk.SBC( + NUTS(eight_schools_cauchy_prior), + data_dir={"J": 8, "sigma": sigma, "y": data}, + num_simulations=10, + sample_kwargs={"num_warmup": 50, "num_samples": 25}, + ) + sbc.run_simulations() + assert "prior_sbc" in sbc.simulations - def eight_schools_cauchy_prior(J, sigma, y=None): - mu = numpyro.sample("mu", dist.Normal(0, 5)) - tau = numpyro.sample("tau", dist.HalfCauchy(5)) - with numpyro.plate("J", J): - theta = numpyro.sample("theta", dist.Normal(mu, tau)) - numpyro.sample("y", dist.Normal(theta, sigma), obs=y) +# --- Tests with custom simulators --- +@pytest.mark.parametrize( + "model,simulator", + [ + # Case 1: Both simulator function and observed variables present + (centered_eight, centered_eight_simulator), + # Case 2: Only simulator function present + (centered_eight_no_observed, centered_eight_simulator), + # Case 3: bambi model with custom simulator + (bmb_model, bmb_simulator), + ], +) +def test_sbc_with_custom_simulator(model, simulator): sbc = simuk.SBC( - NUTS(eight_schools_cauchy_prior), - data_dir={"J": 8, "sigma": sigma, "y": y}, + model, num_simulations=10, sample_kwargs={"draws": 5, "tune": 5}, simulator=simulator + ) + sbc.run_simulations() + assert "prior_sbc" in sbc.simulations + + +@pytest.mark.parametrize( + "model,simulator", + [ + # Case 1: Both simulator function and observed variables present + (eight_schools_cauchy_prior, centered_eight_simulator), + # Case 2: Only simulator function present + (eight_schools_cauchy_prior_no_observed, centered_eight_simulator), + ], +) +def test_sbc_numpyro_with_custom_simulator(model, simulator): + sbc = simuk.SBC( + NUTS(model), + data_dir={"J": 8, "sigma": sigma, "y": data}, num_simulations=10, sample_kwargs={"num_warmup": 50, "num_samples": 25}, + simulator=simulator, ) sbc.run_simulations() assert "prior_sbc" in sbc.simulations + + +# --- Error handling tests with custom simulators --- +def test_sbc_fail_no_observed_variable(): + with pytest.raises(ValueError, match="no observed variables"): + simuk.SBC( + centered_eight_no_observed, + num_simulations=10, + sample_kwargs={"draws": 5, "tune": 5}, + ) + + +def test_sbc_numpyro_fail_no_observed_variable(): + # Note: factor variables are catalogued as 'observed_vars' in NumPyro + # therefore, we cannot raise an early exception with an informative message + with pytest.raises(ValueError): + sbc = simuk.SBC( + NUTS(eight_schools_cauchy_prior_no_observed), + data_dir={"J": 8, "sigma": sigma, "y": data}, + num_simulations=10, + sample_kwargs={"num_warmup": 50, "num_samples": 25}, + ) + sbc.run_simulations() From 29ff47507c8f7a9639ec8662446d4811792d0419 Mon Sep 17 00:00:00 2001 From: Curro Campuzano Date: Sat, 30 Aug 2025 12:48:52 +0200 Subject: [PATCH 2/8] Add seed parameter --- simuk/sbc.py | 15 +++++++++------ simuk/tests/test_sbc.py | 10 ++++++---- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/simuk/sbc.py b/simuk/sbc.py index 25f7f37..2e4b7e5 100644 --- a/simuk/sbc.py +++ b/simuk/sbc.py @@ -62,7 +62,7 @@ class SBC: an MCMC Kernel model. simulator : callable A custom simulator function that takes as input the model parameters and - returns a dictionary of named observations. + a int parameter named `seed`, and returns a dictionary of named observations. Example ------- @@ -127,13 +127,13 @@ def __init__( if simulator is not None and self.observed_vars: logging.warning( "Provided model contains both observed variables and a simulator. " - "Ignoring observed variables and using simulator instead." + "Ignoring observed variables and using the simulator instead." ) if simulator is None and not self.observed_vars and self.engine == "pymc": # Ideally, we could raise an error early for `numpyro` also, # but `factor` also produces 'observed_vars' raise ValueError( - "There are no observed variables and PyMC will not generate prior " + "There are no observed variables, and PyMC will not generate prior " "predictive samples. Either change the model or specify a simulator " "with the `simulator` argument." ) @@ -178,6 +178,7 @@ def _get_prior_predictive_samples(self): prior_pred = [] for i in range(prior.sizes["sample"]): params = {var: prior[var].isel(sample=i).values for var in prior.data_vars} + params["seed"] = self._seeds[i] try: res = self.simulator(**params) assert isinstance( @@ -214,9 +215,11 @@ def _get_prior_predictive_samples_numpyro(self): samples = predictive(jax.random.PRNGKey(self._seeds[0]), **free_vars_data) prior = {k: v for k, v in samples.items() if k not in self.observed_vars} if self.simulator: - results = [ - self.simulator(**dict(zip(prior.keys(), values))) for values in zip(*prior.values()) - ] + results = [] + for i, vals in enumerate(zip(*prior.values())): + params = dict(zip(prior.keys(), vals)) + params["seed"] = self._seeds[i] + results.append(self.simulator(**params)) prior_pred = dict(zip(results[0].keys(), zip(*[result.values() for result in results]))) else: prior_pred = {k: v for k, v in samples.items() if k in self.observed_vars} diff --git a/simuk/tests/test_sbc.py b/simuk/tests/test_sbc.py index 640caac..dcdf405 100644 --- a/simuk/tests/test_sbc.py +++ b/simuk/tests/test_sbc.py @@ -59,12 +59,14 @@ def eight_schools_cauchy_prior_no_observed(J, sigma, y=None): # Custom simulator functions -def centered_eight_simulator(theta, **kwargs): - return {"y": np.random.normal(theta, sigma)} +def centered_eight_simulator(theta, seed, **kwargs): + rng = np.random.default_rng(seed) + return {"y": rng.normal(theta, sigma)} -def bmb_simulator(mu, sigma, **kwargs): - return {"y": np.random.normal(mu, sigma)} +def bmb_simulator(mu, sigma, seed, **kwargs): + rng = np.random.default_rng(seed) + return {"y": rng.normal(mu, sigma)} # --- Tests with observed variables --- From 81f994e2c5b68963300a39732a666458fa52fdf2 Mon Sep 17 00:00:00 2001 From: Curro Campuzano <69399781+currocam@users.noreply.github.com> Date: Wed, 3 Sep 2025 12:47:53 +0200 Subject: [PATCH 3/8] Apply suggestions from code review Co-authored-by: Osvaldo A Martin --- simuk/sbc.py | 23 +++++------------------ 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/simuk/sbc.py b/simuk/sbc.py index 2e4b7e5..9822106 100644 --- a/simuk/sbc.py +++ b/simuk/sbc.py @@ -189,23 +189,10 @@ def _get_prior_predictive_samples(self): raise ValueError( f"Error generating prior predictive sample with parameters {params}: {e}." ) - # --- Convert list of dicts to xarray.Dataset --- - # Get the sample coordinate (keeps chain/draw MultiIndex intact) - sample_coord = prior.coords["sample"] - - # Collect variables into dict-of-arrays, stacking along 'sample' - data_vars = {} - for key in prior_pred[0].keys(): - stacked = np.stack([np.asarray(d[key]) for d in prior_pred], axis=-1) - dims = [f"{key}_dim_{i}" for i in range(stacked.ndim - 1)] + ["sample"] - data_vars[key] = (dims, stacked) - - # Build dataset - prior_pred = xr.Dataset( - data_vars=data_vars, - coords={**{k: v for k, v in prior.coords.items()}, "sample": sample_coord}, - attrs=prior.attrs, - ) + prior_pred = dict_to_dataset({key: np.stack([pp[key] for pp in prior_pred]) for key in prior_pred[0]}, + sample_dims=["sample"], + coords={**prior.coords}, + ) return prior, prior_pred def _get_prior_predictive_samples_numpyro(self): @@ -220,7 +207,7 @@ def _get_prior_predictive_samples_numpyro(self): params = dict(zip(prior.keys(), vals)) params["seed"] = self._seeds[i] results.append(self.simulator(**params)) - prior_pred = dict(zip(results[0].keys(), zip(*[result.values() for result in results]))) + prior_pred = {key: [result[key] for result in results] for key in results[0]} else: prior_pred = {k: v for k, v in samples.items() if k in self.observed_vars} return prior, prior_pred From 3ece99c992af304ccf109d14224d2933f7834b38 Mon Sep 17 00:00:00 2001 From: Curro Campuzano Date: Wed, 3 Sep 2025 12:50:38 +0200 Subject: [PATCH 4/8] Add missing import --- simuk/sbc.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/simuk/sbc.py b/simuk/sbc.py index 9822106..b4354f9 100644 --- a/simuk/sbc.py +++ b/simuk/sbc.py @@ -18,10 +18,9 @@ import numpy as np import xarray as xr -from arviz_base import extract, from_dict, from_numpyro +from arviz_base import extract, from_dict, dict_to_dataset, from_numpyro from tqdm import tqdm - class quiet_logging: """Turn off logging for PyMC, Bambi and PyTensor.""" From b23422bd1158da1f23a8ee137e50be85b9cc3dd3 Mon Sep 17 00:00:00 2001 From: Curro Campuzano Date: Wed, 3 Sep 2025 13:09:26 +0200 Subject: [PATCH 5/8] Check `simulator` returns a `Mapping` + linter --- simuk/sbc.py | 17 ++++++++++------- simuk/tests/test_sbc.py | 11 +++++++++++ 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/simuk/sbc.py b/simuk/sbc.py index b4354f9..0aff66d 100644 --- a/simuk/sbc.py +++ b/simuk/sbc.py @@ -16,11 +16,13 @@ except ImportError: pass +from collections.abc import Mapping + import numpy as np -import xarray as xr -from arviz_base import extract, from_dict, dict_to_dataset, from_numpyro +from arviz_base import dict_to_dataset, extract, from_dict, from_numpyro from tqdm import tqdm + class quiet_logging: """Turn off logging for PyMC, Bambi and PyTensor.""" @@ -181,17 +183,18 @@ def _get_prior_predictive_samples(self): try: res = self.simulator(**params) assert isinstance( - res, dict + res, Mapping ), f"Simulator must return a dictionary, got {type(res)}" prior_pred.append(res) except Exception as e: raise ValueError( f"Error generating prior predictive sample with parameters {params}: {e}." ) - prior_pred = dict_to_dataset({key: np.stack([pp[key] for pp in prior_pred]) for key in prior_pred[0]}, - sample_dims=["sample"], - coords={**prior.coords}, - ) + prior_pred = dict_to_dataset( + {key: np.stack([pp[key] for pp in prior_pred]) for key in prior_pred[0]}, + sample_dims=["sample"], + coords={**prior.coords}, + ) return prior, prior_pred def _get_prior_predictive_samples_numpyro(self): diff --git a/simuk/tests/test_sbc.py b/simuk/tests/test_sbc.py index dcdf405..50f95e6 100644 --- a/simuk/tests/test_sbc.py +++ b/simuk/tests/test_sbc.py @@ -6,6 +6,7 @@ import pandas as pd import pymc as pm import pytest +from numba import njit from numpyro.infer import NUTS import simuk @@ -64,6 +65,16 @@ def centered_eight_simulator(theta, seed, **kwargs): return {"y": rng.normal(theta, sigma)} +@njit +def centered_eight_jitted_simulator(tau, mu, theta, seed): + # Some expensive computation + n = theta.shape[0] + y = np.zeros(n) + for i in range(n): + y[i] = theta[i] + return {"y": y} + + def bmb_simulator(mu, sigma, seed, **kwargs): rng = np.random.default_rng(seed) return {"y": rng.normal(mu, sigma)} From ab941b4cae14a3a3e71e893349371a6d2133b46d Mon Sep 17 00:00:00 2001 From: Curro Campuzano Date: Wed, 3 Sep 2025 13:42:55 +0200 Subject: [PATCH 6/8] Add Numba as dependency + docstring --- requirements-dev.txt | 1 + simuk/sbc.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 954b640..a889cba 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -9,3 +9,4 @@ bambi>=0.13.0 arviz_base>=0.5.0 ruff==0.9.1 numpyro>=0.17.0 +numba>=0.60.0 diff --git a/simuk/sbc.py b/simuk/sbc.py index 0aff66d..5ea6d1d 100644 --- a/simuk/sbc.py +++ b/simuk/sbc.py @@ -63,7 +63,7 @@ class SBC: an MCMC Kernel model. simulator : callable A custom simulator function that takes as input the model parameters and - a int parameter named `seed`, and returns a dictionary of named observations. + a int parameter named `seed`, and must return a dictionary of named observations. Example ------- From 5102f8dafbc1124206fcfb674eda3b46fb385b51 Mon Sep 17 00:00:00 2001 From: Curro Campuzano Date: Thu, 4 Sep 2025 18:52:23 +0200 Subject: [PATCH 7/8] Bump bambi version --- environment-dev.yml | 3 +-- requirements-dev.txt | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/environment-dev.yml b/environment-dev.yml index deda50d..b4880b0 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -7,7 +7,7 @@ dependencies: - pip - python >= 3.11 - pymc>=5.20.1 - - bambi>=0.13.0 + - bambi>=0.14.0 - arviz>=0.20.0 - black=22.3.0 - click=8.0.4 @@ -15,4 +15,3 @@ dependencies: - pytest>=4.4.0 - pre-commit>=2.19 - ruff==0.9.1 - diff --git a/requirements-dev.txt b/requirements-dev.txt index a889cba..507256c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -5,7 +5,7 @@ pytest>=4.4.0 pre-commit>=2.19 ipytest==0.13.0 pymc>=5.20.1 -bambi>=0.13.0 +bambi>=0.14.0 arviz_base>=0.5.0 ruff==0.9.1 numpyro>=0.17.0 From f094ea83a45a699fa21a2edaee9313587fbe95b0 Mon Sep 17 00:00:00 2001 From: Curro Campuzano Date: Thu, 4 Sep 2025 19:25:44 +0200 Subject: [PATCH 8/8] Revert bump version & skip test --- environment-dev.yml | 2 +- requirements-dev.txt | 2 +- simuk/tests/test_sbc.py | 17 +++++++++++++++-- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/environment-dev.yml b/environment-dev.yml index b4880b0..7f934ba 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -7,7 +7,7 @@ dependencies: - pip - python >= 3.11 - pymc>=5.20.1 - - bambi>=0.14.0 + - bambi>=0.13.0 - arviz>=0.20.0 - black=22.3.0 - click=8.0.4 diff --git a/requirements-dev.txt b/requirements-dev.txt index 507256c..a889cba 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -5,7 +5,7 @@ pytest>=4.4.0 pre-commit>=2.19 ipytest==0.13.0 pymc>=5.20.1 -bambi>=0.14.0 +bambi>=0.13.0 arviz_base>=0.5.0 ruff==0.9.1 numpyro>=0.17.0 diff --git a/simuk/tests/test_sbc.py b/simuk/tests/test_sbc.py index 50f95e6..17e0e1a 100644 --- a/simuk/tests/test_sbc.py +++ b/simuk/tests/test_sbc.py @@ -111,8 +111,6 @@ def test_sbc_numpyro_with_observed_data(): (centered_eight, centered_eight_simulator), # Case 2: Only simulator function present (centered_eight_no_observed, centered_eight_simulator), - # Case 3: bambi model with custom simulator - (bmb_model, bmb_simulator), ], ) def test_sbc_with_custom_simulator(model, simulator): @@ -123,6 +121,21 @@ def test_sbc_with_custom_simulator(model, simulator): assert "prior_sbc" in sbc.simulations +@pytest.mark.skipif( + hasattr(bmb, "__version__") and tuple(map(int, bmb.__version__.split("."))) <= (0, 14), + reason="requires bambi version > 0.14", +) +def test_sbc_bambi_with_custom_simulator(): + sbc = simuk.SBC( + bmb_model, + num_simulations=10, + sample_kwargs={"draws": 5, "tune": 5}, + simulator=bmb_simulator, + ) + sbc.run_simulations() + assert "prior_sbc" in sbc.simulations + + @pytest.mark.parametrize( "model,simulator", [