import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
def main():
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# 50000张训练图片
# 第一次使用时要将download设置为True才会自动去下载数据集
train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
download=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,
shuffle=True, num_workers=0)
# 10000张验证图片
# 第一次使用时要将download设置为True才会自动去下载数据集
val_set = torchvision.datasets.CIFAR10(root='./data', train=False,
download=False, transform=transform)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=5000,
shuffle=False, num_workers=0)
val_data_iter = iter(val_loader)
val_image, val_label = val_data_iter.next()
# classes = ('plane', 'car', 'bird', 'cat',
# 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
net = LeNet()
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)
for epoch in range(5): # loop over the dataset multiple times
running_loss = 0.0
for step, data in enumerate(train_loader, start=0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if step % 500 == 499: # print every 500 mini-batches
with torch.no_grad():
outputs = net(val_image) # [batch, 10]
predict_y = torch.max(outputs, dim=1)[1]
accuracy = torch.eq(predict_y, val_label).sum().item() / val_label.size(0)
print('[%d, %5d] train_loss: %.3f test_accuracy: %.3f' %
(epoch + 1, step + 1, running_loss / 500, accuracy))
running_loss = 0.0
print('Finished Training')
save_path = './Lenet.pth'
torch.save(net.state_dict(), save_path)
if __name__ == '__main__':
main()
zuoyou-HPU
- 粉丝: 2644
- 资源: 20
最新资源
- 基于matlab的齿轮系统非线性动力学特性分析,综合考虑齿侧间隙、时变啮合刚度、综合啮合误差等因素下,参数阻尼比变化调节下,输出位移、相图、载荷、频率幅值结果 程序已调通,可直接运行
- 10kV线路微机继电保护装置源代码,配套pcb图纸和bom 适合自己学习的素材,也可作为基础版本工程,缩短开发周期 为源码和pcb图
- 无刷直流电机的调速 Matlab simulink仿真搭建模型 介绍:该模型展示了无刷直流电机的速度控制 无刷直流电机有完整的动态模型 将电机的实际转速与参考转速进行比较,以控制三相逆变器来调节端
- comsol枝晶生长 温度场相场溶质场三场耦合
- 水箱水位温度MCGS嵌入版7.7脚本程序动画仿真 带历史数据报表,实时数据报表,历史曲线,实时曲线 标价就是卖价
- 三相共直流母线式光储VSG 同步机 构网型 组网型逆变器 仿真包含前级光伏PV与Boost的扰动观察法最大功率追踪,共直流母线式储能Buck-boost变器,采用电压电流双闭环控制 三相VSG 同步
- TCP IP协议栈IP,纯RTL语言实现,包含tcp server,tcp client,icmp,ping 等,可移植任何平台
- 锂电池建模与热管理仿真 主要贡献: 1、 对并联或串联连接的任意所需数量的电池进行电池系统仿真; 2、拟串联电池的被动平衡; 3、自动将统计参数偏差分配给电池系统内的所有电池; 4、模拟不可逆和可逆电
- 燃料电池汽车参数匹配与能量管理 包含燃料电池汽车的燃料电池动力源功率选型,驱动电机参数匹配选型,蓄电池参数匹配选型,主减速比匹配,以满足最高车速,最大爬坡度,百公里加速时间等动力性要求 然后根据参
- 永磁同步电机三闭环控制Simulink仿真 电流内环 转速 位置外环 参数已经调好 原理与双闭环类似 有资料,仿真
- 无线电能传输仿真模型,电路采用S-S拓扑结构 闭环输出电压400v,输出效果良好 采用的是移相控制 另有主电路的参数设计过程
- 台达DVP 16ES2与3台 台达DT3系温控器通讯程序(TDES-3) 功能:采用台达DVP ES2型号PLC,对台达DT3温控器通过485方式,modbus协议,进行温度的设定,实际温度读
- comsol 激光抛光, 平顶激光,连续激光,高斯激光,都可以进行抛光,所用公式有文献参考处处
- seqlist-malloc
- dogstar-ui-html
- Maxwell和Simplorer联合仿真-永磁同步电机SVPWM控制 本仿真用AnsysEM实现永磁同步电机(PMSM)的仿真模拟,控制方式采用空间矢量控制,闭环方式采用电流环速度环双闭环控制
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈