Skip to content

andylolu2/jax-diffusion

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

103 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

JAX Diffusion

Unofficial implementation of Denoising Diffusion Probabilistic Models (DDPM) in JAX and Flax.

Denoising Diffusion Implicit Models (DDIM) sampling is used as well.

MNIST

Real Generated
img img

Training details

Model has 5.46M parameters, trained on Colab (T4) for 100K steps with batch size 128 in 8.5 hours.

Full hyperparameters can be found in configs/mnist.py.

Fashion MNIST

Real Generated
img img

Training details

Model has 9.70M parameters, trained on Kaggle (TPUv3-8) for 40K steps with batch size 128 in 2.5 hours.

Full hyperparameters can be found in configs/fashion_mnist.py.

Celeb A

Results

Real Generated
img img

Training details

Due to compute constraints, the model is only trained for 64 x 64 images.

Model has 72.70M parameters, trained on Kaggle (P100) for 60K steps with batch size 64 in 22 hours.

Full hyperparameters can be found in configs/celeb_a64.py.

About

Implementation of Denoising Diffusion Probabilistic Models (DDPM) in JAX and Flax.

Topics

Resources

Stars

Watchers

Forks

Contributors