import os
import torch
from torch import nn,optim
from torch.autograd import Variable
from torchvision import transforms,datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision.utils import save_image
# Dataset Load
def get_data():
data_tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
train_data = datasets.MNIST(root='./mnist data', train=True, transform=data_tf, download=False)
train_loader = DataLoader(train_data, shuffle = True, batch_size = batch_size)
return train_loader
class VAE(nn.Module):
def __init__(self):
super(VAE,self).__init__()
self.fc1 = nn.Linear(784,400)
self.fc21 = nn.Linear(400,20) # Mean
self.fc22 = nn.Linear(400,20) # Var
self.fc3 = nn.Linear(20,400)
self.fc4 = nn.Linear(400,784)
def encoder(self,x):
# Transform into the hidden vector h1
# h2: [batch_size,400]
h1 = F.relu(self.fc1(x))
# Extract the mu value [batch_size, 20]
mu = self.fc21(h1)
# Extract the logvar value [batch_size, 20]
logvar = self.fc22(h1)
return mu,logvar
def decoder(self,z):
# self.fc3: Transformed into 400-dims
h3 = F.relu(self.fc3(z))
# self.fc4: Transformed into 784 dims
x = F.tanh(self.fc4(h3))
return x
/*
开发不易,整理也不易,如需要详细的说明文档和程序,以及完整的数据集,训练好的模型,或者进一步开发,
可加作者新联系方式咨询,WX:Q3101759565,QQ:3101759565
*/
def loss_function(recon_x,x,mu,log_var):
# The Decode Network output: recon_x (The generated Image)
# The original output: x Note: Here, x is a batch_size Hence, reduction='sum
MSE = reconstruction_function(recon_x,x)
# mu, sigma: (batch_size,20)
KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
# 20-dimensions to one-dimension
KLD = torch.sum(KLD_element).mul_(-0.5)
return MSE + KLD
def to_Img(x):
x = (x+1.)*0.5
x = x.clamp(0, 1)
x = x.view(x.size(0), 1, 28, 28)
return x
batch_size = 128
lr = 1e-3
epoches = 10
model = VAE()
train_data = get_data()
reconstruction_function = nn.MSELoss(reduction='sum')
optimizer = optim.Adam(model.parameters(), lr=lr)
# Perform Traning
for epoch in range(epoches):
for img, _ in train_data:
# Input: [batch_size, 784]
img = img.view(img.size(0), -1)
img = Variable(img)
# Forward
output, mu, logvar = model(img)
loss = loss_function(output,img,mu,logvar)/img.size(0)
# Backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("epoch=", epoch, loss.data.float())
if (epoch+1) % 10 == 0:
print("epoch = {}, loss is {}".format(epoch+1, loss.data))
torch.save(model,'./vae.pth')
import matplotlib.pyplot as plt
img = torch.normal(0,1,size=(1,20))
img = Variable(torch.FloatTensor(img))
# print(img.shape) # torch.Size([1, 20]) Hidden vector that is meet the normal distribution
# (1,784)
decode = model.decoder(img)
# print(decode.shape) # torch.Size([1, 784])
decode_img = to_Img(decode)
decode_img = decode_img.squeeze()
# print(decode_img.shape)# torch.Size([28, 28])
decode_img = decode_img.data * 255
print(decode_img.shape)
plt.imshow(decode_img.numpy().astype('uint8'), cmap='gray')
plt.show()