from __future__ import print_function
try:
import argparse
import os
import numpy as np
from torch.autograd import Variable
from torch.autograd import grad as torch_grad
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms
from torchvision.utils import save_image
from itertools import chain as ichain
except ImportError as e:
print(e)
raise ImportError
os.makedirs("images", exist_ok=True)
parser = argparse.ArgumentParser(description="ClusterGAN Training Script")
parser.add_argument("-n", "--n_epochs", dest="n_epochs", default=200, type=int, help="Number of epochs")
parser.add_argument("-b", "--batch_size", dest="batch_size", default=64, type=int, help="Batch size")
parser.add_argument("-i", "--img_size", dest="img_size", type=int, default=28, help="Size of image dimension")
parser.add_argument("-d", "--latent_dim", dest="latent_dim", default=30, type=int, help="Dimension of latent space")
parser.add_argument("-l", "--lr", dest="learning_rate", type=float, default=0.0001, help="Learning rate")
parser.add_argument("-c", "--n_critic", dest="n_critic", type=int, default=5, help="Number of training steps for discriminator per iter")
parser.add_argument("-w", "--wass_flag", dest="wass_flag", action='store_true', help="Flag for Wasserstein metric")
args = parser.parse_args()
# Sample a random latent space vector
def sample_z(shape=64, latent_dim=10, n_c=10, fix_class=-1, req_grad=False):
assert (fix_class == -1 or (fix_class >= 0 and fix_class < n_c) ), "Requested class %i outside bounds."%fix_class
Tensor = torch.cuda.FloatTensor
# Sample noise as generator input, zn
zn = Variable(Tensor(0.75*np.random.normal(0, 1, (shape, latent_dim))), requires_grad=req_grad)
######### zc, zc_idx variables with grads, and zc to one-hot vector
# Pure one-hot vector generation
zc_FT = Tensor(shape, n_c).fill_(0)
# zc_idx,生成长度为shape全部为0的tensor
zc_idx = torch.empty(shape, dtype=torch.long)
if (fix_class == -1):
#生成数值范围在0到n_c-1的tensor
zc_idx = zc_idx.random_(n_c).cuda()
#每行对应位置填1,one-hot
zc_FT = zc_FT.scatter_(1, zc_idx.unsqueeze(1), 1.)
else:
zc_idx[:] = fix_class
zc_FT[:, fix_class] = 1
zc_idx = zc_idx.cuda()
zc_FT = zc_FT.cuda()
zc = Variable(zc_FT, requires_grad=req_grad)
# Return components of latent space variable
return zn, zc, zc_idx
def calc_gradient_penalty(netD, real_data, generated_data):
# GP strength
LAMBDA = 10
b_size = real_data.size()[0]
# Calculate interpolation
alpha = torch.rand(b_size, 1, 1, 1)
alpha = alpha.expand_as(real_data)
alpha = alpha.cuda()
interpolated = alpha * real_data.data + (1 - alpha) * generated_data.data
interpolated = Variable(interpolated, requires_grad=True)
interpolated = interpolated.cuda()
# Calculate probability of interpolated examples
prob_interpolated = netD(interpolated)
# Calculate gradients of probabilities with respect to examples
gradients = torch_grad(outputs=prob_interpolated, inputs=interpolated,
grad_outputs=torch.ones(prob_interpolated.size()).cuda(),
create_graph=True, retain_graph=True)[0]
# Gradients have shape (batch_size, num_channels, img_width, img_height),
# so flatten to easily take norm per example in batch
gradients = gradients.view(b_size, -1)
# Derivatives of the gradient close to 0 can cause problems because of
# the square root, so manually calculate norm and add epsilon
gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)
# Return gradient penalty
return LAMBDA * ((gradients_norm - 1) ** 2).mean()
# Weight Initializer
def initialize_weights(net):
for m in net.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
elif isinstance(m, nn.ConvTranspose2d):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
# Softmax function
def softmax(x):
return F.softmax(x, dim=1)
class Reshape(nn.Module):
"""
Class for performing a reshape as a layer in a sequential model.
"""
def __init__(self, shape=[]):
super(Reshape, self).__init__()
self.shape = shape
def forward(self, x):
return x.view(x.size(0), *self.shape)
def extra_repr(self):
# (Optional)Set the extra information about this module. You can test
# it by printing an object of this class.
return 'shape={}'.format(
self.shape
)
class Generator_CNN(nn.Module):
"""
CNN to model the generator of a ClusterGAN
Input is a vector from representation space of dimension z_dim
output is a vector from image space of dimension X_dim
"""
# Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
def __init__(self, latent_dim, n_c, x_shape, verbose=False):
super(Generator_CNN, self).__init__()
self.name = 'generator'
self.latent_dim = latent_dim
self.n_c = n_c
self.x_shape = x_shape
self.ishape = (128, 7, 7)
self.iels = int(np.prod(self.ishape))
self.verbose = verbose
self.model = nn.Sequential(
# Fully connected layers
torch.nn.Linear(self.latent_dim + self.n_c, 1024),
nn.BatchNorm1d(1024),
nn.LeakyReLU(0.2, inplace=True),
torch.nn.Linear(1024, self.iels),
nn.BatchNorm1d(self.iels),
nn.LeakyReLU(0.2, inplace=True),
# Reshape to 128 x (7x7)
Reshape(self.ishape),
# Upconvolution layers
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1, bias=True),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose2d(64, 1, 4, stride=2, padding=1, bias=True),
nn.Sigmoid()
)
initialize_weights(self)
if self.verbose:
print("Setting up {}...\n".format(self.name))
print(self.model)
def forward(self, zn, zc):
#(64,30+10)
z = torch.cat((zn, zc), 1)
x_gen = self.model(z)
# Reshape for output
x_gen = x_gen.view(x_gen.size(0), *self.x_shape)
return x_gen
class Encoder_CNN(nn.Module):
"""
CNN to model the encoder of a ClusterGAN
Input is vector X from image space if dimension X_dim
Output is vector z from representation space of dimension z_dim
"""
def __init__(self, latent_dim, n_c, verbose=False):
super(Encoder_CNN, self).__init__()
self.name = 'encoder'
self.channels = 1
self.latent_dim = latent_dim
self.n_c = n_c
self.cshape = (128, 5, 5)
self.iels = int(np.prod(self.cshape))
self.lshape = (self.iels,)
self.verbose = verbose
self.model = nn.Sequential(
# Convolutional layers
nn.Conv2d(self.channels, 64, 4, stride=2, bias=True),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, 4, stride=2, bias=True),
nn.LeakyReLU(0.2, inplace=True),
# Flatten
Reshape(self.lshape),
# Fully connected layers
torch.nn.Linear(self.iels, 1024),
nn.LeakyReLU(0.2, inplace=True),
torch.nn.Linear(1024, latent_dim + n_c)
)
initialize_weights(self)
if self.verbose
评论0