import torch
import torchvision
from torch import nn
from torch.utils.tensorboard import SummaryWriter
import time
# from p24_model import *
# 1. 准备数据集
# 训练数据
from torch.utils.data import DataLoader
train_data = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=torchvision.transforms.ToTensor(),
download=True)
# 测试数据
test_data = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=torchvision.transforms.ToTensor(),
download=True)
# 查看数据大小--size
print("训练数据集大小:", len(train_data))
print("测试数据集大小:", len(test_data))
# 利用DataLoader来加载数据集
train_loader = DataLoader(dataset=train_data, batch_size=64)
test_loader = DataLoader(dataset=test_data, batch_size=64)
# 2. 导入模型结构 创建模型
class Cifar10Net(nn.Module):
def __init__(self):
super(Cifar10Net, self).__init__()
self.net = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(32, 32, 5, 1, 2),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(32, 64, 5, 1, 2),
nn.MaxPool2d(kernel_size=2),
nn.Flatten(),
nn.Linear(64*4*4, 64),
nn.ReLU(),
nn.Linear(64, 10)
)
def forward(self, x):
x = self.net(x)
return x
# 设置cpu或者gpu机器
device = torch.device('cpu') # 使用cpu
device = torch.device('cuda') # 单台gpu
device = torch.device('cuda:0') # 使用第一台gpu机器
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
net = Cifar10Net()
# if torch.cuda.is_available():
# net = net.cuda()
# net.to(device)
net = net.to(device) # 和上面的代码效果一样, 可以不重新赋值新的变量
# 3. 创建损失函数 分类问题--交叉熵
loss_fn = nn.CrossEntropyLoss()
# if torch.cuda.is_available():
# loss_fn = loss_fn.cuda()
# loss_fn.to(device)
loss_fn = loss_fn.to(device) # 和上面的代码效果一样, 可以不重新赋值新的变量
# 4. 创建优化器
# learing_rate = 0.01
# 1e-2 = 1 * 10^(-2) = 0.01
learing_rate = 1e-2
print(learing_rate)
optimizer = torch.optim.SGD(net.parameters(), lr=learing_rate)
# 设置训练网络的一些参数
epoch = 10 # 记录训练的轮数
total_train_step = 0 # 记录训练的次数
total_test_step = 0 # 记录测试的次数
# 利用tensorboard显示训练loss趋势
writer = SummaryWriter('./train_logs')
start_time = time.time()
print('start_time: ', start_time)
print(torch.cuda.is_available())
for i in range(epoch):
# 训练步骤开始
net.train() # 可以加可以不加 只有当模型结构有 Dropout BatchNorml层才会起作用
for data in train_loader:
imgs, targets = data # 获取数据
# if torch.cuda.is_available():
# imgs = imgs.cuda()
# targets = targets.cuda()
imgs = imgs.to(device)
targets = targets.to(device)
output = net(imgs) # 数据输入模型
loss = loss_fn(output, targets) # 损失函数计算损失 看计算的输出和真实的标签误差是多少
# 优化器开始优化模型 1.梯度清零 2.反向传播 3.参数优化
optimizer.zero_grad() # 利用优化器把梯度清零 全部设置为0
loss.backward() # 设置计算的损失值,调用损失的反向传播,计算每个参数结点的参数
optimizer.step() # 调用优化器的step()方法 对其中的参数进行优化
# 优化一次 认为训练了一次
total_train_step += 1
if total_train_step % 100 == 0:
print('训练次数: {} loss: {}'.format(total_train_step, loss))
end_time = time.time()
print('训练100次需要的时间:', end_time-start_time)
# 直接打印loss是tensor数据类型,打印loss.item()是打印的int或float真实数值, 真实数值方便做数据可视化【损失可视化】
# print('训练次数: {} loss: {}'.format(total_train_step, loss.item()))
writer.add_scalar('train-loss', loss.item(), global_step=total_train_step)
# 利用现有模型做模型测试
# 测试步骤开始
total_test_loss = 0
accuracy = 0
net.eval() # 可以加可以不加 只有当模型结构有 Dropout BatchNorml层才会起作用
with torch.no_grad():
for data in test_loader:
imgs, targets = data
# if torch.cuda.is_available():
# imgs = imgs.cuda()
# targets = targets.cuda()
imgs = imgs.to(device)
targets = targets.to(device)
output = net(imgs)
loss = loss_fn(output, targets)
total_test_loss += loss.item()
# 计算测试集的正确率
preds = (output.argmax(1)==targets).sum()
accuracy += preds
# writer.add_scalar('test-loss', total_test_loss, global_step=i+1)
writer.add_scalar('test-loss', total_test_loss, global_step=total_test_step)
writer.add_scalar('test-accracy', accuracy/len(test_data), total_test_step)
total_test_step += 1
print("---------test loss: {}--------------".format(total_test_loss))
print("---------test accuracy: {}--------------".format(accuracy/len(test_data)))
# 保存每一个epoch训练得到的模型
torch.save(net, './net_epoch{}.pth'.format(i))
writer.close()
没有合适的资源?快使用搜索试试~ 我知道了~
资源推荐
资源详情
资源评论
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![ipynb](https://img-home.csdnimg.cn/images/20210720083646.png)
![ipynb](https://img-home.csdnimg.cn/images/20210720083646.png)
![ipynb](https://img-home.csdnimg.cn/images/20210720083646.png)
![rar](https://img-home.csdnimg.cn/images/20210720083606.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![ipynb](https://img-home.csdnimg.cn/images/20210720083646.png)
![ipynb](https://img-home.csdnimg.cn/images/20210720083646.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![7z](https://img-home.csdnimg.cn/images/20210720083312.png)
![txt](https://img-home.csdnimg.cn/images/20210720083642.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
收起资源包目录
![package](https://csdnimg.cn/release/downloadcmsfe/public/img/package.f3fc750b.png)
![folder](https://csdnimg.cn/release/downloadcmsfe/public/img/folder.005fa2e5.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![folder](https://csdnimg.cn/release/downloadcmsfe/public/img/folder.005fa2e5.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![folder](https://csdnimg.cn/release/downloadcmsfe/public/img/folder.005fa2e5.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/ZIP.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![folder](https://csdnimg.cn/release/downloadcmsfe/public/img/folder.005fa2e5.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/JPG.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/JPG.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![folder](https://csdnimg.cn/release/downloadcmsfe/public/img/folder.005fa2e5.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
共 38 条
- 1
资源评论
![avatar-default](https://csdnimg.cn/release/downloadcmsfe/public/img/lazyLogo2.1882d7f4.png)
![avatar](https://profile-avatar.csdnimg.cn/ecb7b54aa0ae4830b4b30cdb5a99c861_weixin_42831564.jpg!1)
陌上阳光
- 粉丝: 580
- 资源: 1
上传资源 快速赚钱
我的内容管理 展开
我的资源 快来上传第一个资源
我的收益
登录查看自己的收益我的积分 登录查看自己的积分
我的C币 登录后查看C币余额
我的收藏
我的下载
下载帮助
![voice](https://csdnimg.cn/release/downloadcmsfe/public/img/voice.245cc511.png)
![center-task](https://csdnimg.cn/release/downloadcmsfe/public/img/center-task.c2eda91a.png)
最新资源
- Dockerfile: Ubuntu18.04 + Python3.10
- Python毕业设计-基于深度学习的水果识别系统的源代码+文档说明+数据集+模型(高分项目)
- PyTorch 开源的机器学习框架结合了动态计算图的灵活性和深度学习库的强大功能
- Python高分毕设-基于深度学习的水果识别系统的+源代码+文档说明+数据集+模型.zip
- TensorFlow 开源的机器学习框架支持大规模的机器学习
- cocosCreator-js动态生成二维码
- STM3232位ARM Cortex-M微控制器
- 基于jsp+servlet+mysql宠物店购物商城系统+源代码+文档说明+数据库(毕业设计).zip
- sdgsApp.apk
- 单片机呜呜呜单片机单片机单片机单片机
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
![feedback](https://img-home.csdnimg.cn/images/20220527035711.png)
![feedback](https://img-home.csdnimg.cn/images/20220527035711.png)
![feedback-tip](https://img-home.csdnimg.cn/images/20220527035111.png)
安全验证
文档复制为VIP权益,开通VIP直接复制
![dialog-icon](https://csdnimg.cn/release/downloadcmsfe/public/img/green-success.6a4acb44.png)