import time
from ops import *
from utils import *
from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch
class BigGAN(object):
def __init__(self, sess, args):
self.model_name = "BigGAN" # name for checkpoint
self.sess = sess
self.dataset_name = args.dataset
self.checkpoint_dir = args.checkpoint_dir
self.sample_dir = args.sample_dir
self.result_dir = args.result_dir
self.log_dir = args.log_dir
self.epoch = args.epoch
self.iteration = args.iteration
self.batch_size = args.batch_size
self.print_freq = args.print_freq
self.save_freq = args.save_freq
self.img_size = args.img_size
""" Generator """
self.layer_num = int(np.log2(self.img_size)) - 3
self.z_dim = args.z_dim # dimension of noise-vector
self.gan_type = args.gan_type
""" Discriminator """
self.n_critic = args.n_critic
self.sn = args.sn
self.ld = args.ld
self.args =args
self.sample_num = args.sample_num # number of generated images to be saved
self.test_num = args.test_num
# train
self.g_learning_rate = args.g_lr
self.d_learning_rate = args.d_lr
self.beta1 = args.beta1
self.beta2 = args.beta2
self.custom_dataset = False
if self.dataset_name == 'mnist' :
self.c_dim = 1
self.data = load_mnist(size=self.img_size)
elif self.dataset_name == 'cifar10' :
self.c_dim = 3
self.data = load_cifar10(size=self.img_size)
else:
self.c_dim = 3
print('------dataset ----', self.dataset_name)
self.data = load_data(dataset_name=self.dataset_name, size=self.img_size)
# print("----self.data ---", self.data)
self.custom_dataset = True
if self.args.phase == 'test':
self.custom_dataset = False
self.dataset_num = len(self.data)
self.sample_dir = os.path.join(self.sample_dir, self.model_dir)
check_folder(self.sample_dir)
print()
print("##### Information #####")
print("# gan type : ", self.gan_type)
print("# dataset : ", self.dataset_name)
print("# dataset number : ", self.dataset_num)
print("# batch_size : ", self.batch_size)
print("# epoch : ", self.epoch)
print("# iteration per epoch : ", self.iteration)
print()
print("##### Generator #####")
print("# generator layer : ", self.layer_num)
print()
print("##### Discriminator #####")
print("# discriminator layer : ", self.layer_num)
print("# the number of critic : ", self.n_critic)
print("# spectral normalization : ", self.sn)
##################################################################################
# Generator
##################################################################################
def generator(self, z, is_training=True, reuse=False):
with tf.variable_scope("generator", reuse=reuse):
ch = 1024
x = fully_connected(z, units=4 * 4 * ch, sn=self.sn, scope='fc')
x = tf.reshape(x, [-1, 4, 4, ch])
x = up_resblock(x, channels=ch, is_training=is_training, sn=self.sn, scope='front_resblock_0')
for i in range(self.layer_num // 2) :
x = up_resblock(x, channels=ch // 2, is_training=is_training, sn=self.sn, scope='middle_resblock_' + str(i))
ch = ch // 2
x = self.google_attention(x, channels=ch, scope='self_attention')
for i in range(self.layer_num // 2, self.layer_num) :
x = up_resblock(x, channels=ch // 2, is_training=is_training, sn=self.sn, scope='back_resblock_' + str(i))
ch = ch // 2
x = batch_norm(x, is_training)
x = relu(x)
x = conv(x, channels=self.c_dim, kernel=3, stride=1, pad=1, pad_type='zero', scope='g_logit')
x = tanh(x)
# x = tf.identity(x, name='fake_image')
return x
##################################################################################
# Discriminator
##################################################################################
def discriminator(self, x, reuse=False):
with tf.variable_scope("discriminator", reuse=reuse):
ch = 64
x = init_down_resblock(x, channels=ch, sn=self.sn, scope='init_resblock')
x = down_resblock(x, channels=ch * 2, sn=self.sn, scope='front_down_resblock')
x = self.google_attention(x, channels=ch * 2, scope='self_attention')
ch = ch * 2
for i in range(self.layer_num) :
if i == self.layer_num - 1 :
x = down_resblock(x, channels=ch, sn=self.sn, to_down=False, scope='middle_down_resblock_' + str(i))
else :
x = down_resblock(x, channels=ch * 2, sn=self.sn, scope='middle_down_resblock_' + str(i))
ch = ch * 2
x = lrelu(x, 0.2)
x = global_sum_pooling(x)
x = fully_connected(x, units=1, sn=self.sn, scope='d_logit')
return x
def attention(self, x, channels, scope='attention'):
with tf.variable_scope(scope):
f = conv(x, channels // 8, kernel=1, stride=1, sn=self.sn, scope='f_conv') # [bs, h, w, c']
g = conv(x, channels // 8, kernel=1, stride=1, sn=self.sn, scope='g_conv') # [bs, h, w, c']
h = conv(x, channels, kernel=1, stride=1, sn=self.sn, scope='h_conv') # [bs, h, w, c]
# N = h * w
s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True) # # [bs, N, N]
beta = tf.nn.softmax(s) # attention map
o = tf.matmul(beta, hw_flatten(h)) # [bs, N, C]
gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0))
o = tf.reshape(o, shape=x.shape) # [bs, h, w, C]
o = conv(o, channels, kernel=1, stride=1, sn=self.sn, scope='attn_conv')
x = gamma * o + x
return x
def google_attention(self, x, channels, scope='attention'):
with tf.variable_scope(scope):
batch_size, height, width, num_channels = x.get_shape().as_list()
f = conv(x, channels // 8, kernel=1, stride=1, sn=self.sn, scope='f_conv') # [bs, h, w, c']
f = max_pooling(f)
g = conv(x, channels // 8, kernel=1, stride=1, sn=self.sn, scope='g_conv') # [bs, h, w, c']
h = conv(x, channels // 2, kernel=1, stride=1, sn=self.sn, scope='h_conv') # [bs, h, w, c]
h = max_pooling(h)
# N = h * w
s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True) # # [bs, N, N]
beta = tf.nn.softmax(s) # attention map
o = tf.matmul(beta, hw_flatten(h)) # [bs, N, C]
gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0))
o = tf.reshape(o, shape=[batch_size, height, width, num_channels // 2]) # [bs, h, w, C]
o = conv(o, channels, kernel=1, stride=1, sn=self.sn, scope='attn_conv')
x = gamma * o + x
return x
def gradient_penalty(self, real, fake):
if self.gan_type == 'dragan' :
shape = tf.shape(real)
eps = tf.random_uniform(shape=shape, minval=0., maxval=1.)
x_mean, x_var = tf.nn.moments(real, axes=[0, 1, 2, 3])
x_std = tf.sqrt(x_var) # magnitude of noise decides the size of local region
noise = 0.5 * x_std * eps # delta in paper
# Author suggested U[0,1] in original paper, but he admitted it is bug in github
# (https://github.com/kodalinaveen3/DRAGAN). It should be two-sided.
alpha = tf.random_uniform(shape=[sha