Skip to content

Implementation of Denoising Diffusion Probabilistic Models (DDPM) in JAX and Flax.

Notifications You must be signed in to change notification settings

andylolu2/jax-diffusion

Repository files navigation

JAX Diffusion

Unofficial implementation of Denoising Diffusion Probabilistic Models (DDPM) in JAX and Flax.

Denoising Diffusion Implicit Models (DDIM) sampling is used as well.

MNIST

Real Generated
img img

Training details

Model has 5.46M parameters, trained on Colab (T4) for 100K steps with batch size 128 in 8.5 hours.

Full hyperparameters can be found in configs/mnist.py.

Fashion MNIST

Real Generated
img img

Training details

Model has 9.70M parameters, trained on Kaggle (TPUv3-8) for 40K steps with batch size 128 in 2.5 hours.

Full hyperparameters can be found in configs/fashion_mnist.py.

Celeb A

Results

Real Generated
img img

Training details

Due to compute constraints, the model is only trained for 64 x 64 images.

Model has 72.70M parameters, trained on Kaggle (P100) for 60K steps with batch size 64 in 22 hours.

Full hyperparameters can be found in configs/celeb_a64.py.