import torch
from torch import nn
from torch import optim
from torch import autograd
from matplotlib import pyplot as plt
import numpy as np
import visdom
import random
h_dim = 400
batchsz = 512
viz = visdom.Visdom()
class Generator(nn.Module):
"""docstring for Generator"""
def __init__(self):
super(Generator, self).__init__()
self.net = nn.Sequential(
nn.Linear(2, h_dim),
nn.ReLU(True),
nn.Linear(h_dim, h_dim),
nn.ReLU(True),
nn.Linear(h_dim, h_dim),
nn.ReLU(True),
nn.Linear(h_dim, 2)
)
def forward(self, z):
output = self.net(z)
return output
class Discriminator(nn.Module):
"""docstring for Discriminator"""
def __init__(self):
super(Discriminator, self).__init__()
self.net = nn.Sequential(
nn.Linear(2, h_dim),
nn.ReLU(True),
nn.Linear(h_dim, h_dim),
nn.ReLU(True),
nn.Linear(h_dim, h_dim),
nn.ReLU(True),
nn.Linear(h_dim, 1),
nn.Sigmoid()
)
def forward(self, x):
output = self.net(x)
return output
def data_generator():
sacle = 2
centers = [
(1, 0),
(-1, 0),
(0, 1),
(1. / np.sqrt(2), 1. / np.sqrt(2)),
(1. / np.sqrt(2), -1. / np.sqrt(2)),
(-1. / np.sqrt(2), 1. / np.sqrt(2)),
(-1. / np.sqrt(2), -1. / np.sqrt(2)),
]
centers = [(sacle * x, sacle *y) for x, y in centers]
# print(centers)
while True:
dataset = []
for i in range(batchsz):
point = np.random.randn(2) * 0.02
center = random.choice(centers)
# N(0,1) + center_x1/x2
point[0] += center[0]
point[1] += center[1]
dataset.append(point)
dataset = np.array(dataset).astype(np.float32)
dataset /= 1.414
yield dataset
def generate_image(D, G, xr, epoch):
"""
Generates and saves a plot of the true distribution, the generator, and the
critic.
"""
N_POINTS = 128
RANGE = 3
plt.clf()
points = np.zeros((N_POINTS, N_POINTS, 2), dtype='float32')
points[:, :, 0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None]
points[:, :, 1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :]
points = points.reshape((-1, 2))
# (16384, 2)
# print('p:', points.shape)
# draw contour
with torch.no_grad():
points = torch.Tensor(points) # [16384, 2]
disc_map = D(points).cpu().numpy() # [16384]
x = y = np.linspace(-RANGE, RANGE, N_POINTS)
cs = plt.contour(x, y, disc_map.reshape((len(x), len(y))).transpose())
plt.clabel(cs, inline=1, fontsize=10)
# plt.colorbar()
# draw samples
with torch.no_grad():
z = torch.randn(batchsz, 2) # [b, 2]
samples = G(z).cpu().numpy() # [b, 2]
plt.scatter(xr[:, 0], xr[:, 1], c='orange', marker='.')
plt.scatter(samples[:, 0], samples[:, 1], c='green', marker='+')
viz.matplot(plt, win='contour', opts=dict(title='p(x):%d'%epoch))
def weights_init(m):
if isinstance(m, nn.Linear):
# m.weight.data.normal_(0.0, 0.02)
nn.init.kaiming_normal_(m.weight)
m.bias.data.fill_(0)
def gradient_penalty(D, xr, xf):
"""
:param D:
:param xr:
:param xf:
:return:
"""
LAMBDA = 0.3
# only constrait for Discriminator
# xf = xf.detach()
# xr = xr.detach()
# [b, 1] => [b, 2]
alpha = torch.rand(batchsz, 1)
alpha = alpha.expand_as(xr)
# interpolation
interpolates = alpha * xr + ((1 - alpha) * xf)
# set it requires gradient
interpolates.requires_grad_()
disc_interpolates = D(interpolates)
gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones_like(disc_interpolates),
create_graph=True, retain_graph=True, only_inputs=True)[0]
gp = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
return gp
def main():
torch.manual_seed(23)
np.random.seed(23)
data_iter = data_generator()
x = next(data_iter)
# print(x.shape)
# print(x)
G = Generator()
D = Discriminator()
# print(G)
# print(D)
optim_G = optim.Adam(G.parameters(), lr=5e-4, betas=(0.5, 0.9))
optim_D = optim.Adam(D.parameters(), lr=5e-4, betas=(0.5, 0.9))
viz.line([[0, 0]], [0], win='loss', opts=dict(title='loss', legend=['D', 'G']))
for epoch in range(5000):
# 1. train Discriminator first
for _ in range(5):
# 1.1 train on real data
xr = next(data_iter)
xr = torch.from_numpy(xr)
# [b,2] => [b,1]
predr = D(xr)
# max predr,so min lossr
lossr = -predr.mean()
# 1.2 train on fake data
# [b,2]
z = torch.randn(batchsz, 2)
xf = G(z).detach()
# min predf
predf = D(xf)
lossf = predf.mean()
# 1.3 gradient penalty
gp = gradient_penalty(D, xr, xf.detach())
# aggregate all
loss_D = lossr + lossf + gp
# optimize
optim_D.zero_grad()
loss_D.backward()
optim_D.step()
# 2. train Generator
z = torch.randn(batchsz, 2)
xf = G(z)
predf = D(xf)
# max predf, so min loss_G
loss_G = -predf.mean()
# optimize
optim_G.zero_grad()
loss_G.backward()
optim_G.step()
if epoch % 100 == 0:
viz.line([[loss_D.item(), loss_G.item()]], [epoch], win='loss', update='append')
print("loss_D:", loss_D.item(), "\tloss_G:", loss_G.item())
generate_image(D, G, xr, epoch)
if __name__ == '__main__':
main()
没有合适的资源?快使用搜索试试~ 我知道了~
深度学习 pytorch demo.zip
共32个文件
py:24个
xml:5个
gitignore:2个
需积分: 5 0 下载量 85 浏览量
2024-05-08
10:03:26
上传
评论
收藏 28KB ZIP 举报
温馨提示
深度学习 pytorch demo.zip
资源推荐
资源详情
资源评论
收起资源包目录
深度学习 pytorch demo.zip (32个子文件)
content
Cifar10
main.py 3KB
lenet5.py 2KB
cifar10_tutorial_use_resnet18.py 6KB
resnet.py 3KB
cifar10_tutorial.py 5KB
Mnist
utils.py 799B
mnist_MLP.py 2KB
mnist_MLP_train_val_test.py 4KB
.idea
vcs.xml 183B
misc.xml 192B
Mnist.iml 284B
inspectionProfiles
Project_Default.xml 1022B
profiles_settings.xml 174B
modules.xml 262B
.gitignore 47B
mnist_MLP_with_visdom.py 3KB
mnist_simple_w_and_b.py 2KB
mnist_simple_MLP.py 3KB
RNN
rnn.py 821B
rnn_pred.py 2KB
GAN
gan.py 5KB
wgan.py 6KB
test.py 227B
.gitignore 107B
Pokemon
utils.py 629B
train_scratch.py 3KB
resnet.py 3KB
train_transfer.py 4KB
pokemon.py 5KB
AutoEncoder
main.py 2KB
autoencoder.py 959B
vae.py 1KB
共 32 条
- 1
资源评论
生瓜蛋子
- 粉丝: 3829
- 资源: 6140
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功