import argparse
import os
import numpy as np
import math
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch
os.makedirs("images", exist_ok=True)
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=100, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=128, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
opt = parser.parse_args()
print(opt)
img_shape = (opt.channels, opt.img_size, opt.img_size)
cuda = True if torch.cuda.is_available() else False
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*block(opt.latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), *img_shape)
return img
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
# Loss function
adversarial_loss = torch.nn.BCELoss()
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()
if cuda:
generator.cuda()
discriminator.cuda()
adversarial_loss.cuda()
# Configure data loader
os.makedirs("./data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
datasets.MNIST(
"./data/mnist",
train=True,
download=True,
transform=transforms.Compose(
[transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
),
),
batch_size=opt.batch_size,
shuffle=True,
)
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
# ----------
# Training
# ----------
for epoch in range(opt.n_epochs):
for i, (imgs, _) in enumerate(dataloader):
# Adversarial ground truths
valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)
# Configure input
real_imgs = Variable(imgs.type(Tensor))
# -----------------
# Train Generator
# -----------------
optimizer_G.zero_grad()
# Sample noise as generator input
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
# Generate a batch of images
gen_imgs = generator(z)
# Loss measures generator's ability to fool the discriminator
g_loss = adversarial_loss(discriminator(gen_imgs), valid)
g_loss.backward()
optimizer_G.step()
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Measure discriminator's ability to classify real from generated samples
real_loss = adversarial_loss(discriminator(real_imgs), valid)
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
)
batches_done = epoch * len(dataloader) + i
if batches_done % opt.sample_interval == 0:
save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)
没有合适的资源?快使用搜索试试~ 我知道了~
第八章:对抗生成网络11111
共131个文件
png:119个
gz:4个
pt:2个
需积分: 5 0 下载量 113 浏览量
2022-11-23
19:26:33
上传
评论
收藏 35.28MB ZIP 举报
温馨提示
第八章:对抗生成网络11111
资源推荐
资源详情
资源评论
收起资源包目录
第八章:对抗生成网络11111 (131个子文件)
train-images-idx3-ubyte.gz 9.45MB
t10k-images-idx3-ubyte.gz 1.57MB
train-labels-idx1-ubyte.gz 28KB
t10k-labels-idx1-ubyte.gz 4KB
0.png 48KB
800.png 42KB
400.png 40KB
1600.png 32KB
1200.png 31KB
3200.png 29KB
2800.png 29KB
4800.png 28KB
6400.png 28KB
4000.png 27KB
3600.png 27KB
4400.png 26KB
2400.png 26KB
7600.png 25KB
6000.png 25KB
2000.png 24KB
5600.png 24KB
10800.png 24KB
5200.png 22KB
9600.png 21KB
12400.png 21KB
QQ截图20191111121122.png 21KB
8000.png 20KB
10000.png 20KB
9200.png 20KB
13600.png 20KB
7200.png 19KB
10400.png 19KB
6800.png 19KB
8400.png 19KB
14000.png 18KB
16800.png 18KB
16400.png 18KB
17200.png 18KB
14400.png 18KB
12000.png 18KB
11600.png 17KB
14800.png 17KB
15200.png 17KB
40000.png 16KB
45200.png 16KB
11200.png 16KB
20800.png 16KB
8800.png 16KB
46000.png 15KB
13200.png 15KB
12800.png 15KB
19200.png 15KB
38400.png 15KB
18000.png 15KB
25200.png 15KB
29200.png 15KB
21600.png 15KB
26400.png 15KB
15600.png 15KB
24800.png 15KB
23600.png 15KB
39200.png 14KB
20400.png 14KB
22000.png 14KB
44000.png 14KB
18400.png 14KB
21200.png 14KB
28000.png 14KB
35600.png 14KB
42400.png 14KB
23200.png 14KB
27600.png 14KB
33600.png 14KB
16000.png 14KB
29600.png 14KB
24400.png 13KB
22400.png 13KB
19600.png 13KB
34800.png 13KB
20000.png 13KB
18800.png 13KB
36000.png 13KB
17600.png 13KB
30400.png 13KB
36400.png 13KB
37600.png 13KB
43200.png 13KB
42800.png 13KB
24000.png 13KB
30000.png 13KB
34000.png 13KB
44400.png 13KB
26800.png 13KB
22800.png 13KB
46400.png 13KB
32800.png 13KB
45600.png 13KB
30800.png 12KB
28400.png 12KB
31600.png 12KB
共 131 条
- 1
- 2
资源评论
巴黎左岸°C
- 粉丝: 0
- 资源: 38
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功