import argparse
import os
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch
from models import Generator, Discriminator, weights_init_normal
os.makedirs("images", exist_ok=True)
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--n_classes", type=int, default=10, help="number of classes for dataset")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
parser.add_argument("--checkpoint_interval", type=int, default=5000, help="interval between model checkpoints")
opt = parser.parse_args()
print(opt)
cuda = True if torch.cuda.is_available() else False
# Loss functions
adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss()
# Initialize generator and discriminator
generator = Generator(opt)
discriminator = Discriminator(opt)
if cuda:
generator.cuda()
discriminator.cuda()
adversarial_loss.cuda()
auxiliary_loss.cuda()
# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)
# Configure data loader
os.makedirs("data/mnist", exist_ok=True)
dataloader = DataLoader(
datasets.MNIST(
"data/mnist",
train=True,
download=True,
transform=transforms.Compose(
[transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
),
),
batch_size=opt.batch_size,
shuffle=True,
)
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor
def sample_image(n_row, batches_done):
"""Saves a grid of generated digits ranging from 0 to n_classes"""
# Sample noise
z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))
# Get labels ranging from 0 to n_classes for n rows
labels = np.array([num for _ in range(n_row) for num in range(n_row)])
labels = Variable(LongTensor(labels))
gen_imgs = generator(z, labels)
save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)
# ----------
# Training
# ----------
for epoch in range(opt.n_epochs):
for i, (imgs, labels) in enumerate(dataloader):
batch_size = imgs.shape[0]
# Adversarial ground truths
valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)
fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)
# Configure input
real_imgs = Variable(imgs.type(FloatTensor))
labels = Variable(labels.type(LongTensor))
# -----------------
# Train Generator
# -----------------
optimizer_G.zero_grad()
# Sample noise and labels as generator input
z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))
gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size)))
# Generate a batch of images
gen_imgs = generator(z, gen_labels)
# Loss measures generator's ability to fool the discriminator
validity, pred_label = discriminator(gen_imgs)
g_loss = 0.5 * (adversarial_loss(validity, valid) + auxiliary_loss(pred_label, gen_labels))
g_loss.backward()
optimizer_G.step()
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Loss for real images
real_pred, real_aux = discriminator(real_imgs)
d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2
# Loss for fake images
fake_pred, fake_aux = discriminator(gen_imgs.detach())
d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, gen_labels)) / 2
# Total discriminator loss
d_loss = (d_real_loss + d_fake_loss) / 2
# Calculate discriminator accuracy
pred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0)
gt = np.concatenate([labels.data.cpu().numpy(), gen_labels.data.cpu().numpy()], axis=0)
d_acc = np.mean(np.argmax(pred, axis=1) == gt)
d_loss.backward()
optimizer_D.step()
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]"
% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), 100 * d_acc, g_loss.item())
)
batches_done = epoch * len(dataloader) + i
if batches_done % opt.sample_interval == 0:
sample_image(n_row=10, batches_done=batches_done)
if batches_done % opt.checkpoint_interval == 0:
torch.save(generator.state_dict(), "generator_%d.pth" % batches_done)
torch.save(discriminator.state_dict(), "discriminator_%d.pth" % batches_done)
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
# ACGAN生成对抗网络训练Pytorch代码 生成指定数字手写数字图片 1. 包含ACGAN生成器和判别器模型定义代码; 2. 包含训练代码,快速上手,简洁易懂; 3. 包含生成指定数字图片的代码; 3. 包含训练5000batch的权重文件; 4. 包含训练过程生成图片样例。
资源推荐
资源详情
资源评论
收起资源包目录
ACGAN生成对抗网络训练Pytorch代码 生成指定数字手写数字图片.zip (20个子文件)
acgan_pytorch
discriminator_5000.pth 413KB
generator_5000.pth 4.02MB
models.py 3KB
sample.py 1KB
images
2800.png 91KB
3600.png 89KB
400.png 93KB
4800.png 86KB
0.png 140KB
2000.png 97KB
4400.png 87KB
800.png 121KB
5600.png 83KB
3200.png 93KB
1600.png 102KB
1200.png 108KB
2400.png 94KB
5200.png 85KB
4000.png 90KB
train.py 6KB
共 20 条
- 1
两只程序猿
- 粉丝: 381
- 资源: 159
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- CardExpiredException解决方案(亲测可用).md
- python《Hybrid-SORT-多目标跟踪器(弱线索对在线多目标跟踪)》+项目源码+文档说明
- aspose:word,pdf,ppt
- 个人信用报告690428.zip
- 植物大战僵尸射击版v.0.3 双端安装程序
- 【重磅,更新!】中国各省水资源分类统计数据(2003-2022年)
- 富士施乐打印机驱动下载 适用机型:FujiXerox DocuPrint M375 df、M378 d、M378 df
- python-图片批量保存脚本
- aspose:word,pdf,ppt
- 中国2005-2021年大气污染物和二氧化碳排放趋势数据集【重磅,更新!】
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功
- 1
- 2
- 3
- 4
前往页