Skip to content

A tensorflow implementation of VAE-GAN. This is the first approach which viewed the discriminator as a loss function to improve.

Notifications You must be signed in to change notification settings

PrateekMunjal/Autoencoding-beyond-pixels-using-a-learned-similarity-metric

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

20 Commits
 
 
 
 
 
 

Repository files navigation

VAE/GAN

A Tensorflow implementation of VAE-GAN, following the paper: VAE/GAN. The encoder and decoder functions are implemented using fully strided convoluttional layers and transposed convolution layers respectively. The discriminator network has the same architecture as that of encoder with an additional last one layer of its output. As suggested by authors I have implemented Gaussian decoders and Gaussian prior.

Setup

  • Python 3.5+
  • Tensorflow 1.9

Relevant Code Files

File config.py contains the hyper-parameters for VAE/GAN reported results.

File vae-gan.py contains the code to train VAE/GAN model.

Similarly, as the name suggests, file vae-gan_inference.py contains the code to test the trained VAE/GAN model.

Usage

Training a model

NOTE: For celebA, make sure you have the downloaded dataset from here and keep it in the current directory of project.

python vae-gan.py

Test a trained model

First place the model weights in model_directory (mentioned in vae-gan_inference.py) and then:

python vae-gan_inference.py 

Emprical Observations

  • I observed that sometimes the presence of KL-divergence term in the loss of encoder network makes the model training cumbersome.

  • The only hyper-parameter I tweaked to alleviate the above issue is weight mutiplied to this KL term. Almost always, the KL-term weight equal to 1/batch_size works.

  • Another alternate I tried for Kl weight was taking as a function of epoch i.e sigmoid(epoch).

  • Intuitively, the dynamic Kl weight made more sense as with increasing epochs we increased the weight, therefore the model does not pay attention to KL divergence term in initial iterations. However, one should ask why do we want the model to not focus in initial iterations?

  • The reason is that we free the latent space variables in initial iterations to make them learn, meaningful representations responsible for reconstructing the input and with increasing epochs we make the latent distribution close to our prior as we increase KL term weight with epochs.

  • But why did not we used some other function like exp(epochs)? -- It is also a monotonic function.

  • While increasing the weight of KL term, we should have some limit else the model may completely focus on this term. Therefore we choose a function which has a saturation on large values of input.

Model weights

The weights for presented results in this repository are mentioned below which essentially are shared on google drive.

Generations

MNIST Celeb-A

Reconstructions

  • For MNIST dataset

    • At epoch: 1
    MNIST Original MNIST Reconstruction
    • At epoch: 50
    MNIST Original MNIST Reconstruction
  • For CelebA dataset

    • At epoch: 1
    Celeb-A Original Celeb-A Reconstruction
    • At epoch: 15
    Celeb-A Original Celeb-A Reconstruction

About

A tensorflow implementation of VAE-GAN. This is the first approach which viewed the discriminator as a loss function to improve.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages