import argparse
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
import torch
"""
先训练生成器,在训练判别器
"""
from model import Generator,Discriminator
parser = argparse.ArgumentParser() #创建一个参数对象
#调用 add_argument() 方法给 ArgumentParser对象添加程序所需的参数信息
parser.add_argument("--n_epochs", type=int, default=10, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=128, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
opt = parser.parse_args() # parse_args()返回我们定义的参数字典
print(opt)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('device',device)
transforms=transforms.Compose(
[
transforms.Resize(opt.img_size),
transforms.ToTensor(),
transforms.Normalize([0.5],[0.5]) ##均值,标准差
]
)
train_datasets=datasets.MNIST(root='./',train=True,download=True,transform=transforms)
# lenth = 60000
# train_datasets, _ = torch.utils.data.random_split(train_datasets, [lenth, len(train_datasets) - lenth])
# test_datasets=datasets.MNIST(root='./',train=False,download=True,transform=transforms)
print('训练集的数量',len(train_datasets))
# print('测试集的数量',len(test_datasets))
train_loader = DataLoader(train_datasets, batch_size=opt.batch_size, shuffle=True)
# test_loader = DataLoader(test_datasets, batch_size=1, shuffle=False)
# 损失函数
adversarial_loss = torch.nn.BCELoss().to(device)
# 定义网络结构
generator = Generator().to(device)
discriminator = Discriminator().to(device)
# 优化器的设置
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) #Betas是动量梯度的下降
for epoch in range(opt.n_epochs):
total_d_loss=0
total_g_loss=0
#开始训练
for i, (img, _) in enumerate(train_loader):
##将图片变为1维数据
real_img=img.view(img.size()[0],-1)
#定义真实的图片label为1
real_label=torch.ones(img.size()[0],1)
#定义假的图片label为0
fake_label=torch.zeros(img.size()[0],1)
if epoch%1==0:
# 训练生成器
# 原理:目的是希望生成的假图片可以被判别器判断为真的图片
# 在此过程中,将判别器固定,将假的图片传入判别器的结果real_label对应
# 使得生成的图片让判别器以为是真的。这样就达到了对抗的目的
# 计算假图片的损失
noise = torch.randn(img.size()[0], opt.latent_dim) # 随机生成一些噪声
fake_img = generator(noise).detach() ##随机噪声输入到生成器中,得到一幅假的图片
output = discriminator(fake_img) ##经过判别器得到的结果
g_loss = adversarial_loss(output, real_label)
total_g_loss += g_loss.data.item()
# 反向传播 更新参数
optimizer_G.zero_grad()
g_loss.backward()
optimizer_G.step()
if epoch%5==0:
# 判别器训练
# 将真实图片输入到判别器中
real_out = discriminator(real_img)
# 得到真实图片的loss
d_loss_real = adversarial_loss(real_out, real_label)
# 计算假图片的损失
noise = torch.randn(img.size()[0], opt.latent_dim) ##随机生成一些噪声,
##将随机噪声放入生成网络中,生成一张假的图片
# 避免梯度传到生成器,这里生成器不用更新,detach分离
fake_img = generator(noise).detach()
# 判别器判断假的图片
fake_out = discriminator(fake_img)
# 得到假图片的loss
d_loss_fake = adversarial_loss(fake_out, fake_label)
# 得到假图片的判别值,对于判别器来讲,假图片的d_loss_fake越接近越好
d_loss = d_loss_real + d_loss_fake ##损失包含判真损失和判假损失
total_d_loss += d_loss.data.item()
optimizer_D.zero_grad() # 反向传播之前,将梯度归0
d_loss.backward() # 将误差反向传播
optimizer_D.step() # 更新参数
#打印每个epoch 的损失
print('Epoch[{}/{}],d_loss:{:.6f},g_loss:{:.6f}'.format(epoch,opt.n_epochs,total_d_loss/len(train_loader),total_g_loss/len(train_loader)))
torch.save(generator,'./gen1.pth')
torch.save(discriminator,'./dis1.pth')