没有合适的资源?快使用搜索试试~ 我知道了~
资源推荐
资源详情
资源评论













Pytorch使用使用MNIST数据集实现数据集实现CGAN和生成指定的数字方式和生成指定的数字方式
今天小编就为大家分享一篇Pytorch使用MNIST数据集实现CGAN和生成指定的数字方式,具有很好的参考价
值,希望对大家有所帮助。一起跟随小编过来看看吧
CGAN的全拼是Conditional Generative Adversarial Networks,条件生成对抗网络,在初始GAN的基础上增加了图片的相应信
息。
这里用传统的卷积方式实现CGAN。
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from torch import optim
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from torch.autograd import Variable
import pickle
import copy
import matplotlib.gridspec as gridspec
import os
def save_model(model, filename): #保存为CPU中可以打开的模型
state = model.state_dict()
x=state.copy()
for key in x:
x[key] = x[key].clone().cpu()
torch.save(x, filename)
def showimg(images,count):
images=images.to('cpu')
images=images.detach().numpy()
images=images[[6, 12, 18, 24, 30, 36, 42, 48, 54, 60, 66, 72, 78, 84, 90, 96]]
images=255*(0.5*images+0.5)
images = images.astype(np.uint8)
grid_length=int(np.ceil(np.sqrt(images.shape[0])))
plt.figure(figsize=(4,4))
width = images.shape[2]
gs = gridspec.GridSpec(grid_length,grid_length,wspace=0,hspace=0)
for i, img in enumerate(images):
ax = plt.subplot(gs[i])
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(img.reshape(width,width),cmap = plt.cm.gray)
plt.axis('off')
plt.tight_layout()
# plt.tight_layout()
plt.savefig(r'./CGAN/images/%d.png'% count, bbox_inches='tight')
def loadMNIST(batch_size): #MNIST图片的大小是28*28
trans_img=transforms.Compose([transforms.ToTensor()])
trainset=MNIST('./data',train=True,transform=trans_img,download=True)
testset=MNIST('./data',train=False,transform=trans_img,download=True)
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
trainloader=DataLoader(trainset,batch_size=batch_size,shuffle=True,num_workers=10)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=10)
return trainset,testset,trainloader,testloader
class discriminator(nn.Module):
def __init__(self):
super(discriminator,self).__init__()
self.dis=nn.Sequential(
nn.Conv2d(1,32,5,stride=1,padding=2),
nn.LeakyReLU(0.2,True),
nn.MaxPool2d((2,2)),
nn.Conv2d(32,64,5,stride=1,padding=2),
nn.LeakyReLU(0.2,True),
nn.MaxPool2d((2,2))
)
self.fc=nn.Sequential(
nn.Linear(7 * 7 * 64, 1024),
nn.LeakyReLU(0.2, True),
nn.Linear(1024, 10),
资源评论

- qq_409231212021-07-23资源是PDF格式的

weixin_38610682
- 粉丝: 6
- 资源: 878
上传资源 快速赚钱
我的内容管理 展开
我的资源 快来上传第一个资源
我的收益
登录查看自己的收益我的积分 登录查看自己的积分
我的C币 登录后查看C币余额
我的收藏
我的下载
下载帮助


安全验证
文档复制为VIP权益,开通VIP直接复制
