import torch
import torch.nn.functional as nn
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)
mb_size = 64
Z_dim = 100
X_dim = mnist.train.images.shape[1]
y_dim = mnist.train.labels.shape[1]
h_dim = 128
c = 0
lr = 1e-3
def xavier_init(size):
in_dim = size[0]
xavier_stddev = 1. / np.sqrt(in_dim / 2.)
return Variable(torch.randn(*size) * xavier_stddev, requires_grad=True)
""" ==================== GENERATOR ======================== """
Wzh = xavier_init(size=[Z_dim, h_dim])
bzh = Variable(torch.zeros(h_dim), requires_grad=True)
Whx = xavier_init(size=[h_dim, X_dim])
bhx = Variable(torch.zeros(X_dim), requires_grad=True)
def G(z):
h = nn.relu(z @ Wzh + bzh.repeat(z.size(0), 1))
X = nn.sigmoid(h @ Whx + bhx.repeat(h.size(0), 1))
return X
""" ==================== DISCRIMINATOR ======================== """
Wxh = xavier_init(size=[X_dim, h_dim])
bxh = Variable(torch.zeros(h_dim), requires_grad=True)
Why = xavier_init(size=[h_dim, 1])
bhy = Variable(torch.zeros(1), requires_grad=True)
def D(X):
h = nn.relu(X @ Wxh + bxh.repeat(X.size(0), 1))
y = nn.sigmoid(h @ Why + bhy.repeat(h.size(0), 1))
return y
G_params = [Wzh, bzh, Whx, bhx]
D_params = [Wxh, bxh, Why, bhy]
params = G_params + D_params
""" ===================== TRAINING ======================== """
def reset_grad():
for p in params:
if p.grad is not None:
data = p.grad.data
p.grad = Variable(data.new().resize_as_(data).zero_())
G_solver = optim.Adam(G_params, lr=1e-3)
D_solver = optim.Adam(D_params, lr=1e-3)
ones_label = Variable(torch.ones(mb_size, 1))
zeros_label = Variable(torch.zeros(mb_size, 1))
for it in range(100000):
# Sample data
z = Variable(torch.randn(mb_size, Z_dim))
X, _ = mnist.train.next_batch(mb_size)
X = Variable(torch.from_numpy(X))
# Dicriminator forward-loss-backward-update
G_sample = G(z)
D_real = D(X)
D_fake = D(G_sample)
D_loss_real = nn.binary_cross_entropy(D_real, ones_label)
D_loss_fake = nn.binary_cross_entropy(D_fake, zeros_label)
D_loss = D_loss_real + D_loss_fake
D_loss.backward()
D_solver.step()
# Housekeeping - reset gradient
reset_grad()
# Generator forward-loss-backward-update
z = Variable(torch.randn(mb_size, Z_dim))
G_sample = G(z)
D_fake = D(G_sample)
G_loss = nn.binary_cross_entropy(D_fake, ones_label)
G_loss.backward()
G_solver.step()
# Housekeeping - reset gradient
reset_grad()
# Print and plot every now and then
if it % 1000 == 0:
print('Iter-{}; D_loss: {}; G_loss: {}'.format(it, D_loss.data.numpy(), G_loss.data.numpy()))
samples = G(z).data.numpy()[:16]
fig = plt.figure(figsize=(4, 4))
gs = gridspec.GridSpec(4, 4)
gs.update(wspace=0.05, hspace=0.05)
for i, sample in enumerate(samples):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
if not os.path.exists('out/'):
os.makedirs('out/')
plt.savefig('out/{}.png'.format(str(c).zfill(3)), bbox_inches='tight')
c += 1
plt.close(fig)
没有合适的资源?快使用搜索试试~ 我知道了~
资源推荐
资源详情
资源评论
收起资源包目录
Demo.rar (6个子文件)
Demo
MNIST_data
t10k-images-idx3-ubyte.gz 1.57MB
t10k-labels-idx1-ubyte.gz 4KB
train-images-idx3-ubyte.gz 9.45MB
train-labels-idx1-ubyte.gz 28KB
vanilla_gan
gan_pytorch.py 4KB
out
gan_tensorflow.py 4KB
共 6 条
- 1
资源评论
- m0_701871972023-02-03挺好的,拿来学习学习涨涨见识
Uingll
- 粉丝: 10
- 资源: 9
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功