Python interface for the Stan to NumPyro compiler.
pip install .
Let start with the simple eight schools example from Gelman et al (Bayesian Data Analysis: Sec. 5.5, 2003). First save the following Stan code, e.g., in a file 8schools.stan:
data {
int <lower=0> J; // number of schools
real y[J]; // estimated treatment
real<lower=0> sigma[J]; // std of estimated effect
}
parameters {
real theta[J]; // treatment effect in school j
real mu; // hyper-parameter of mean
real<lower=0> tau; // hyper-parameter of sdv
}
model {
tau ~ cauchy(0, 5); // a non-informative prior
theta ~ normal(mu, tau);
y ~ normal(theta, sigma);
mu ~ normal(0, 5);
}Then to compile and run inference with the NumPyro runtime:
from stannumpyro import NumPyroModel
from jax import random
if __name__ == "__main__":
stanfile = "8schools.stan"
data = {
'J': 8,
'y': [28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0],
'sigma': [15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]
}
model = NumPyroModel(stanfile)
mcmc = model.mcmc(
samples = 1000,
warmups = 100,
chains=2,
thin=2,
)
mcmc.run(random.PRNGKey(0), data)
print(mcmc.summary())The NumPyroModel constructor compiles and load the model.
The rest of the API is similar to NumPyro.
Note that mcmc.run requires an explicit random seed as a first argument.
class NumPyroModel(stanfile, recompile=True, mode="comprehensive", compiler=["stanc"], build_dir="_tmp")Simplified interface to compile and run Stan models using the Stan to NumPyro compiler
Parameters
stanfile: location of the stan filerecompile: Whether to recompile the model (default toTrue)mode: compilation mode. One of"comprehensive","mixed", or"generative"(default"comprehensive")compiler: compiler command as a list (default["stanc"])build_dir: name of the build directory (default"_tmp")
mcmc(samples, warmups=0, chains=1, thin=1, kernel=None, **kwargs) -> MCMCProxyProvides access to Markov Chain Monte Carlo inference algorithms in NumPyro (see http://num.pyro.ai/en/stable/mcmc.html). The default kernel is NUTS.
Parameters
samples: Number of samples to generate from the Markov chainwarmups: Number of warmup steps (default0)chains: Number of MCMC chains to run (default1)thin: Positive integer that controls the fraction of post-warmup samples that are retained (default1)**kwargs: Other options that are directly passed to NumPyro MCMC constructor
svi(optim, loss) -> SVIProxyProvide access to Stochastic Variational Inference given an ELBO loss objective (see http://num.pyro.ai/en/stable/svi.html)
Parameters
optim: an instance of_NumpyroOptimloss: ELBO loss, i.e. negative Evidence Lower Bound, to minimize
class MCMCProxy(mcmc, module)Wrapper for NumPyro MCMC (see http://num.pyro.ai/en/stable/mcmc.html).
Parameters
mcmc: An instance of NumPyro MCMC (obtained from themcmcmethod ofNumPyroModel)module: The module containing the compiled code (seecompilefunction)
run(rng_key, kwargs):Run the inference
Parameters
rng_key: Random number generator key to be used for the sampling (e.g.,jax.random.PRNGKey(0)).kwargs: Data passed as a dictionary.
get_samples()Get samples from the MCMC run.
summary(prob=0.9)Print the statistics of posterior samples collected during running this MCMC instance
Parameters
prob: the probability mass of samples within the credible interval.
class SVIProxy(svi, module)Wrapper for NumPyro SVI (see http://num.pyro.ai/en/stable/svi.html). :warning: Still experimental.
Parameters
svi: An instance of NumPyro SVI (obtained from thesvimethod ofNumPyroModel)module: The module containing the compiled code (seecompilefunction)
compile(mode, stanfile, compiler=["stanc"], recompile=True, build_dir="_tmp")Compile a stan model to NumPyro.
Parameters
mode: compilation mode. One of"comprehensive","mixed", or"generative"stanfile: location of the stan filecompiler: compiler command as a list (default["stanc"])recompile: Whether to recompile the model (default toTrue)build_dir: name of the build directory (default"_tmp")
This function generate a python file named after the Stan file in the build directory containing the generated code. This file can then be used a python module.
The files distributions.py and stanlib.py contains the NumPyro implementation of the Stan standard library.
The complete documentation can be found here: https://mc-stan.org/docs/2_26/functions-reference/index.html