A JAX-based, NumPyro-compatible implementation of Non-Reversible Parallel Tempering (NRPT)
Warning: nrpt is under active development.
Optional: if you want to run your NumPyro models on an accelerator (GPU/TPU), make sure to install the correct version of JAX before proceeding. Otherwise, the following will install the default, CPU-only version of JAX.
Using pip
pip install automcmc @ git+https://github.com/UBC-Stat-ML/automcmc.git
pip install nrpt @ git+https://github.com/Estep-Bingham-Lab/nrpt.gitNote: In the following we will require the additional packages pandas and
corner, which can be installed from common repositories.
To showcase the power of nrpt, we will analyze a challenging benchmark problem
described in Ballnus et al. 2017.
The objective is to estimate the parameters of an Ordinary Differential Equation
(ODE) given noisy observations of its solution. The ODE itself was described in
Leonhardt et al. 2014, while the
Bayesian formulation of the inference problem is from
Ballnus et al. 2017. The latter
shows an empirical comparison of several MCMC samplers on the ODE problem,
indicating that schemes that used Parallel Tempering were the only ones able
to accurately describe the posterior distribution. Indeed, its density is
bimodal and features narrow ridges.
Here we will show that nrpt can leverage an automatically tuned
sampler described in Liu et al. (2025)
to tackle this inference task.
For brevity, we won't go into the details of the model here; be sure to check
the references if you are curious. We also assume that you are familiar with
NRPT. Beyond the original paper, a good
reference is the documentation of the Julia package
Pigeons.jl; nrpt is heavily inspired by it.
We will aim to reproduce Figure 6 in
Ballnus et al. 2017, which shows
a corner plot of the posterior samples of the unknown parameters of the ODE.
The model has been written in NumPyro and included in nrpt.
We can load it along all the required dependencies using
from jax import random
from jax import numpy as jnp
from numpyro.diagnostics import print_summary
from automcmc import autohmc
from nrpt import initialization
from nrpt import sampling
from nrpt import toy_examples
import numpy as np
import corner
model, model_args, model_kwargs = toy_examples.mrna()model is a python function written using NumPyro primitives. This function
takes as input the observation times -- contained in the tuple model_args --
and the noisy observations inside the model_kwargs dictionary.
Following the NumPyro convention, we enclose the model in an MCMC sampler. In
nrpt, this sampler will be used as the explorer in the NRPT terminology.
Currently, nrpt only works with the MCMC samplers of the
automcmc package. For this
example, we will use the AutoHMC sampler with the default 32 leapfrog steps.
kernel = autohmc.AutoHMC(model)With the explorer in place, we can proceed to instantiate a PTSampler object
pt_sampler = initialization.PT(
kernel,
rng_key = random.key(1),
n_rounds = 14,
n_replicas = 15,
n_refresh = 2,
model_args=model_args,
model_kwargs=model_kwargs
)Note that the model arguments are passed to the constructor. There are several other settings being provided:
- A JAX PRNG key, used to draw (pseudo-)random variates.
- The number of NRPT rounds is set to 14, so that a total of
$2^{15}$ =16384 samples -- corresponding to the last round -- are returned. - The number of replicas is set to 15, which is roughly 2.5 times the global
barrier
$\Lambda$ of the problem. This is the ideal minimum number of chains needed for NRPT to correctly bootstrap itself via adaptation. The number of replicas can be increased past this point until either device memory is exhausted or until significant speed deterioration is observed. Of course, since$\Lambda$ is a priori unknown, settingn_replicasrequires some iteration. - The number of explorer refreshments within each exploration step is set to 2. This allows us to achieve a worse-case autocorrelation of the log-likelihood (across replicas) of less than 0.95.
We can run NRPT typing (takes less than 5 minutes on an Nvidia RTX 2000 Ada generation laptop GPU)
pt_sampler = sampling.run(pt_sampler)The above will produce an output similar to this
R | Δt | ETA | Λ | logZ | ρ (mean/max/amax) | newβ₁ | α (min/mean) | AC (mean/max)
----------------------------------------------------------------------------------------------------------
1 0:00:12 54:56:08 3.9 -8.56e+02 0.28 / 0.92 / 12 1e-19 0.00 / 0.45 -0.30 / 1.33
2 0:00:00 0:06:56 4.4 -7.83e+02 0.32 / 0.95 / 2 1e-05 0.32 / 0.56 0.54 / 1.00
3 0:00:00 0:08:40 6.3 -7.48e+02 0.45 / 0.79 / 8 6e-08 0.37 / 0.61 0.42 / 1.01
4 0:00:00 0:05:44 8.4 -4.73e+02 0.60 / 0.91 / 8 7e-06 0.47 / 0.67 0.62 / 1.06
5 0:00:00 0:05:31 6.2 -4.08e+02 0.44 / 0.77 / 3 4e-09 0.56 / 0.75 0.78 / 0.98
6 0:00:01 0:04:55 5.1 -3.68e+02 0.36 / 0.99 / 4 1e-06 0.77 / 0.88 0.47 / 0.87
7 0:00:01 0:03:02 5.2 -3.66e+02 0.37 / 0.93 / 4 5e-08 0.79 / 0.92 0.58 / 0.94
8 0:00:02 0:04:44 5.6 -3.66e+02 0.40 / 0.56 / 5 5e-07 0.70 / 0.90 0.65 / 0.97
9 0:00:04 0:03:39 6.2 -3.71e+02 0.44 / 0.65 / 6 7e-07 0.72 / 0.88 0.68 / 0.93
10 0:00:08 0:03:59 6.3 -3.70e+02 0.45 / 0.52 / 12 3e-07 0.70 / 0.87 0.67 / 0.93
11 0:00:13 0:02:59 6.2 -3.71e+02 0.44 / 0.50 / 6 3e-07 0.72 / 0.88 0.69 / 0.92
12 0:00:27 0:02:41 6.3 -3.71e+02 0.45 / 0.50 / 9 5e-07 0.72 / 0.87 0.70 / 0.92
13 0:00:52 0:01:44 6.1 -3.70e+02 0.44 / 0.48 / 11 4e-07 0.70 / 0.87 0.69 / 0.92
14 0:01:43 0:00:00 6.2 -3.70e+02 0.44 / 0.46 / 10 4e-07 0.72 / 0.89 0.70 / 0.92
From left to right, the figures shown here correspond to:
- The round index
- The duration of the round
- The estimated time until sampling is completed. Note that this is very inaccurate in the earlier rounds. It begins to stabilize roughly after round 7, depending on the complexity of the target.
- Estimate of the global barrier, which at the last round is
$\Lambda \approx 6.2$ . - Estimate of the log-normalization constant, which in the last round gives
$\log(\mathcal{Z})\approx -370$ . - Average and worst-case swap rejection probabilities. When the average is close to the maximum -- as in the last 3 rounds -- the ideal equi-rejection condition has been approximately attained.
- The
amaxfield in the previous column indicates the chain index that shows the highest rejection probability. That is, whenamax=i, it means that the swap between chainsiandi+1shows the highest rejection rate. The next column shows the updated value of the first non-zero inverse temperature. This helps with diagnosing high rejection rates foramax=0. - Average and worst-case explorer acceptance probabilities. If the explorer is working correctly along the path of distributions, we expect both values to be away from 0 and 1.
- Average and worst-case autocorrelation (AC) of the log-likelihood before and after the exploration steps. As described above, the number of refreshments was set so that the maximum was below 0.95. Note: the estimator does not behave well in small samples, which is why we can see autocorrelation values larger than one in earlier rounds.
We can now extract the samples and use the print_summary function from
NumPyro to show a brief description of the latent values of the model
samples = pt_sampler.pt_state.samples
print_summary(samples, group_by_chain=False) mean std median 5.0% 95.0% n_eff r_hat
lbeta -1.51 0.92 -0.80 -2.60 -0.66 185.15 1.00
ldelta -1.81 1.06 -2.09 -3.07 -0.66 148.99 1.00
lkm0 1.08 0.02 1.08 1.05 1.12 347.15 1.00
log_lik -349.62 1.80 -349.29 -352.20 -347.06 466.17 1.00
log_prior -7.65 0.42 -7.52 -7.94 -7.37 382.82 1.00
lsigma 0.40 0.02 0.39 0.36 0.44 241.64 1.01
lt0 0.18 0.03 0.18 0.13 0.22 347.33 1.01
The summary includes the model parameters as well as the log prior, log likelihood,
and log joint---corresponding to log posterior plus the log density
of the momentum. For all these quantities, we see effective sample sizes (n_eff)
of over 100, together with
Finally, we can recreate the corner plot in Ballnus et al. (2017) using
transformed_samples = np.array(jnp.vstack(
[
samples['lt0'],
samples['lkm0'],
samples['lbeta'],
samples['ldelta'],
samples['lsigma']
]
).swapaxes(0,1))
figure = corner.corner(
transformed_samples,
labels=[
r"$\log_{10}(t_0)$",
r"$\log_{10}(\kappa)$",
r"$\log_{10}(\beta)$",
r"$\log_{10}(\delta)$",
r"$\log_{10}(\sigma)$",
],
quantiles=[0.16, 0.5, 0.84],
show_titles=True,
title_kwargs={"fontsize": 12},
plot_contours=False,
smooth=False,
plot_density=False,
data_kwargs={'color': (0.0,0.6056031611752245,0.9786801175696073)}
)
figure.savefig('mrna_corner.png', bbox_inches='tight')Note that the posterior is clearly bimodal. Not only that, the shapes
of these two modes are completely different. Moreover, there is a clear
ridge visible in the
Syed, S., Bouchard-Côté, A., Deligiannidis, G., & Doucet, A. (2022). Non-reversible parallel tempering: a scalable highly parallel MCMC scheme. Journal of the Royal Statistical Society Series B: Statistical Methodology, 84(2), 321-350.
Liu, T., Surjanovic, N., Biron-Lattes, M., Bouchard-Côté, A., & Campbell, T. (2024). AutoStep: Locally adaptive involutive MCMC. arXiv preprint arXiv:2410.18929. Accepted to ICML 2025.
- Documentation
