Skip to content

PyTorch Impl. of Prediction Optimizer (to stabilize GAN training)

Notifications You must be signed in to change notification settings

sanghoon/prediction_gan

Repository files navigation

Prediction Optimizer (to stabilize GAN training)

Introduction

This is a PyTorch implementation of 'prediction method' introduced in the following paper ...

  • Abhay Yadav et al., Stabilizing Adversarial Nets with Prediction Methods, ICLR 2018, Link
  • (Just for clarification, I'm not an author of the paper.)

The authors proposed a simple (but effective) method to stabilize GAN trainings. With this Prediction Optimizer, you can easily apply the method to your existing GAN codes. This impl. is compatible with most of PyTorch optimizers and network structures. (Please let me know if you have any issues using this)

How-to-use

Instructions

  • Import prediction.py
    • from prediction import PredOpt
  • Initialize just like an optimizer
    • pred = PredOpt(net.parameters())
  • Run the model in a 'with' block to get results from a model with predicted params.
    • With 'step' argument, you can control lookahead step size (1.0 by default)
    • with pred.lookahead(step=1.0):
          output = net(input)
  • Call step() after an update of the network parameters
    • optim_net.step()
      pred.step()

Samples

  • You can find a sample code in this repository (example_gan.py)
  • A sample snippet
  • import torch.optim as optim
    from prediction import PredOpt
    
    
    # ...
    
    optim_G = optim.Adam(netG.parameters(), lr=0.01)
    optim_D = optim.Adam(netD.parameters(), lr=0.01)
    
    pred_G = PredOpt(netG.parameters())             # Create an prediction optimizer with target parameters
    pred_D = PredOpt(netD.parameters())
    
    
    for i, data in enumerate(dataloader, 0):
        # (1) Training D with samples from predicted generator
        with pred_G.lookahead(step=1.0):            # in the 'with' block, the model works as a 'predicted' model
            fake_predicted = netG(Z)                           
        
            # Compute gradients and loss 
        
            optim_D.step()
            pred_D.step()
        
        # (2) Training G
        with pred_D.lookahead(step=1.0:)            # 'Predicted D'
            fake = netG(Z)                          # Draw samples from the real model. (not predicted one)
            D_outs = netD(fake)
    
            # Compute gradients and loss
    
            optim_G.step()
            pred_G.step()                           # You should call PredOpt.step() after each update

Output samples

You can find more images at the following issues.

Training w/ large learning rate (0.01)

Vanilla DCGAN DCGAN w/ prediction (step=1.0)
ep25_cifar_base_lr 0 01 ep25_cifar_pred_lr 0 01
ep25_celeba_base_lr 0 01 ep25_celeba_pred_lr 0 01

Training w/ medium learning rate (1e-4)

Vanilla DCGAN DCGAN w/ prediction (step=1.0)
ep25_cifar_base_lr 0 0001 ep25_cifar_pred_lr 0 0001
ep25_celeba_base_lr 0 0001 ep25_celeba_pred_lr 0 0001

Training w/ small learning rate (1e-5)

Vanilla DCGAN DCGAN w/ prediction (step=1.0)
ep25_cifar_base_lr 0 00001 ep25_cifar_pred_lr 0 00001
ep25_celeba_base_lr 0 00001 ep25_celeba_pred_lr 0 00001

External links

TODOs

  • : Impl. as an optimizer
  • : Support pip install
  • : Add some experimental results

Releases

No releases published

Packages

No packages published