import argparse
import os
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
from models import WGAN
import torch
os.makedirs("images", exist_ok=True)
# 传入参数
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.00005, help="learning rate")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--n_critic", type=int, default=5, help="number of training steps for discriminator per iter")
parser.add_argument("--clip_value", type=float, default=0.01, help="lower and upper clip value for disc. weights")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
parser.add_argument("--checkpoint_interval", type=int, default=5000, help="interval between model checkpoints")
args = parser.parse_args()
print(args)
img_shape = (args.channels, args.img_size, args.img_size)
cuda = True if torch.cuda.is_available() else False
# 初始化WGAN模型
model = WGAN(args)
generator = model.generator
discriminator = model.discriminator
if cuda:
generator.cuda()
discriminator.cuda()
# 载入MNIST数据集
os.makedirs("data/mnist", exist_ok=True)
dataloader = DataLoader(
datasets.MNIST("data/mnist", train=True, download=True,
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])]),
),
batch_size=args.batch_size,
shuffle=True,
)
# 优化器
optimizer_G = torch.optim.RMSprop(generator.parameters(), lr=args.lr)
optimizer_D = torch.optim.RMSprop(discriminator.parameters(), lr=args.lr)
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
# 训练
batches_done = 0
for epoch in range(args.n_epochs):
for i, (imgs, _) in enumerate(dataloader):
# Configure input
real_imgs = Variable(imgs.type(Tensor))
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Sample noise as generator input
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], args.latent_dim))))
# Generate a batch of images
fake_imgs = generator(z).detach()
# Adversarial loss
loss_D = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs))
loss_D.backward()
optimizer_D.step()
# Clip weights of discriminator
for p in discriminator.parameters():
p.data.clamp_(-args.clip_value, args.clip_value)
# Train the generator every n_critic iterations
if i % args.n_critic == 0:
# -----------------
# Train Generator
# -----------------
optimizer_G.zero_grad()
# Generate a batch of images
gen_imgs = generator(z)
# Adversarial loss
loss_G = -torch.mean(discriminator(gen_imgs))
loss_G.backward()
optimizer_G.step()
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
% (epoch, args.n_epochs, batches_done % len(dataloader), len(dataloader), loss_D.item(), loss_G.item())
)
if batches_done % args.sample_interval == 0:
save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)
if batches_done % args.checkpoint_interval == 0:
torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')
batches_done += 1
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
# WGAN生成对抗网络训练Pytorch代码 使用MNIST数据集生成数字图片 1. 完成了WGAN生成器和判别器的定义代码; 2. 包含使用MNIST训练集训练WGAN的代码,简洁易懂; 3. 包含使用训练完的生成器模型生成数字图片的代码; 4. 无需另外下载数据集,使用pytorch载入MNIST,首次运行代码自动下载; 5. 包含训练45000batch的模型权重文件;包含该次训练过程的生成图片样例。
资源推荐
资源详情
资源评论
收起资源包目录
WGAN生成对抗网络训练Pytorch代码 使用MNIST数据集生成数字图片.zip (17个子文件)
wgan_pytorch
models.py 2KB
discriminator.pth 2.04MB
sample.py 802B
generator.pth 5.78MB
images
47200.png 18KB
46400.png 19KB
0.png 47KB
48400.png 18KB
48800.png 19KB
47600.png 19KB
48000.png 20KB
46000.png 20KB
44800.png 19KB
45600.png 19KB
46800.png 19KB
45200.png 20KB
train.py 4KB
共 17 条
- 1
资源评论
- qunimabidashabi2023-09-18这个资源对我启发很大,受益匪浅,学到了很多,谢谢分享~
- qq_123062023-09-12资源很不错,内容和描述一致,值得借鉴,赶紧学起来!
两只程序猿
- 粉丝: 338
- 资源: 158
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功