# @Time : 2022/8/25 15:02
# @Author :lgl
# @e-mail :GuanlinLi_BIT@163.com
import matplotlib.pyplot as plt
import glob
import time
import numpy
import torch
import torchvision
from torch.optim import SGD
import winsound
from res_model import *
from model import *
from torch.utils.data import DataLoader
from torch import nn
from torch.utils.tensorboard import SummaryWriter
# 把网络模型、数据(输入和标注)、损失函数加上 .cuda 就可以用GPU
old_time = time.time()
# 创建日志
writer = SummaryWriter("../logs")
# 加载数据
train_data = torchvision.datasets.CIFAR10(r"D:\AI_project\20220825_cifar10demo\cifar10", train=True,
transform=torchvision.transforms.ToTensor(), download=True)
train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
train_len = len(train_data) # 训练数据集长度
print("训练数据集长度:{}".format(train_len))
# 加载测试数据
test_data = torchvision.datasets.CIFAR10(r"D:\AI_project\20220825_cifar10demo\cifar10", train=False,
transform=torchvision.transforms.ToTensor(), download=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
test_len = len(test_data) # 训练数据集长度
# 加载最新历史模型
path_file_num = glob.glob(r'D:\AI_project\20220825_cifar10demo\model\*.pth') # 匹配给定路径下的所有pattern,并以列表形式返回。
pth_num = len(path_file_num) # 读取路径中有多少.pth文件
if pth_num != 0:
history_model = torch.load(r"D:\AI_project\20220825_cifar10demo\model\neural_networks_all_{}.pth".format(pth_num))
print("当前最新历史数据:{}".format("neural_networks_all_{}.pth".format(pth_num)))
# 选择训练的模型。 neural_networks 或者 history_model
model = res_neural_networks
print("残差f网络结构: {}".format(res_neural_networks))
print("网络结构: {}".format(neural_networks))
if torch.cuda.is_available():
model = model.cuda()
print("GPU is working")
else:
print("CPU is working")
# 设置损失函数
ce_loss = nn.CrossEntropyLoss() # 交叉熵损失函数,分类问题常用
if torch.cuda.is_available():
ce_loss = ce_loss.cuda()
# 设置学习速率
# 设置优化器
# 记录训练次数
total_train_step = 0
# 本次训练轮数
epoch = 10
for i in range(epoch):
# 设置学习率
learning_rate_max = 0.05
learning_rate_min = 0.001
learning_rate = learning_rate_max - (learning_rate_max - learning_rate_min) * (i + 1) / epoch
print("learning_rate:{}".format(learning_rate))
# 设置优化器
opt = SGD(model.parameters(), lr=learning_rate) # 随机梯度下降
print("\n\n---------第{}轮训练开始---------".format(i + 1))
# loss_history = numpy.load(r"D:\AI_project\20220825_cifar10demo\loss_history\loss_history.npy")
total_train_loss = 0.0
min_total_loss = 1400.0 # 最优模型loss值的记录
for data in train_dataloader:
imgs, targets = data
if torch.cuda.is_available():
imgs = imgs.cuda()
targets = targets.cuda()
output = model(imgs) # 载入数据
loss = ce_loss(output, targets) # 计算loss
# 优化器
opt.zero_grad() # 将之前的梯度清零,准备好这一轮的优化
loss.backward() # 反向传播,求梯度
opt.step() # 执行优化
total_train_loss += loss
total_train_step += 1
# 每隔100次显示训练情况
if total_train_step % 100 == 0:
print("训练次数:{}, Loss:{}".format(total_train_step, loss.item()))
# 本轮训练结束
print("本轮总Loss:{}".format(total_train_loss))
# 由于total_train_loss是张量,所以要转换成numpy类型保存
# new_loss_history = numpy.append(loss_history, total_train_loss.detach().numpy())
# numpy.save(r"D:\AI_project\20220825_cifar10demo\loss_history\loss_history.npy", new_loss_history)
# print("len hisloss:{}".format(len(new_loss_history)))
writer.add_scalar("loss", total_train_loss, i)
if i == 0:
min_total_loss = total_train_loss
if total_train_loss <= min_total_loss:
path_file_num = glob.glob(r'D:\AI_project\20220825_cifar10demo\model\*.pth')
pth_num = len(path_file_num) # 读取路径中有多少.pth文件
if i == 0:
pth_num += 1
torch.save(model, r"D:\AI_project\20220825_cifar10demo\model\neural_networks_all_{}.pth".format(pth_num))
print("save, loss={}".format(total_train_loss))
min_total_loss = total_train_loss
# 新的最优网络出现提示音
############### winsound.Beep(222, 500) # 主板蜂鸣器
########### winsound.MessageBeep() # 喇叭
# 跑完一轮,提示音
######## # winsound.Beep(2222, 500) # 主板蜂鸣器
########### winsound.MessageBeep() # 喇叭
# 每次隔10轮测试一次准确率
if i % 2 == 0:
total_accuracy = 0.0
for data in test_dataloader:
test_imgs, test_targets = data
if torch.cuda.is_available():
test_imgs = test_imgs.cuda()
test_targets = test_targets.cuda()
test_output = model(test_imgs) # 载入数据
accuracy = (test_output.argmax(1) == test_targets).sum() # 每轮64个图片里面对了几个 argmax(1)
total_accuracy += accuracy
print("整体测试集上的准确率:{:.2%}".format(total_accuracy / test_len))
current_time = time.time()
writer.close()
print("\n\n本次训练共训练{}轮".format(epoch))
print("运行时间为 {} s".format(current_time - old_time))
# loss_history = #numpy.load(r"D:\AI_project\20220825_cifar10demo\loss_history\loss_history.npy")
# p#rint("历史loss情况:{}".format(los#s_history))
# plt.plot(range(len(loss_history)), loss_history)
# plt.title('history_loss')
# plt.show() # 显示
# 在Terminal里面输入:tensorboard --logdir D:\AI_project\20220825_cifar10demo\logs 可以打开网页,查看lost
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
cifar10文件夹:包括了cifar10原数据库 kaggle文件夹:包括了在kaggle上训练好的模型,以及日志文件 model文件夹:包括了本地cpu训练好的模型 src文件夹:kaggle_tensorboard.py: 用于使用tensorboard展示kaggle上训练的日志 model.py: 神经网络模型 res_model:残差网络模型 有问题就发邮件。GuanlinLi_BIT@163.com
资源详情
资源评论
资源推荐
收起资源包目录
cifar10 - pytorch - 模型源文件、train、test、use等源代码,kaggle上训练好的模型 (156个子文件)
events.out.tfevents.1661443636.DESKTOP-O150G6E.7416.0 25KB
events.out.tfevents.1662204940.DESKTOP-O150G6E.10072.0 22KB
events.out.tfevents.1662542623.DESKTOP-O150G6E.3496.0 22KB
events.out.tfevents.1661524803.DESKTOP-O150G6E.12688.0 21KB
events.out.tfevents.1661506627.DESKTOP-O150G6E.3088.0 4KB
events.out.tfevents.1661517374.DESKTOP-O150G6E.6652.0 2KB
events.out.tfevents.1661435467.DESKTOP-O150G6E.7020.0 2KB
events.out.tfevents.1661590996.DESKTOP-O150G6E.14200.0 2KB
events.out.tfevents.1661917174.DESKTOP-O150G6E.12896.0 122B
events.out.tfevents.1661504498.DESKTOP-O150G6E.7952.0 122B
events.out.tfevents.1661442725.DESKTOP-O150G6E.9896.0 80B
events.out.tfevents.1661917011.DESKTOP-O150G6E.15024.0 80B
events.out.tfevents.1661524643.DESKTOP-O150G6E.3380.0 80B
events.out.tfevents.1661442278.DESKTOP-O150G6E.8596.0 80B
events.out.tfevents.1661443469.DESKTOP-O150G6E.15448.0 80B
events.out.tfevents.1661442362.DESKTOP-O150G6E.14584.0 80B
events.out.tfevents.1661443311.DESKTOP-O150G6E.13096.0 80B
events.out.tfevents.1661443216.DESKTOP-O150G6E.13672.0 80B
events.out.tfevents.1661916735.DESKTOP-O150G6E.4792.0 80B
events.out.tfevents.1661516997.DESKTOP-O150G6E.6684.0 80B
events.out.tfevents.1661924401.DESKTOP-O150G6E.12916.0 80B
events.out.tfevents.1661506073.DESKTOP-O150G6E.1152.0 80B
events.out.tfevents.1661923853.DESKTOP-O150G6E.612.0 80B
events.out.tfevents.1661504099.DESKTOP-O150G6E.10660.0 80B
events.out.tfevents.1662258194.DESKTOP-O150G6E.12000.0 40B
events.out.tfevents.1661925576.DESKTOP-O150G6E.816.0 40B
events.out.tfevents.1662259410.DESKTOP-O150G6E.14776.0 40B
events.out.tfevents.1662087932.DESKTOP-O150G6E.15620.0 40B
events.out.tfevents.1662085063.DESKTOP-O150G6E.3096.0 40B
events.out.tfevents.1662003595.DESKTOP-O150G6E.5360.0 40B
events.out.tfevents.1662088520.DESKTOP-O150G6E.4080.0 40B
events.out.tfevents.1661506464.DESKTOP-O150G6E.13348.0 40B
events.out.tfevents.1661441415.DESKTOP-O150G6E.4852.0 40B
events.out.tfevents.1661922768.DESKTOP-O150G6E.7492.0 40B
events.out.tfevents.1661441845.DESKTOP-O150G6E.15964.0 40B
events.out.tfevents.1662257996.DESKTOP-O150G6E.12848.0 40B
events.out.tfevents.1661923488.DESKTOP-O150G6E.12940.0 40B
events.out.tfevents.1662088014.DESKTOP-O150G6E.7412.0 40B
events.out.tfevents.1661960877.DESKTOP-O150G6E.5892.0 40B
events.out.tfevents.1662087258.DESKTOP-O150G6E.5068.0 40B
events.out.tfevents.1662084953.DESKTOP-O150G6E.2668.0 40B
events.out.tfevents.1662087903.DESKTOP-O150G6E.17392.0 40B
events.out.tfevents.1662085555.DESKTOP-O150G6E.4644.0 40B
events.out.tfevents.1661504468.DESKTOP-O150G6E.12092.0 40B
events.out.tfevents.1661441517.DESKTOP-O150G6E.3904.0 40B
events.out.tfevents.1661916353.DESKTOP-O150G6E.400.0 40B
events.out.tfevents.1661960822.DESKTOP-O150G6E.3980.0 40B
events.out.tfevents.1662085447.DESKTOP-O150G6E.8860.0 40B
events.out.tfevents.1662086961.DESKTOP-O150G6E.9088.0 40B
events.out.tfevents.1661922717.DESKTOP-O150G6E.4308.0 40B
events.out.tfevents.1662088365.DESKTOP-O150G6E.15972.0 40B
events.out.tfevents.1662088311.DESKTOP-O150G6E.10868.0 40B
events.out.tfevents.1662088029.DESKTOP-O150G6E.5344.0 40B
events.out.tfevents.1661925569.DESKTOP-O150G6E.16256.0 40B
events.out.tfevents.1662088217.DESKTOP-O150G6E.6992.0 40B
events.out.tfevents.1662087837.DESKTOP-O150G6E.10784.0 40B
events.out.tfevents.1662088178.DESKTOP-O150G6E.16056.0 40B
events.out.tfevents.1662087222.DESKTOP-O150G6E.12512.0 40B
events.out.tfevents.1662259127.DESKTOP-O150G6E.5152.0 40B
events.out.tfevents.1661590981.DESKTOP-O150G6E.8196.0 40B
events.out.tfevents.1661516960.DESKTOP-O150G6E.8884.0 40B
events.out.tfevents.1662088404.DESKTOP-O150G6E.3744.0 40B
events.out.tfevents.1662258294.DESKTOP-O150G6E.15716.0 40B
events.out.tfevents.1662087151.DESKTOP-O150G6E.13300.0 40B
events.out.tfevents.1662258702.DESKTOP-O150G6E.10308.0 40B
events.out.tfevents.1661506288.DESKTOP-O150G6E.11796.0 40B
events.out.tfevents.1662259471.DESKTOP-O150G6E.15692.0 40B
events.out.tfevents.1661506411.DESKTOP-O150G6E.4960.0 40B
events.out.tfevents.1662086624.DESKTOP-O150G6E.9032.0 40B
events.out.tfevents.1662258442.DESKTOP-O150G6E.9964.0 40B
events.out.tfevents.1662257946.DESKTOP-O150G6E.10760.0 40B
events.out.tfevents.1661916050.DESKTOP-O150G6E.2832.0 40B
events.out.tfevents.1662258111.DESKTOP-O150G6E.11360.0 40B
events.out.tfevents.1661506259.DESKTOP-O150G6E.11924.0 40B
events.out.tfevents.1662085330.DESKTOP-O150G6E.5672.0 40B
events.out.tfevents.1661441372.DESKTOP-O150G6E.13836.0 40B
events.out.tfevents.1661504069.DESKTOP-O150G6E.4880.0 40B
events.out.tfevents.1662087798.DESKTOP-O150G6E.7932.0 40B
events.out.tfevents.1661442180.DESKTOP-O150G6E.10712.0 40B
events.out.tfevents.1662088333.DESKTOP-O150G6E.13120.0 40B
events.out.tfevents.1661441268.DESKTOP-O150G6E.3656.0 40B
events.out.tfevents.1661915921.DESKTOP-O150G6E.7700.0 40B
events.out.tfevents.1662086759.DESKTOP-O150G6E.9624.0 40B
events.out.tfevents.1662087286.DESKTOP-O150G6E.12696.0 40B
events.out.tfevents.1662087002.DESKTOP-O150G6E.14324.0 40B
events.out.tfevents.1662087486.DESKTOP-O150G6E.14708.0 40B
data_batch_1 29.6MB
data_batch_2 29.6MB
data_batch_3 29.6MB
data_batch_4 29.6MB
data_batch_5 29.6MB
.gitignore 50B
readme.html 88B
src.iml 334B
logs 4KB
batches.meta 158B
.name 8B
accuracy_history.npy 2KB
loss_history.npy 2KB
accuracy_history.npy 2KB
共 156 条
- 1
- 2
宇智波洛必达
- 粉丝: 36
- 资源: 2
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- 毕设和企业适用springboot自动化仓库管理平台类及云计算资源管理平台源码+论文+视频.zip
- 毕设和企业适用springboot自动化仓库管理平台类及直播流媒体平台源码+论文+视频.zip
- 360图床HTML源码.zip
- 毕设和企业适用springboot订餐类及虚拟人类交互系统源码+论文+视频.zip
- 毕设和企业适用springboot二手跳蚤类及共享经济平台源码+论文+视频.zip
- 2023年总结,个人资料
- 2024年下半年计算机水平考试模拟盘.zip
- A10-Tray自动上料抓取工位工程图机械结构设计图纸和其它技术资料和技术方案非常好100%好用.zip
- MySQL基础-布尔全文搜索.pdf
- ANQU磁铁检测机工程图机械结构设计图纸和其它技术资料和技术方案非常好100%好用.zip
- AS014-XD10检测设备装配体工程图机械结构设计图纸和其它技术资料和技术方案非常好100%好用.zip
- 2023工作总结,个人使用
- 1212338883_2402103_10.2.1.1_20241216090042_951322129_a.apk
- 圣诞树html网页代码
- Linux应急响应手册
- 555构成的多路波形发生器.ms14
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功
评论5