import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import transforms
import torchvision.datasets as dst
from torchvision.utils import save_image
from model import VAE
# 训练超参数配置
epochs = 50
batch_size = 64
num_workers = 0
log_interval = 10
latent_dim = 32
# 损失函数
def loss_func(recon_x, x, mu, logvar):
BCE = F.binary_cross_entropy(recon_x, x, size_average=False)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
if __name__ == '__main__':
# 加载数据集
transform = transforms.Compose([transforms.ToTensor()])
data_train = dst.MNIST('MNIST_data/', train=True, transform=transform, download=True)
data_test = dst.MNIST('MNIST_data/', train=False, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(dataset=data_train, num_workers=num_workers, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=data_test, num_workers=num_workers, batch_size=batch_size, shuffle=True)
# 创建VAE模型
vae = VAE(latent_dim).cuda()
# 创建优化器
optimizer = optim.Adam(vae.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
os.makedirs('result', exist_ok=True) # 创建文件夹
# 开始训练
for epoch in range(1, epochs):
vae.train()
total_loss = 0
for i, (data, _) in enumerate(train_loader, 0):
data = Variable(data).cuda()
optimizer.zero_grad()
recon_x, mu, logvar = vae.forward(data)
loss = loss_func(recon_x, data, mu, logvar)
loss.backward()
total_loss += loss.item()
optimizer.step()
if i % log_interval == 0:
sample = Variable(torch.randn(64, latent_dim)).cuda()
sample = vae.decoder(vae.fc2(sample).view(64, 128, 7, 7)).cpu()
save_image(sample.data.view(64, 1, 28, 28),
'result/sample_' + str(epoch) + '.png')
print('Train Epoch:{} -- [{}/{} ({:.0f}%)] -- Loss:{:.6f}'.format(
epoch, i*len(data), len(train_loader.dataset),
100.*i/len(train_loader), loss.item()/len(data)))
print('====> Epoch: {} Average loss: {:.4f}'.format(
epoch, total_loss / len(train_loader.dataset)))
torch.save(vae.state_dict(), 'model.pth')
两只程序猿
- 粉丝: 382
- 资源: 159
最新资源
- 电动汽车模型的各模块的Simulink模型,包括驾驶员模块,整车控制器模块,电机模块,变速器模块,主减速器模块,车轮模块,车速模块以及BMS模块 附有说明文档,文档详细的描述了模型的建模过程及功能
- 西门子200smart与东元Teco N310变频器通讯实战程序 器件:西门子s7 200 smart PLC,东元Teco N310变频器,昆仑通态触摸屏(带以太网),中途可以加路由器
- 三菱FX3U 485ADP与东元TECO变频器N310通讯实战程序 功能:通过三菱fx3u 485ADP-MB板对东元Teco N310变频器进行modbus通讯,实现频率设定,启停控制,输出
- 【Matlab Simulink】电动汽车双向充电桩电路仿真 交流侧采用普通三相桥式变电路,SVPWM控制生成开关信号,控制系统采用电压外环电流内环控制 可实现整流,逆变以及指定功率输出,无功补偿 直
- 基于MATLAB的圆形检测算法:在MATLAB中实现的,利用图像边缘的梯度信息 进行圆形检测的算法m文件可直接运行 相比于传统的霍夫变检测圆的算法速度有极大提升
- 电动汽车充电站选址定容Matlab程序代码实现 在一定区域内的电动汽车充电站多目标规划选址定容的Matlab程序 使用PSO和Voronoi图联合求解
- 基于遗传算法的电动汽车有序充电优化调度 软件:Matlab 利用遗传算法对电动汽车有序充电进行优化;优化目标包括充电费用最低,电动汽车充到足够的电,负荷峰谷差最小 分别利用传统、精英和变异遗传算法进
- 无迹卡尔曼滤波UKF,平方根无迹卡尔曼滤波SRUKF,自适应平方根无迹卡尔曼滤波ASRUKF估算电池SOC
- 多目标粒子群算法CCHP联供综合能源系统 说明书MATLAB代码:基于多目标粒子群算法冷热电联供综合能源系统运行优化关键词:综合能源 冷热电三联供 粒子群算法 多目标优化参考文档:基于多目标算法的
- 运用Matlab,LBP分割脸部特征,从而达到识别人物面部表情的效果
- FPGA Verilog 舵机驱动代码,FPGA驱动舵机
- 西门子S7-1500PLC与西门子V90 PN伺服通讯控制项 西门子S7-1500PLC与西门子V90 PN伺服通讯控制项目程序项目程序包含S7-1500 PLC,KTP系列触摸屏,西门子V90 PN
- 碳交易机制下考虑需求响应的综合能源系统优化运行 首先,根据负荷响应特性将需求响应分为价格型和替代型 2 类,分别建立了基于价格弹性矩阵的价格型需求响应模型,及考虑用能侧电能和热能相互转的替代型需求响应
- 质子交膜燃料电池系统模型(PEMFC),基于MATLAB simulink开发 主要部分有空压机模型,供气系统模型(阴极和阳极),背压阀模型,电堆模型等 可进行控制策略等仿真开发工作
- 基于.net6的跨平台物联网网关 通过可视化配置,轻松的连接到你的任何设备和系统(如PLC、扫码枪、CNC、数据库、串口设备、上位机、OPC Server、OPC UA Server、Mqtt Se
- 不确定性决策理论及其军事与自动化应用
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
- 1
- 2
- 3
- 4
前往页