Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ __pycache__
.vscode
*.xml
*.pdf
.env
2 changes: 1 addition & 1 deletion examples/himmelblau.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from nak_torch.algorithms.msip import MSIPFredholm, MSIPQuadGradientFree
from nak_torch.tools.quadrature import spherical_MC_radial_Laguerre
from datetime import datetime
from nak_torch.tools.kernel import kernel_optimal_weight_factory, default_kernel_matrix
from nak_torch.tools.kernel import kernel_optimal_weight_factory, DEFAULT_KERNEL_MATRIX

save_gif = False
function_name = "himmelblau"
Expand Down
1 change: 1 addition & 0 deletions examples/stan/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
build
150 changes: 150 additions & 0 deletions examples/stan/pdb_covid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# %%
from typing import Optional

import nest_asyncio
import os
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm
import nak_torch
from nak_torch.algorithms import MSIP, SVGD
from nak_torch.algorithms.msip import MSIPFredholm
from nak_torch.tools import stan_tools
from nak_torch.tools.types import BatchGradLogDensityEvaluator, DeviceLike

nest_asyncio.apply() # See pystan documentation on why you need this when doing jupyter
import stan # noqa: E402
import posteriordb # noqa: E402

def covid_prior_sample(
N_samples: int,
M_y: int = 14,
dtype: Optional[torch.dtype] = None,
device: Optional[DeviceLike] = None,
rng: Optional[torch.Generator] = None,
):
if rng is None:
rng = torch.default_generator
if dtype is None:
dtype = torch.get_default_dtype()
if device is None:
device = torch.get_default_device()
tau = (
torch.empty((N_samples, 1), dtype=dtype, device=device)
.exponential_(generator=rng)
.div_(0.03)
)
y = (
torch.empty((N_samples, M_y), dtype=dtype, device=device)
.exponential_(generator=rng)
.div_(tau)
)
phi = torch.randn((N_samples, 1), generator=rng).mul_(5.0)
kappa = torch.randn((N_samples, 1), generator=rng).mul_(0.5)
mu = torch.randn((N_samples, M_y), generator=rng).mul_(kappa).add_(3.28)
alpha_hier = torch._standard_gamma(
torch.as_tensor(0.1667, dtype=dtype, device=device).expand(N_samples, 6),
generator=rng,
)
ifr_noise = torch.randn((N_samples, M_y), generator=rng).mul_(0.1).add_(1.0)
log_tau = tau.log_()
log_alpha_hier = alpha_hier.log_()
log_y = y.log_()
return torch.column_stack((mu, log_alpha_hier, kappa, log_y, phi, log_tau, ifr_noise))

# %%
pdb = posteriordb.PosteriorDatabase()
which_posterior = "ecdc0501-covid19imperial_v3"
posterior = pdb.posterior(which_posterior)
post_model = stan.build(posterior.model.stan_code(), data=posterior.data.values())
dim = sum(posterior.information["dimensions"].values())
stan_model = stan_tools.StanModel(post_model, dim)

# %%
pts = torch.randn((10, stan_model.dim))
pdfs = stan_model.log_dens_batch(pts, None)
grad_log_pdfs = stan_model.grad_log_dens_batch(pts, None)
grad_log_pdfs_2, pdfs_2 = stan_model.grad_val_log_dens_batch(pts, None)
assert (pdfs - pdfs_2).square_().sum() < 1e-10
assert (grad_log_pdfs - grad_log_pdfs_2).square_().sum() < 1e-10

# %%
GRADIENT_DECAY = 1.0
N_PARTICLES = 25
KERNEL_DIAG_INFL = 1e-2
KERNEL_LENGTHSCALE = 1e-1
BOUNDS = (-100.0, 100.0)
target_msip_fr = MSIPFredholm(GRADIENT_DECAY, stan_model.grad_val_log_dens_batch)
init_particles = torch.randn((N_PARTICLES, stan_model.dim))#covid_prior_sample(N_PARTICLES)

msip = MSIP(
stan_model.dim,
N_PARTICLES,
kernel_diag_infl=KERNEL_DIAG_INFL,
kernel_lengthscale=KERNEL_LENGTHSCALE,
)

# %%
N_STEPS = 1000
LR = 1e-3
trajectories_msip_fr = nak_torch.nak(
target_msip_fr,
msip,
N_STEPS,
LR,
init_particles=init_particles,
bounds=BOUNDS,
)
trajectories_pts_msip_fr, trajectories_wts_msip_fr = trajectories_msip_fr

# %%
# %%
target_svgd = BatchGradLogDensityEvaluator(
stan_model.grad_log_dens_batch,
is_grad=True,
is_batched=True
)

svgd = SVGD(
stan_model.dim,
N_PARTICLES,
kernel_lengthscale=KERNEL_LENGTHSCALE,
kernel_lengthscale_quantile=0.5
)

# %%
N_STEPS = 1000
LR = 1e-4
trajectories_pts_svgd = nak_torch.nak(
target_svgd,
svgd,
N_STEPS,
LR,
init_particles=init_particles,
bounds=BOUNDS,
)


# %%
crossent = nak_torch.metrics.CrossEntropy(stan_model.log_dens_batch)
msip_crossent = [crossent(p,w) for p,w in tqdm(zip(*trajectories_msip_fr), total=N_STEPS+1)]
svgd_crossent = [crossent(p) for p in tqdm(trajectories_pts_svgd)]

# %%
ksd = nak_torch.metrics.KernelSteinDiscrepancy(stan_model.grad_log_dens_batch, KERNEL_LENGTHSCALE)
msip_ksd = [ksd(p,w) for p,w in tqdm(zip(*trajectories_msip_fr), total=N_STEPS+1)]
svgd_ksd = [ksd(p) for p in tqdm(trajectories_pts_svgd)]

# %%
plt.plot(msip_crossent, label="msip")
plt.plot(svgd_crossent, label="svgd")
plt.title("Cross Entropy")
plt.legend()
plt.show()

# %%
plt.plot(msip_ksd, label="msip")
plt.plot(svgd_ksd, label="svgd")
plt.legend()
plt.title("KSD")
plt.show()
118 changes: 118 additions & 0 deletions examples/stan/pdb_schools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# %%
import nest_asyncio
import os
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm
import nak_torch
from nak_torch.algorithms import MSIP, SVGD
from nak_torch.algorithms.msip import MSIPFredholm
from nak_torch.tools import stan_tools
from nak_torch.tools.types import BatchGradLogDensityEvaluator

nest_asyncio.apply() # See pystan documentation on why you need this when doing jupyter
import stan # noqa: E402
from posteriordb import PosteriorDatabaseGithub # noqa: E402

# %%
if "GITHUB_PAT" not in os.environ.keys():
raise ValueError("Expected GITHUB_PAT to be in environment. Please add this into, e.g., your .env file.")

my_pdb = PosteriorDatabaseGithub()
pos = my_pdb.posterior_names()

def sample_tau_prior(N_samples, loc: float = 0., scale: float = 5.):
dist = torch.distributions.Cauchy(loc, scale, True)
return dist.rsample((N_samples,)).abs_()

# %%
posterior = my_pdb.posterior("eight_schools-eight_schools_centered")

# %%
post_model = stan.build(posterior.model.stan_code(), data=posterior.data.values())

# %%
stan_model = stan_tools.StanModel(post_model)

# %%
pts = torch.randn((100, stan_model.dim))
pdfs = stan_model.log_dens_batch(pts, None)
grad_log_pdfs = stan_model.grad_log_dens_batch(pts, None)
grad_log_pdfs_2, pdfs_2 = stan_model.grad_val_log_dens_batch(pts, None)

# %%
GRADIENT_DECAY = 1.0
N_PARTICLES = 100
KERNEL_DIAG_INFL = 1e-6
KERNEL_LENGTHSCALE = 1e-1
BOUNDS = (-100.0, 100.0)
target_msip_fr = MSIPFredholm(GRADIENT_DECAY, stan_model.grad_val_log_dens_batch)
init_eta = torch.randn((N_PARTICLES, 8))
init_tau = sample_tau_prior(N_PARTICLES).clamp_(*BOUNDS)
init_mu = torch.randn((N_PARTICLES, 1)) * 5
init_particles = torch.column_stack((init_mu, init_tau, init_eta))

msip = MSIP(
stan_model.dim,
N_PARTICLES,
kernel_diag_infl=KERNEL_DIAG_INFL,
kernel_lengthscale=KERNEL_LENGTHSCALE,
)

# %%
N_STEPS = 1000
LR = 1e-3
trajectories_msip_fr = nak_torch.nak(
target_msip_fr,
msip,
N_STEPS,
LR,
init_particles=init_particles,
bounds=BOUNDS,
)
trajectories_pts_msip_fr, trajectories_wts_msip_fr = trajectories_msip_fr

# %%
target_svgd = BatchGradLogDensityEvaluator(
stan_model.grad_log_dens_batch,
is_grad=True,
is_batched=True
)

svgd = SVGD(
stan_model.dim,
N_PARTICLES,
kernel_lengthscale=KERNEL_LENGTHSCALE,
kernel_lengthscale_quantile=0.5
)

# %%
N_STEPS = 1000
LR = 1e-3
trajectories_pts_svgd = nak_torch.nak(
target_svgd,
svgd,
N_STEPS,
LR,
init_particles=init_particles,
bounds=BOUNDS,
)


# %%
draws = stan_tools.get_draws(post_model, posterior)

# %%
cross_ent = nak_torch.metrics.CrossEntropy(stan_model.log_dens_batch)

# %%
msip_cross_ent = [cross_ent(pts, None, None) for pts in tqdm(trajectories_pts_msip_fr)]
svgd_cross_ent = [cross_ent(pts, None, None) for pts in tqdm(trajectories_pts_svgd)]

# %%
plt.plot(msip_cross_ent, label="MSIP")
plt.plot(svgd_cross_ent, label="SVGD")
plt.plot()
plt.title("Cross entropy across iterations")
plt.legend()
# %%
90 changes: 90 additions & 0 deletions examples/stan/schools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# %%
import torch
import nest_asyncio
import nak_torch
from nak_torch.tools import stan_tools
from nak_torch.algorithms import MSIP, SVGD
from nak_torch.algorithms.msip import MSIPFredholm, MSIPQuadGradientFree

nest_asyncio.apply() # See pystan documentation on why you need this when doing jupyter
import stan # noqa: E402

# %%
# Example from https://github.com/stan-dev/pystan
schools_code = """
data {
int<lower=0> J; // number of schools
array[J] real y; // estimated treatment effects
array[J] real<lower=0> sigma; // standard error of effect estimates
}
parameters {
real mu; // population treatment effect
real log_tau; // standard deviation in treatment effects
vector[J] eta; // unscaled deviation from mu by school
}
transformed parameters {
vector[J] theta = mu + exp(log_tau) * eta; // school treatment effects
}
model {
target += normal_lpdf(eta | 0, 1); // prior log-density
target += normal_lpdf(log_tau | 5, 1);
target += normal_lpdf(mu | 0, 10);
target += normal_lpdf(y | theta, sigma); // log-likelihood
}
"""

schools_data = {
"J": 8,
"y": [28, 8, -3, 7, -1, 1, 18, 12],
"sigma": [15, 10, 16, 11, 9, 11, 10, 18],
}

posterior = stan.build(schools_code, data=schools_data)

# %%
# Ten dimensional (mu, tau, eta): theta is a constrained parameter.
model = stan_tools.StanModel(posterior, dim=10)

# %%
# Test evaluation of the pdf and logpdf
pts = torch.randn((100, model.dim))
pdfs = model.log_dens_batch(pts, None)
grad_log_pdfs = model.grad_log_dens_batch(pts, None)
grad_log_pdfs_2, pdfs_2 = model.grad_val_log_dens_batch(pts, None)

# %%
GRADIENT_DECAY = 0.95
N_PARTICLES = 100
KERNEL_DIAG_INFL = 1e-6
KERNEL_LENGTHSCALE = 1e-2
target_msip_fr = MSIPFredholm(GRADIENT_DECAY, model.grad_val_log_dens_batch)
init_eta = torch.randn((N_PARTICLES, 8))
init_log_tau = torch.randn((N_PARTICLES, 1)) + 5
init_mu = torch.randn((N_PARTICLES, 1)) * 10
init_particles = torch.column_stack((init_mu, init_log_tau, init_eta))
msip = MSIP(
model.dim,
N_PARTICLES,
kernel_diag_infl=KERNEL_DIAG_INFL,
kernel_lengthscale=KERNEL_LENGTHSCALE,
kernel_lengthscale_quantile=0.05
)

# %%
N_STEPS = 1000
LR = 1e-3
trajectories_msip_fr = nak_torch.nak(
target_msip_fr,
msip,
N_STEPS,
LR,
init_particles=init_particles,
bounds=(-100.0, 100.0),
)
trajectories_pts_msip_fr, trajectories_wts_msip_fr = trajectories_msip_fr

# %%
msip_fr_end = trajectories_pts_msip_fr[-1]
eta_end = msip_fr_end[:,:8] - init_particles[:,:8]
mean_sq_shift = (msip_fr_end - init_particles).square().sum() / init_particles.square().sum()
print(mean_sq_shift)
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ nak-torch = "nak_torch:main"
examples = [
"ipykernel>=7.2.0",
"matplotlib>=3.10.8",
"posteriordb>=0.2.0",
"pyro-ppl>=1.9.1",
"pystan",
"scipy>=1.17.1",
]

Expand Down Expand Up @@ -54,3 +56,6 @@ testpaths = [
# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or
# McCabe complexity (`C901`) by default.
ignore = ["F722"]

[tool.uv.sources]
pystan = { git = "https://github.com/dannys4/pystan", branch = "change_function_interface" }
Loading
Loading