Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
5019456
Move pyro_tools
dannys4 Apr 9, 2026
4cf254a
add pyro as an optional dependency
dannys4 Apr 9, 2026
605f670
update examples option
dannys4 Apr 12, 2026
acb2253
Working logistic regression example
dannys4 Apr 12, 2026
15fc812
Keep working on covtype
dannys4 Apr 15, 2026
b12fb18
Merge branch 'main' into dannys4/logistic_reg
dannys4 Apr 17, 2026
491f4a5
Fix __init__
dannys4 Apr 17, 2026
09274f7
Fix logistic regression prior
dannys4 Apr 18, 2026
5aabe44
Clean up simple logreg
dannys4 Apr 18, 2026
696c1f6
Work on covtype logreg
dannys4 Apr 18, 2026
d57250d
Start creating core loop
dannys4 Apr 21, 2026
bde9a8c
Remove old algorithms
dannys4 Apr 21, 2026
e12c463
First draft MSIP with new setup
dannys4 Apr 21, 2026
fd89354
Consolidate reused MSIP code
dannys4 Apr 23, 2026
a17e377
Finish initial implementations for the most part
dannys4 Apr 23, 2026
11b1eb4
Work on file organization
dannys4 Apr 23, 2026
bb24bae
Work on himmelblau
dannys4 Apr 23, 2026
19217ad
Fix Gaussian example
dannys4 Apr 23, 2026
545b0ae
Working GF-ALDI
dannys4 Apr 24, 2026
4b8b6aa
Fix EKS
dannys4 Apr 24, 2026
66e84b0
Remove redundant gaussian example
dannys4 Apr 24, 2026
5b7f93a
Fix gaussianmodel
dannys4 Apr 24, 2026
0227b6e
Fix sigma_sq
dannys4 Apr 24, 2026
ae35e0c
Fix MSIP_GS
dannys4 Apr 24, 2026
85e19e9
Add every algorithm so far to Gaussian example
dannys4 Apr 24, 2026
edc62fd
Work on AB example
dannys4 Apr 25, 2026
1e2683a
Merge branch 'dannys4/logistic_reg' into dannys4/main_loop
dannys4 Apr 25, 2026
bc5a158
Add DeepEnsembles
dannys4 Apr 25, 2026
fe304ee
Add scipy as dev dep
dannys4 Apr 26, 2026
b3b1cb3
Work on batching
dannys4 Apr 26, 2026
6e965af
Add batching to covtype
dannys4 Apr 26, 2026
e18229b
Add adagrad for svgd
dannys4 Apr 26, 2026
42c7064
Adjust logistic regression interface
dannys4 Apr 26, 2026
411bb70
Add svgd to covertype
dannys4 Apr 26, 2026
c7699e5
Add stanpy and posteriordb to deps for examples
dannys4 Apr 26, 2026
4341df7
Fix stan versioning
dannys4 Apr 26, 2026
39a26d5
Add basic stan model
dannys4 Apr 27, 2026
d1c1797
Change version dep of stan
dannys4 Apr 27, 2026
5e16218
Adapt to different stan interface
dannys4 Apr 27, 2026
3847703
make schools example work
dannys4 Apr 27, 2026
d332836
Start on PosteriorDB example
dannys4 Apr 27, 2026
8be95e6
Work on fixing logistic regression
dannys4 Apr 27, 2026
5a150c9
Add test for the logistic regression model
dannys4 Apr 28, 2026
6758e96
Fix grad-informed estimator
dannys4 Apr 28, 2026
f4fba16
switch to usual inv
dannys4 Apr 28, 2026
ebbe8a2
Satisfactory logreg
dannys4 Apr 28, 2026
c483716
Remove old comment
dannys4 Apr 28, 2026
6d7839f
Update readme to include stan info
dannys4 Apr 28, 2026
792229c
Finish BNN example prelim
dannys4 Apr 29, 2026
d8dc2dd
Merge branch 'dannys4/main_loop' into dannys4/stan
dannys4 Apr 29, 2026
86b830f
Change PDB version
dannys4 Apr 29, 2026
05f5e04
Move to cap'ing default kernel
dannys4 Apr 29, 2026
7b861f8
Add sample MMD metric
dannys4 Apr 29, 2026
bbad439
minimal PDB example
dannys4 Apr 29, 2026
2738e64
Fix up posteriordb ver
dannys4 May 2, 2026
15dfab6
Working PDB
dannys4 May 3, 2026
d3a86f4
Fix metrics test
dannys4 May 3, 2026
1366e5c
Merge pull request #35 from Nodes-and-Kernels/dannys4/stan
dannys4 May 3, 2026
0ece5a9
Specify device and dtype
dannys4 May 4, 2026
af1272c
Update type annotations for nak function parameters
dannys4 May 4, 2026
04fb72f
Fixing dtype and device issues
dannys4 May 4, 2026
3fa28c3
Add generator for MC quad
dannys4 May 4, 2026
2b6881c
Himmelblau stuff
dannys4 May 12, 2026
f93bbad
final himmelblau thing
dannys4 May 12, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ __pycache__
.vscode
*.xml
*.pdf
.env
33 changes: 33 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1 +1,34 @@
# Kernel-based quantization algorithms

## Installation
We recommend installing with `uv`. Currently, the way to install this locally would be
```bash
$ uv pip install -e git+https://github.com/Nodes-and-Kernels/nak_torch
```

If you plan on using the examples, make sure that `[examples]` option is installed. Also, make sure that there is no other installation of `pystan`, which is a dependency---we use a fork of the original package to reduce latency for our algorithms when using a stan posterior.

## List of Algorithms
### MSIP
We largely focus on _mean-shift interacting particle_ (MSIP) algorithms, and we are working to implement several of these. Currently, we have:

- MSIP
- MSIPGS

For these algorithms, we have multiple estimators---each of these produces a certain set of dynamics. In particular, we have:

- MSIPFredholm
- MSIPGradientFree
- MSIPGradientInformed
- MSIPGMMGaussianKernel

### Other algorithms
We also include several other typical interacting-particle sampling algorithms.

- Consensus-based sampler (`CBS`)
- Deep ensembles (`DeepEnsembles`)
- Ensemble Kalman Sampler (`EKS`)
- Gradient-informed affine-invariant Langevin dynamics (`GradALDI`)
- Gradient-free ALDI (`GradFreeALDI`)
- Stein variational gradient descent (`SVGD`)

180 changes: 110 additions & 70 deletions examples/aristoff_bangerth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
import math
import torch
from functions import aristoff_bangerth as ab, build_aristoff_bangerth
from nak_torch.algorithms import msip, svgd
import nak_torch
from nak_torch.algorithms import MSIP, SVGD
from nak_torch.algorithms.msip import MSIPFredholm
from matplotlib import ticker
import gc
import matplotlib.pyplot as plt
from nak_torch.tools.kernel import sqexp_kernel_matrix
from tqdm import tqdm
import pandas as pd
import pyro_tools
from nak_torch.tools.types import BatchGradLogDensityEvaluator
from nak_torch.tools import pyro_tools
from pyro.infer import mcmc

if torch.cuda.is_available():
Expand All @@ -18,25 +21,73 @@
torch.set_default_device("cpu")
torch.set_default_dtype(torch.float64)

# %%
def plot_samples(pts, max_side_len = 6):
n_particles = pts.shape[0]
side_len = min(max_side_len, int(math.floor(math.sqrt(n_particles))))
pts = pts[:side_len**2]
fig = plt.figure(figsize=(9, 6), layout='constrained')
gs = fig.add_gridspec(side_len, side_len + 2)
vabs = max(pts.min().abs(), pts.max().abs())
plt_kwargs = {'vmin': -vabs, 'vmax': vabs, 'extent': (0, 8, 0, 8)}

for i in range(side_len):
for j in range(side_len):
ax = fig.add_subplot(gs[i, j])
# ax.set_axis_off()
ax.set_aspect('equal')
t = ax.matshow(pts[i*side_len + j].reshape(8, 8), **plt_kwargs)
# ax.vlines(jnp.arange(1,8), -0.1, 8.1, color='w', lw=0.75)
# ax.hlines(jnp.arange(1,8), -0.1, 8.1, color='w', lw=0.75)
ax.minorticks_on()
ax.set_xticks([])
ax.set_yticks([])
ax.xaxis.set_minor_locator(ticker.MultipleLocator())
ax.yaxis.set_minor_locator(ticker.MultipleLocator())
ax.grid(which="both", linewidth=1.5, color="w")
ax.tick_params(which="minor", length=0)
ax_cb = fig.add_subplot(gs[:-2, -2:])
ax_cb.set_title(r"Scale of $\log\theta$", y=0.6)
cax_cb = ax_cb.inset_axes((0.1, 0.45, 0.8, 0.1))
ax_cb.axis('off')
fig.colorbar(t, cax=cax_cb, orientation='horizontal') # type: ignore
ax_true = fig.add_subplot(gs[-2:, -2:])
ax_true.set_aspect('equal')
ax_true.matshow(ab.theta_true.log().reshape(8, 8), **plt_kwargs)
ax_true.minorticks_on()
ax_true.set_xticks([])
ax_true.set_yticks([])
ax_true.xaxis.set_minor_locator(ticker.MultipleLocator())
ax_true.yaxis.set_minor_locator(ticker.MultipleLocator())
ax_true.set_title(r"True $\theta$")
ax_true.grid(which="both", linewidth=1.5, color="w")
ax_true.tick_params(which="minor", length=0)
return fig


# %%
use_compiled = True
model = build_aristoff_bangerth(use_compiled=use_compiled, dtype=torch.float64)
log_p = model.to_log_dens(use_compiled=use_compiled)
log_th = torch.randn(500, 64, dtype=torch.float64)
test_out = log_p(log_th)
log_th = torch.randn(25, 64, dtype=torch.float64)
test_out = log_p(log_th, None)

# %%
grad_log_p = torch.func.grad(lambda t: log_p(t).sum())
test_eval = grad_log_p(log_th)
def _tmp_log_p(log_theta, arg: None):
ret = log_p(log_theta, arg)
return ret.sum(), ret

grad_log_p = torch.func.grad(lambda t,a: log_p(t, a).sum())
grad_val_log_p = torch.func.grad(_tmp_log_p, has_aux=True)
test_grad = grad_log_p(log_th, None)
test_grad_2, test_out_2 = grad_val_log_p(log_th, None)

# %%
del log_th
del test_out
# del test_eval
del log_th, test_out, test_grad, test_grad_2, test_out_2
gc.collect()

# %%
n_particles, n_steps, dim = 500, 25, 64
n_particles, n_steps, dim = 25, 25, 64
kernel_bandwidth = 0.75

torch.manual_seed(1)
Expand All @@ -45,84 +96,73 @@
dtype=torch.float64,
) # Sample from prior

msip_args = {
default_kwargs = {
"dim": dim,
"bounds": (-8, 8),
"n_steps": n_steps,
"n_particles": n_particles,
"n_steps": n_steps, # "epochs" (passes over all particles)
"dim": 64,
"bounds": (-8, 8), # [a,b]^d
"gradient_informed": True,
"keep_all": False,
"lr": 1e-1,
"noise": 0.05, # currently unused
"kernel_lengthscale": 0.1,
"init_particles": init_particles,
"kernel_bandwidth": kernel_bandwidth,
"bandwidth_factor": 0.25,
"seed": 0,
"kernel_diag_infl": 1e-10,
"keep_all": False,
"device": None
"gradient_decay": 0.95,
"kernel_diag_infl": 1e-6,
}

# %%
trajectories_msip = msip(
log_p,
**msip_args
)
msip_kwargs = default_kwargs.copy()
msip_kwargs["lr"] = 1e-2
msip_kwargs["kernel_lengthscale_quantile"] = 0.05
msip = MSIP(**msip_kwargs)
target_msip_fr = MSIPFredholm(log_dens_grad_val=grad_val_log_p, **msip_kwargs)

# %%
trajectories_pts_msip_fr, trajectories_wts_msip_fr = nak_torch.nak(target_msip_fr, msip, **msip_kwargs)

# %%
n_steps_hmc = 1000
n_steps_hmc = 100
pyro_model = pyro_tools.PyroModel(model, dim)
hmc_kernel = mcmc.NUTS(pyro_model)
mcmc_setup = mcmc.MCMC(hmc_kernel, num_samples=n_steps_hmc, warmup_steps=100)
mcmc_setup = mcmc.MCMC(hmc_kernel, num_samples=n_steps_hmc, warmup_steps=10)
mcmc_setup.run(model.true_obs)

hmc_samples = mcmc_setup.get_samples()["theta"]

# %%
trajectories_svgd = svgd(
log_p,
is_log_density_batched=True,
**msip_args
target_svgd = BatchGradLogDensityEvaluator(
log_p, is_grad=False, is_batched=True
)

svgd = SVGD(
kernel_lengthscale_quantile=0.5, # Median heuristic
**msip_kwargs
)
svgd_kwargs = msip_kwargs.copy()
svgd_kwargs["lr"] = 1e-1
svgd_kwargs["n_steps"] = 100

trajectories_pts_svgd = nak_torch.nak(
target_svgd,
svgd,
**svgd_kwargs
)

# %%
pts_msip = trajectories_pts_msip_fr[-1] - init_particles
fig = plot_samples(pts_msip)
fig.suptitle("MSIP Samples")
plt.show()

# %%
pts_hmc = hmc_samples[10::3]
fig = plot_samples(pts_hmc)
fig.suptitle("HMC Samples")
plt.show()

# %%
side_len = min(6, int(math.floor(math.sqrt(n_particles))))
pts = trajectories_msip[-1][:side_len**2].detach().cpu()# - init_particles[:side_len**2]
fig = plt.figure(figsize=(9, 6), layout='constrained')
gs = fig.add_gridspec(side_len, side_len + 2)
vabs = max(pts.min().abs(), pts.max().abs())
plt_kwargs = {'vmin': -vabs, 'vmax': vabs, 'extent': (0, 8, 0, 8)}

for i in range(side_len):
for j in range(side_len):
ax = fig.add_subplot(gs[i, j])
# ax.set_axis_off()
ax.set_aspect('equal')
t = ax.matshow(pts[i*side_len + j].reshape(8, 8), **plt_kwargs)
# ax.vlines(jnp.arange(1,8), -0.1, 8.1, color='w', lw=0.75)
# ax.hlines(jnp.arange(1,8), -0.1, 8.1, color='w', lw=0.75)
ax.minorticks_on()
ax.set_xticks([])
ax.set_yticks([])
ax.xaxis.set_minor_locator(ticker.MultipleLocator())
ax.yaxis.set_minor_locator(ticker.MultipleLocator())
ax.grid(which="both", linewidth=1.5, color="w")
ax.tick_params(which="minor", length=0)
ax_cb = fig.add_subplot(gs[:-2, -2:])
ax_cb.set_title(r"Scale of $\log\theta$", y=0.6)
cax_cb = ax_cb.inset_axes((0.1, 0.45, 0.8, 0.1))
ax_cb.axis('off')
fig.colorbar(t, cax=cax_cb, orientation='horizontal')
ax_true = fig.add_subplot(gs[-2:, -2:])
ax_true.set_aspect('equal')
ax_true.matshow(ab.theta_true.log().reshape(8, 8), **plt_kwargs)
ax_true.minorticks_on()
ax_true.set_xticks([])
ax_true.set_yticks([])
ax_true.xaxis.set_minor_locator(ticker.MultipleLocator())
ax_true.yaxis.set_minor_locator(ticker.MultipleLocator())
ax_true.set_title(r"True $\theta$")
ax_true.grid(which="both", linewidth=1.5, color="w")
ax_true.tick_params(which="minor", length=0)
pts_svgd = trajectories_pts_svgd[-1]
fig = plot_samples(pts_svgd)
fig.suptitle("SVGD Samples")
plt.show()

# %%
Expand Down
Loading