🛠️ work-in-progress 🛠️
Bayalign is a lightweight JAX-based library for rigid point cloud registration using efficient Bayesian inference via geodesic slice sampling on the sphere (GeoSSS). See the package geosss for details.
The package is tailored for any rigid registration problem (currently doesn't support translation estimation) and has mainly been motivated from a scientific application such as Cryo-EM where the goal is to estimate the rotation of a 3D structure that best aligns with noisy or partial 2D projections.
- Supports 3D-2D and 3D-3D rigid registration
- GPU acceleration, Automatic differentiation, JIT via JAX
- Fast inference via GeoSSS
- Uses Gaussian Mixture Models (GMM) for scoring the rigid poses
pip install bayalignA basic example of 3D-to-2D registration:
import jax.numpy as jnp
from bayalign.pointcloud import PointCloud, RotationProjection
from bayalign.score import GaussianMixtureModel
from bayalign.inference import ShrinkageSphericalSliceSampler
from bayalign.sphere_utils import sample_sphere
# Define 2D target and 3D source point clouds
target = PointCloud(positions, weights) # shape (N, 2)
source = RotationProjection(positions, weights) # shape (M, 3)
# Define a target probability model using GMM
target_pdf = GaussianMixtureModel(target, source, sigma=1.0, k=20)
# Sample from the posterior over 3D rotations (quaternions)
init_q = sample_sphere(random.key(645), d=3) # initial quaternion (4,)
sampler = ShrinkageSphericalSliceSampler(target_pdf, init_q, seed=123)
samples = sampler.sample(n_samples=100, burnin=0.2)
# Find the best rotation
log_probs = jnp.asarray([target_pdf.log_prob(q) for q in samples])
best_rot = samples[jnp.argmax(log_probs)]
transformed_source = source.transform_positions(best_rot)For 3D-3D registration, use PointCloud for both target and source. Check out the examples directory for detailed use cases using synthetic and cryo-EM data.
To run the examples, you need to install some optional dependencies. Follow one of the methods below to set up your environment.
Clone the repository and navigate to the root.
git clone https://github.com/ShantanuKodgirwar/bayalign.git
cd bayalignThe package bayalign and all its locked dependencies are maintained by uv and can be installed within a virtual environment as:
uv sync --extra allThis also includes the dependencies needed to run the examples.
uv sync --extra all,cuda12 # or cuda13 for a newer GPU (SM > 7.5)If you encounter any problems, have questions, please feel free to open an issue.
