import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import transforms
import torchvision.datasets as dst
from torchvision.utils import save_image
from model import VAE
# 训练超参数配置
epochs = 50
batch_size = 64
num_workers = 0
log_interval = 10
latent_dim = 32
# 损失函数
def loss_func(recon_x, x, mu, logvar):
BCE = F.binary_cross_entropy(recon_x, x, size_average=False)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
if __name__ == '__main__':
# 加载数据集
transform = transforms.Compose([transforms.ToTensor()])
data_train = dst.MNIST('MNIST_data/', train=True, transform=transform, download=True)
data_test = dst.MNIST('MNIST_data/', train=False, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(dataset=data_train, num_workers=num_workers, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=data_test, num_workers=num_workers, batch_size=batch_size, shuffle=True)
# 创建VAE模型
vae = VAE(latent_dim).cuda()
# 创建优化器
optimizer = optim.Adam(vae.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
os.makedirs('result', exist_ok=True) # 创建文件夹
# 开始训练
for epoch in range(1, epochs):
vae.train()
total_loss = 0
for i, (data, _) in enumerate(train_loader, 0):
data = Variable(data).cuda()
optimizer.zero_grad()
recon_x, mu, logvar = vae.forward(data)
loss = loss_func(recon_x, data, mu, logvar)
loss.backward()
total_loss += loss.item()
optimizer.step()
if i % log_interval == 0:
sample = Variable(torch.randn(64, latent_dim)).cuda()
sample = vae.decoder(vae.fc2(sample).view(64, 128, 7, 7)).cpu()
save_image(sample.data.view(64, 1, 28, 28),
'result/sample_' + str(epoch) + '.png')
print('Train Epoch:{} -- [{}/{} ({:.0f}%)] -- Loss:{:.6f}'.format(
epoch, i*len(data), len(train_loader.dataset),
100.*i/len(train_loader), loss.item()/len(data)))
print('====> Epoch: {} Average loss: {:.4f}'.format(
epoch, total_loss / len(train_loader.dataset)))
torch.save(vae.state_dict(), 'model.pth')
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
# Pytorch实现VAE变分自动编码器生成MNIST手写数字图像 1. VAE模型的Pytorch源码,训练后其解码器就是生成模型; 2. 在MNIST数据集上训练了50个epochs,训练过程的生成效果放在result文件夹下,训练后的模型保存为model.pth,可用于生成新的手写数字图像; 3. 训练代码会自动下载MNIST数据集,运行代码即可自行训练。
资源推荐
资源详情
资源评论
收起资源包目录
Pytorch实现VAE变分自动编码器生成MNIST手写数字图像.zip (52个子文件)
vae
result
sample_15.png 41KB
sample_33.png 41KB
sample_25.png 41KB
sample_43.png 42KB
sample_14.png 40KB
sample_41.png 41KB
sample_28.png 41KB
sample_44.png 40KB
sample_39.png 41KB
sample_18.png 41KB
sample_23.png 42KB
sample_24.png 41KB
sample_1.png 55KB
sample_13.png 40KB
sample_12.png 42KB
sample_47.png 42KB
sample_8.png 42KB
sample_48.png 40KB
sample_45.png 41KB
sample_30.png 40KB
sample_32.png 42KB
sample_26.png 42KB
sample_42.png 40KB
sample_27.png 38KB
sample_11.png 42KB
sample_31.png 40KB
sample_37.png 41KB
sample_20.png 40KB
sample_2.png 45KB
sample_40.png 40KB
sample_22.png 40KB
sample_5.png 42KB
sample_38.png 41KB
sample_29.png 40KB
sample_21.png 40KB
sample_46.png 43KB
sample_9.png 41KB
sample_7.png 42KB
sample_35.png 41KB
sample_3.png 44KB
sample_4.png 42KB
sample_34.png 42KB
sample_16.png 42KB
sample_6.png 43KB
sample_19.png 41KB
sample_10.png 42KB
sample_17.png 40KB
sample_49.png 40KB
sample_36.png 41KB
train.py 3KB
model.pth 3.91MB
model.py 2KB
共 52 条
- 1
两只程序猿
- 粉丝: 381
- 资源: 159
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- (源码)基于Spring Boot和MyBatis的社区问答系统.zip
- (源码)基于Spring Boot和WebSocket的人事管理系统.zip
- (源码)基于Spring Boot框架的云网页管理系统.zip
- (源码)基于Maude和深度强化学习的智能体验证系统.zip
- (源码)基于C语言的Papageno字符序列处理系统.zip
- (源码)基于Arduino的水质监测与控制系统.zip
- (源码)基于物联网的智能家居门锁系统.zip
- (源码)基于Python和FastAPI的Squint数据检索系统.zip
- (源码)基于Arduino的图片绘制系统.zip
- (源码)基于C++的ARMA53贪吃蛇游戏系统.zip
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功
- 1
- 2
- 3
- 4
前往页