import torch
from torch import nn
from torch.autograd import Variable
import torchvision.transforms as tfs
from torch.utils.data import DataLoader, sampler
from torchvision.datasets import MNIST
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
def show_images(images):
images = np.reshape(images,[images.shape[0], -1])
sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
# images.shape[0]): Height images.shape[1]: Width
fig = plt.figure(figsize = [sqrtn,sqrtn])
gs = gridspec.GridSpec(sqrtn,sqrtn)
gs.update(wspace= 0.05,hspace=0.05)
# images: A batchsize All: images.shape[0]
# images.shape[1]: --> sqrtimg,sqrtimg
for i, img in enumerate(images):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(img.reshape([sqrtimg,sqrtimg]))
return
def preprocess_img(x):
# torchvision.transforms as tfs
# Transformed into Tensor
x = tfs.ToTensor()(x)
return (x - 0.5) / 0.5
def deprocess_img(x):
return (x + 1.0) / 2.0
class ChunkSampler(sampler.Sampler):
def __init__(self, num_samples, start=0):
self.num_samples = num_samples
self.start = start
def __iter__(self):
return iter(range(self.start, self.start + self.num_samples))
def __len__(self):
return self.num_samples
NUM_TRAIN = 50000
NUM_VAL = 5000
NOISE_DIM = 96
batch_size = 128
train_set = MNIST('./mnist data', train=True, download=False, transform=preprocess_img)
# NUM_TRAIN/num_samples 0/start
# self.start, self.start + self.num_samples ---> (0,NUM_TRAIN)--->(0,50000)
train_data = DataLoader(train_set, batch_size=batch_size, sampler=ChunkSampler(NUM_TRAIN, 0))
val_set = MNIST('./mnist data',train = True, download = False, transform=preprocess_img)
# start: NUM_TRAIN num_samples:NUM_VAL
val_data = DataLoader(val_set,batch_size=batch_size,sampler=ChunkSampler(NUM_VAL, NUM_TRAIN))
print('OK.........................\n')
# train_data: len: 50000
# train_data.__iter__().next()[0]: (128(batch_size),784)
############## Testing
# img = deprocess_img(train_data.__iter__().next()[0].view(batch_size,784)).numpy();
# print(img.shape) # (128, 784)
imgs = deprocess_img(train_data.__iter__().next()[0].view(batch_size,784)).numpy().squeeze()
# print(imgs.shape) # (128, 784)
show_images(imgs)
# Discriminator network:
def discriminator():
net = nn.Sequential(
nn.Linear(784, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1)
)
return net
# NOISE_DIM: 96
def generator(noise_dim = NOISE_DIM):
net = nn.Sequential(
nn.Linear(noise_dim, 1024),
nn.ReLU(True),
nn.Linear(1024, 1024),
nn.ReLU(True),
nn.Linear(1024, 784),
nn.Tanh()
)
return net
bce_loss = nn.BCEWithLogitsLoss()
# Discriminator_Loss
def discriminator_loss(logits_real, logits_fake):
size = logits_fake.shape[0]
# batch_size images
# one column
true_labels = Variable(torch.ones(size, 1)).float()
# false_labels
false_labels = Variable(torch.zeros(size, 1)).float()
# logits_real--> true_labels
# logits_fake-->false_labels
loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake,false_labels)
return loss
# Generator_Loss
def generator_loss(logits_fake):
size = logits_fake.shape[0]
true_labels = Variable(torch.ones(size,1)).float()
# logits_fake--> true_labels
loss = bce_loss(logits_fake,true_labels)
return loss
def get_optimizer(net):
#
optimizer = torch.optim.Adam(net.parameters(),lr=3e-4,betas=(0.5, 0.999))
return optimizer
def train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250,
noise_size=96, num_epochs=10):
iter_count = 0
for epoch in range(num_epochs):
for x, _ in train_data:
bs = x.shape[0]
real_data = Variable(x).view(bs, -1)
# To the Discriminator, the real_data related to the prediction: logits_real
logits_real = D_net(real_data)
# dims: (bs, noise_size=96)
sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5
g_fake_seed = Variable(sample_noise)
/*
开发不易,整理也不易,如需要详细的说明文档和程序,以及完整的数据集,训练好的模型,或者进一步开发,
可加作者新联系方式咨询,WX:Q3101759565,QQ:3101759565
*/
if(iter_count%show_every == 0):
# d_total_error.data[0] g_error.data[0]
print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.item(), g_error.item()))
#
imgs_numpy = deprocess_img(fake_images.data.numpy())
show_images(imgs_numpy[0:4])
plt.show()
iter_count += 1
D = discriminator()
G = generator()
D_optim = get_optimizer(D)
G_optim = get_optimizer(G)
train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss,show_every=250,noise_size=96, num_epochs=10)