Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added spectral normalization for UNet_256 G and NLayer D #750

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
147 changes: 145 additions & 2 deletions models/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, in
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
ngf (int) -- the number of filters in the last conv layer
netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_256_spectral | unet_128
norm (str) -- the name of normalization layers used in the network: batch | instance | none
use_dropout (bool) -- if use dropout layers.
init_type (str) -- the name of our initialization method.
Expand Down Expand Up @@ -154,6 +154,8 @@ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, in
net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
elif netG == 'unet_256':
net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
elif netG == 'unet_256_spectral':
net = UnetGeneratorWithSpectralNorm(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
else:
raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
return init_net(net, init_type, init_gain, gpu_ids)
Expand All @@ -165,7 +167,7 @@ def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal'
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the first conv layer
netD (str) -- the architecture's name: basic | n_layers | pixel
netD (str) -- the architecture's name: basic | n_layers | n_layers_spectral | pixel
n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
norm (str) -- the type of normalization layers used in the network.
init_type (str) -- the name of the initialization method.
Expand Down Expand Up @@ -196,6 +198,8 @@ def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal'
net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)
elif netD == 'n_layers': # more options
net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer)
elif netD == 'n_layers_spectral':
net = NLayerDiscriminatorWithSpectralNorm(input_nc, ndf, n_layers_D)
elif netD == 'pixel': # classify if each pixel is real or fake
net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
else:
Expand Down Expand Up @@ -534,6 +538,105 @@ def forward(self, x):
else: # add skip connections
return torch.cat([x, self.model(x)], 1)

class UnetGeneratorWithSpectralNorm(nn.Module):
"""Create a Unet-based generator with Spectral Normalization"""

def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
"""Construct a Unet generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
image of size 128x128 will become of size 1x1 # at the bottleneck
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer

We construct the U-Net from the innermost layer to the outermost layer.
It is a recursive process.
"""
super(UnetGeneratorWithSpectralNorm, self).__init__()
# construct unet structure
unet_block = UnetSkipConnectionBlockWithSpectralNorm(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
unet_block = UnetSkipConnectionBlockWithSpectralNorm(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
# gradually reduce the number of filters from ngf * 8 to ngf
unet_block = UnetSkipConnectionBlockWithSpectralNorm(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlockWithSpectralNorm(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlockWithSpectralNorm(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
self.model = UnetSkipConnectionBlockWithSpectralNorm(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer

def forward(self, input):
"""Standard forward"""
return self.model(input)

class UnetSkipConnectionBlockWithSpectralNorm(nn.Module):
"""Defines the Unet submodule with skip connection and spectral normalization.
X -------------------identity----------------------
|-- downsampling -- |submodule| -- upsampling --|
"""

def __init__(self, outer_nc, inner_nc, input_nc=None,
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
"""Construct a Unet submodule with skip connections.

Parameters:
outer_nc (int) -- the number of filters in the outer conv layer
inner_nc (int) -- the number of filters in the inner conv layer
input_nc (int) -- the number of channels in input images/features
submodule (UnetSkipConnectionBlock) -- previously defined submodules
outermost (bool) -- if this module is the outermost module
innermost (bool) -- if this module is the innermost module
norm_layer -- normalization layer
user_dropout (bool) -- if use dropout layers.
"""
super(UnetSkipConnectionBlockWithSpectralNorm, self).__init__()
self.outermost = outermost
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
if input_nc is None:
input_nc = outer_nc
downconv = nn.utils.spectral_norm(nn.Conv2d(input_nc, inner_nc, kernel_size=4,
stride=2, padding=1, bias=use_bias))
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(True)
upnorm = norm_layer(outer_nc)

if outermost:
upconv = nn.utils.spectral_norm(nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1))
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = nn.utils.spectral_norm(nn.ConvTranspose2d(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias))
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
model = down + up
else:
upconv = nn.utils.spectral_norm(nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias))
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]

if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up

self.model = nn.Sequential(*model)

def forward(self, x):
if self.outermost:
return self.model(x)
else: # add skip connections
return torch.cat([x, self.model(x)], 1)

class NLayerDiscriminator(nn.Module):
"""Defines a PatchGAN discriminator"""
Expand Down Expand Up @@ -582,6 +685,46 @@ def forward(self, input):
"""Standard forward."""
return self.model(input)

class NLayerDiscriminatorWithSpectralNorm(nn.Module):
"""Defines a PatchGAN discriminator"""

def __init__(self, input_nc, ndf=64, n_layers=3):
"""Construct a PatchGAN discriminator

Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_layer -- normalization layer
"""
super(NLayerDiscriminatorWithSpectralNorm, self).__init__()
use_bias = True
kw = 4
padw = 1
sequence = [nn.utils.spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw, bias=use_bias)), nn.LeakyReLU(0.2, True)]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.utils.spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias)),
nn.LeakyReLU(0.2, True)
]

nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
nn.utils.spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias)),
nn.LeakyReLU(0.2, True)
]

sequence += [nn.utils.spectral_norm(nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw, bias=use_bias))] # output 1 channel prediction map
self.model = nn.Sequential(*sequence)

def forward(self, input):
"""Standard forward."""
return self.model(input)

class PixelDiscriminator(nn.Module):
"""Defines a 1x1 PatchGAN discriminator (pixelGAN)"""
Expand Down
4 changes: 2 additions & 2 deletions options/base_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def initialize(self, parser):
parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')
parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
parser.add_argument('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
parser.add_argument('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')
parser.add_argument('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | n_layers_spectral | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
parser.add_argument('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_256_spectral | unet_128]')
parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')
parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')
Expand Down