216~220页代码:
train.py:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
from torchvision.utils import save_image
import os
from utils import *
from visdom import Visdom
if not os.path.exists('./sample'):
os.mkdir('./sample')
batch_size = 128
num_epoch = 50
z_dimension = 100 # noise dimension
img_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
mnist = datasets.MNIST(r'..\..\data', transform=img_transform)
dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True)
class Discriminator(nn.Module):
def __init__(self):
super(discriminator, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(1, 32, 5, padding=2), # batch, 32, 28, 28
nn.LeakyReLU(0.2, True),
nn.AvgPool2d(2, stride=2), # batch, 32, 14, 14
)
self.conv2 = nn.Sequential(
nn.Conv2d(32, 64, 5, padding=2), # batch, 64, 14, 14
nn.LeakyReLU(0.2, True),
nn.AvgPool2d(2, stride=2) # batch, 64, 7, 7
)
self.fc = nn.Sequential(
nn.Linear(64*7*7, 1024),
nn.LeakyReLU(0.2, True),
nn.Linear(1024, 1),
nn.Sigmoid()
)
def forward(self, x):
'''
x: batch, width, height, channel=1
'''
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
class Generator(nn.Module):
def __init__(self, input_size, num_feature):
super(generator, self).__init__()
self.fc = nn.Linear(input_size, num_feature) # batch, 3136=1x56x56
self.br = nn.Sequential(
nn.BatchNorm2d(1),
nn.ReLU(True)
)
self.downsample1 = nn.Sequential(
nn.Conv2d(1, 50, 3, stride=1, padding=1), # batch, 50, 56, 56
nn.BatchNorm2d(50),
nn.ReLU(True)
)
self.downsample2 = nn.Sequential(
nn.Conv2d(50, 25, 3, stride=1, padding=1), # batch, 25, 56, 56
nn.BatchNorm2d(25),
nn.ReLU(True)
)
self.downsample3 = nn.Sequential(
nn.Conv2d(25, 1, 2, stride=2), # batch, 1, 28, 28
nn.Tanh()
)
def forward(self, x):
x = self.fc(x)
x = x.view(x.size(0), 1, 56, 56)
x = self.br(x)
x = self.downsample1(x)
x = self.downsample2(x)
x = self.downsample3(x)
return x
if __name__=="__main__":
D = Discriminator().cuda() # discriminator model
G = Generator(z_dimension, 3136).cuda() # generator model
criterion = nn.BCELoss() # binary cross entropy
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)
# train
viz=Visdom(env='GAN')
nCount=0
for epoch in range(num_epoch):
for img, _ in dataloader:
real_img = img.cuda()
real_label = torch.ones([img.size(0),1]).cuda()
fake_label = torch.zeros([img.size(0),1]).cuda()
# =================train D lock G
# loss of real
real_out = D(real_img) # hope close to 1
d_loss_real = criterion(real_out, real_label)
# loss of fake
z = torch.randn(img.size(0), z_dimension).cuda()
fake_img = G(z)
fake_out = D(fake_img.detach()) # hope close to 0
d_loss_fake = criterion(fake_out, fake_label)
# optimize D
d_loss = 0.5*(d_loss_real + d_loss_fake)
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
# ===============train G lock D
# loss of fake
# z = torch.randn(batch_size, z_dimension).cuda()
output = D(fake_img)
g_loss = criterion(output, real_label)
# optimize G
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
nCount+=1
if nCount % 100 == 0:
print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f}'
.format(epoch, num_epoch, d_loss.data, g_loss.data))
# visible
LossDV=d_loss.cpu().detach().numpy()
LossGV=g_loss.cpu().detach().numpy()
viz.line(Y=[LossDV],X=[nCount//100],win='Discriminator',update='append',opts=dict(title='D Train Loss',
xlabel='step/100',ylabel='loss'))
viz.line(Y=[LossGV],X=[nCount//100],win='Generator',update='append',opts=dict(title='G Train Loss',
xlabel='step/100',ylabel='loss'))
for n in range(6):
FakeV=to_img(fake_img.cpu().data)
viz.image(FakeV[n], win='Generate[%d]' % n)
for n in range(6):
imgV=img.cpu().data
viz.image(imgV[n], win='Real[%d]' % n)
fake_images = to_img(fake_img.cpu().data)
save_image(fake_images, './sample/fake_images-{}.png'.format(epoch+1))
torch.save({'epoch':epoch,'state_dict':G.state_dict()}, 'ckpt/generator.pth')
torch.save({'epoch':epoch,'state_dict':D.state_dict()}, 'ckpt/discriminator.pth')
test.py:
import torch,os
from torchvision.utils import save_image
from train import generator
from utils import *
if not os.path.exists('./output'):
os.mkdir('./output')
G_model=generator(100,3136)
ckpt=torch.load(r'ckpt\generator.pth')
G_model.load_state_dict(ckpt['state_dict'])
G_model.to(torch.device('cuda'))
G_model.eval()
g_num=150
batch=128
for i in range(g_num):
z=torch.randn([batch,100]).cuda()
fake=G_model(z)
out=to_img(fake)
save_image(out,r'output\%d.jpg' % i)
print(i)
utils.py:
import torchvision.transforms as transforms
from torchvision.utils import make_grid
from torch.utils import data
from PIL import Image
from glob import glob
import os
from sklearn.model_selection import train_test_split
import numpy as np
from matplotlib import pyplot as plt
class Inpaint(data.Dataset):
def __init__(self, datatxt, transform=None):
super(Inpaint, self).__init__()
with open(datatxt) as fr:
lines=fr.readlines()
self.Names=[]
self.transform=transform
for line in lines:
l=line.rstrip('\n')
self.Names.append(l)
def __getitem__(self, index):
fName=self.Names[index]
img=Image.open(fName).convert('RGB')
if self.transform is not None:
img=self.transform(img)
return img
def __len__(self):
return len(self.Names)
def describe(dFold):
fList=glob(os.path.join(dFold,'*.jpg'))
train,test = train_test_split(fList, test_size=0.33)
with open('CelebAtrain.txt','w') as fw:
for l in train:
fw.write(l+'\n')
with open('CelebAtest.txt','w') as fw:
for l in test:
fw.write(l+'\n')
print('OK\n')
def to_input(img):
# img: 0~1
# output: -1~1
return img*2-1
def to_img(inp):
# input: -1~1
# output: 0~1
img=0.5*(inp+1)
return img.clamp(0.,1.)
if __name__ == '__main__':
# visualize
if (not os.path.exists('CelebAtrain.txt')) or (not os.path.exists('CelebAtest.txt')):
describe(r'E:\datasets\img_align_celeba')
trans = transforms.Compose([transforms.RandomHorizontalFlip(),
transforms.Resize((100,150)),
transforms.CenterCrop(64),
[深度卷积神经网络原理与实践][周浦城 等][程序源代码]
版权申诉
142 浏览量
2022-04-14
10:56:34
上传
评论
收藏 23KB RAR 举报
人工智能教学实践
- 粉丝: 531
- 资源: 253
最新资源
- TG-2024-05-23-204718255.mp4
- 候志强@181 5428 8938_20240420112107.amr
- spispispispispi
- 实验二:IP协议分析.zip
- 驱动代码驱动代码驱动代码驱动代码
- SVID_20240523_141155_1.mp4
- Code for the complete guide to tkinter tutorial
- 关于百货中心供应链管理系统.zip
- SimpleFolderIcon-master 修改Unity的Project下的文件夹图标
- A python Tkinter widget to display tile based maps
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈