import torch
import torch.nn as nn
import torchvision
from vggnet import VGGNet
from load_cifar10 import test_loader, train_loader
import os
import tensorboardX
# 查看gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epoch_num = 100
lr = 0.001
net = VGGNet().to(device)
# 多分类问题交叉商
loss_func = nn.CrossEntropyLoss()
# 优化器
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
# 采用变长学习率,每五个epoch指数衰减
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9)
# 用于记录过程的log
if not os.path.exists("log"):
os.mkdir("log")
writer = tensorboardX.SummaryWriter("log")
# step计数器
step_n = 0
for epoch in range(epoch_num):
print("epoch is", epoch)
net.train() # train BN dropout
batchsize = 128
for i, data in enumerate(train_loader):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
outputs = net(inputs) # 获取输出
# 计算loss
loss = loss_func(outputs, labels)
optimizer.zero_grad() # 优化qi梯度为
loss.backward() # 反向传播
optimizer.step() # 更新参数
# 打印batch的loss
# print("step", i, "loss is :", loss.item())
_, pred = torch.max(outputs.data, dim=1)
correct = pred.eq(labels.data).cpu().sum()
print("train step", i, "loss is :", loss.item(), "batch correct is:", 100.0 * correct / batchsize)
# 记录loss。correct
writer.add_scalar("train loss", loss.item(), global_step=step_n)
writer.add_scalar("train correct", 100.0 * correct / batchsize, global_step=step_n)
im = torchvision.utils.make_grid(inputs)
writer.add_image("train image", im, global_step=step_n)
step_n += 1
# 保存每个epoch模型
if not os.path.exists("models"):
os.mkdir("models")
torch.save(net.state_dict(), "models/{}.pth".format(epoch + 1))
# 学习率更新
scheduler.step()
print("lr is:", optimizer.state_dict()["param_groups"][0]["lr"])
# 测试
sum_loss = 0
sum_correct = 0
net.eval()
for i, data in enumerate(test_loader):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
outputs = net(inputs) # 获取输出
# 计算loss
loss = loss_func(outputs, labels)
# print("step", i, "loss is :", loss.item())
_, pred = torch.max(outputs.data, dim=1)
correct = pred.eq(labels.data).cpu().sum()
sum_correct += correct.item()
sum_loss += loss.item()
# 记录loss。correct
# print("test step", i, "test loss is :", loss.item(), "test batch correct is:", 100.0 * correct / batchsize)
im = torchvision.utils.make_grid(inputs)
writer.add_image("train image", im, global_step=step_n)
test_loss = sum_loss * 1.0 / len(test_loader)
test_correct = sum_correct * 100.0 / len(test_loader) / batchsize
writer.add_scalar("test loss", loss.item(), global_step=epoch + 1)
writer.add_scalar("test correct", 100.0 * correct / batchsize, global_step=epoch + 1)
print("test epoch", epoch + 1, "test loss is :", test_loss, "test batch correct is:", test_correct)
writer.close()
基于 PyTorch 的 cifar-10 图像分类
需积分: 0 77 浏览量
更新于2023-05-17
收藏 314.24MB ZIP 举报
B站的视频的源码。基于 PyTorch 的 cifar-10 图像分类,文中包括 cifar-10 数据集介绍、环境配置、实验代码、运行结果以及遇到的问题这几个部分,本实验采用了基本网络和VGG加深网络模型,其中VGG加深网络模型的识别准确率是要优于基本网络模型的。
m0_68035385
- 粉丝: 13
- 资源: 1
最新资源
- 利用五次多项式实现基于模型预测控制的道算法,可根据程序修改自己的算法 matlab2016b,carsim2018
- 西门子大型PLC程序,博途V13 v14 V15 V16 V17版,CPu1511,屏为1200,外加30台G120变频器PN通讯十ET200远程io,温度pid控制,压力处理,张力控制,收卷控
- simulink直流调速系统的仿真模型 晶闸管-直流电动机开环调速系统,基于Matalab2018a
- 电机马达PMSM电机负载观测转矩前馈simulink 基于Luenberger降阶状态观测器,包含PMSM数学模型,PMSM双闭环PI矢量控制,并添加了前馈控制,采用SVPWM调制
- Java毕设项目:基于spring+mybatis+maven+mysql实现的高校食堂订餐系统【含源码+数据库+毕业论文】
- MATLAB驱动防滑转控制ASR牵引力控制TCS模型 ASR模型 驱动防滑转模型 牵引力控制系统模型 选择PID控制算法以及对照控制算法,共两种控制算法,可进行选择 选择冰路面以及雪路面,共两
- 基于FPGA的点阵屏设计,基于Quartus ii开发,Verilog编程语言,也可移植到vivado开发 1、可以显示多个汉字 2、暂停、启动控制 3、左移右移控制 4、调速控制
- omron欧姆龙CJ CP程序 欧姆龙CP1H-XA,主机搭载CIF串口模块与2从机PC LINK通信控制, X,Y轴模组矩阵取放料控制,托盘升降机提升机控制应用 全自动LCD组装机,欧姆龙触摸
- Java毕设项目:基于spring+mybatis+maven+mysql实现的人力资源管理系统【含源码+数据库+毕业论文】
- OMRON大型PLC CJ2M项目案例,配套昆仑通泰触摸屏程序 包含模拟量称重模块,SCU串行通讯模块,通过CMND指令把Fins协议转为MODBUS RTU 控制32台三菱 D720
- 128W微型车载逆变器方案,包含原理图,PCB图,烧录文件,汇编源代码,注意是汇编语言
- MES机台看板系统 可连接24台机,还可以扩展更多 通过网口直接与PLC直接通讯,包含西门子全系列,倍福PLC,三菱,松下,欧姆龙主流PLC 可以读写PLC里面BOOL,int,字符串,汉字(源码
- No.1004 S7-200 PLC和组态王温室大棚温室 带解释的梯形图程序,接线图原理图图纸,io分配,组态画面
- Python编程项目:Labyrinth迷宫小游戏完整源码分享给需要的同学
- QT界面开发框架,完整资源与代码一套
- 恒压供水plc程序,三菱FX1N.2N系列plc+fx0n3a模拟量模块+昆仑通态tpc7062触摸屏.全套的图纸十程序,完整的注释,适合参考学习