## 使用Pytorch通过卷积神经网络实现CIFAR10数据集的分类器
### 引言
在本次实验中,会使用Pytorch来实现一个卷积神经网络,之后对CIFAR-10数据集进行训练,保存训练模型参数,绘制loss图并保存,使用训练得到的模型对训练集与测试集的数据进行准确率测试,并将多次训练后得到的测试结果记录到对应的csv文件中。
### CIFAR-10数据集
CIFAR10数据集一共有60000张32*32的彩色图,共有十类,每类6000张,其中5000张训练图,1000张测试图;也就是一共有50000张图用来训练,10000张图用来测试。
十类分别是:plane, car, bird, cat, deer, dog, frog, horse, ship, truck.
更多关于CIFAR-10和CIFAR-100的信息可以看[这里](https://www.cs.toronto.edu/~kriz/cifar.html)
### 事前准备
- 声明:**文档中的代码在拷贝与修改的过程中可能会有错误,具体以实际的代码为准**
- 在data文件夹的声明CIFAR10数据文件夹中创建一个文件夹叫save_model,用来存储训练得到的模型参数
- 使用createcsv.py创建两个csv文件,用来存储每次训练之后测试得到的训练集与测试集的准确度,因为想是csv文件中第一行写上10个类别与总的准确度,以后每次测试完都在下面追加,就把第一行单独先写好,以后测试完直接打开追加即可,代码如下:
```python
import csv
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck','Total']
with open('TrainAccHistory.csv', 'w', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
writer.writerow(classes)
with open('TestAccHistory.csv', 'w', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
writer.writerow(classes)
```
- 下载数据集,可以通过[链接](https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz)直接下载,下载完记得在对应的data文件夹里解压,不然会报错,当然也可以在代码中通过设置download属性进行下载,后面放出来。
### 开始实验
#### 导入需要的模块
我们需要torch模块来创建网络,优化器,损失函数等;需要torchvision模块将下载的数据集做成dataset与loader,以便后续操作;需要matplotlib下的pyplot子库来显示一些图片,绘制并保存loss图;需要os库对文件夹内容进行判断以确定不同的行动分支;使用numpy显示一些数据集中的图片;需要time库查看训练时间等(可以不要);需要csv库将测试结果存到对应的csv文件中。
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data # to make Loader
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import os
import numpy as np
import time
import csv
```
#### 初始化一些参数
包括训练迭代的次数EPOCH;每次训练的批数目BATCH_SIZE;学习率LR;是否下载数据集DOWNLOAD_CIFAR10;
device用来表示使用cpu还是gpu;transform用来将数据进行能够训练的转化;还有10类数据的名字
```python
EPOCH = 2
BATCH_SIZE = 4
LR = 0.001
DOWNLOAD_CIFAR10 = False
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
train_data = torchvision.datasets.CIFAR10(
root='./data',
train=True,
download=DOWNLOAD_CIFAR10,
transform=transform
)
train_loader = Data.DataLoader(
train_data,
batch_size=BATCH_SIZE,
shuffle=True,
# num_workers=2 # ready to be commented(windows)
)
test_data = torchvision.datasets.CIFAR10(
root='./data',
train=False,
download=DOWNLOAD_CIFAR10,
transform=transform,
)
test_loader = Data.DataLoader(
test_data,
batch_size=BATCH_SIZE,
shuffle=False,
# num_workers=2
)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
```
其中的一些参数说明:root表示数据集所在的目录;train=True表示将是训练集的数据提出来,False就是将测试的那部分提出来;download表示是否进行下载,因为前面用链接下载解压过了,所以设置的DOWNLOAD_CIFAR10为False;transform表示进行相应转化(转成灰度以及其他格式)。DataLoader中的第一个参数表示用来做成迭代器Loader的数据集;第二个参数batch_size即为每次load出来的图片数;shuffle参数表示是否在每个epoch开始的时候将数据集重新打乱;num_workers表示用来处理data的进程数,在windows下抱错,直接注释掉即可。(参考[文档](https://www.jianshu.com/p/ecd4549a5819))
#### 显示一些图片
下面试着显示一下几张图,看看有没有load成功:
```python
def imshow(img):
img = img /2 +0.5
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1,2,0)))
plt.show()
dataiter = iter(train_loader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
```
因为上面设置的batch_size为4,所以这里就打印4个labels,得到的结果是:
![img](./Figure_1.png)
关闭窗口之后即可在终端看到对应的labels,分别为: cat frog frog car
将这段显示图片的注释掉,进行后续实验,不然每次运行都显示就有点烦(强迫症哈哈哈)
#### 定义卷积神经网络
采用LeNet5模型,即两层卷积核为5*5的卷积层,两层最大池化层,三层全连接层:
```python
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10) # result is ten kinds of item
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16*5*5) # reshape to 16*5*5 to fc
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
output = self.fc3(x)
return output
cnn = CNN()
```
注意每次卷积之后都要经过一层激活层再进行最大池化层,激活函数用的是relu().
#### 定义损失函数与优化器
```python
optimizer = optim.Adam(cnn.parameters())
loss_func = nn.CrossEntropyLoss()
```
损失函数使用的CrocsEntropyLoss(), 优化器使用的是据说是优化器中的集大成者: Adam,不需要传入学习率,但是在保存模型参数的时候需要将Adam优化器的参数也要保存,以便下次在这个模型参数基础上训练时使用该参数,如果不保存可能会影响结果,于是将状态信息定义成下面这样:
```python
state = {'cnn': cnn.state_dict(), 'optimizer': optimizer.state_dict()}
```
包括卷积神经网络的状态以及优化器的状态,将来训练完之后保存。
绘制并保存loss图可以用如下方式实现:
```python
losses = [] # record losses
def save_losses(losses):
t = np.arange(len(losses))
plt.plot(t, losses)
plt.savefig('loss.png')
# plt.show()
```
#### 定义训练函数
下面定义训练函数,对train_loader load的数据进行训练:
```python
def train():
global losses
for epoch in range(EPOCH):
running_loss = 0.0
for step, (inputs, labels) in enumerate(train_loader, 0):
# inputs, labels = data[0].to(device), data[1].to(device)
inputs = inputs.to(device) # 2
labels = labels.to(device)
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
人工智能学习总结成果,希望可以帮到大家,有疑问欢迎随时沟通~ 人工智能学习总结成果,希望可以帮到大家,有疑问欢迎随时沟通~ 人工智能学习总结成果,希望可以帮到大家,有疑问欢迎随时沟通~ 人工智能学习总结成果,希望可以帮到大家,有疑问欢迎随时沟通~ 人工智能学习总结成果,希望可以帮到大家,有疑问欢迎随时沟通~
资源推荐
资源详情
资源评论
收起资源包目录
《人工智能》--人工智能原理课程实验1,numpy实现Lenet5,im2col方法实现的.zip (68个子文件)
Exer2
Astar.py 2KB
AstarNew.py 5KB
Dijkstra.py 4KB
Readme.md 50B
ExerRL
实验报告.md 22KB
mydqn
main.py 4KB
dqn_tf.py 7KB
pics
A3Ca.PNG 139KB
board.png 361KB
ACP.PNG 22KB
AandE.PNG 190KB
a3c1_t.PNG 41KB
AC.PNG 58KB
Ql.png 417KB
dqna.PNG 128KB
blog1.PNG 41KB
blog2.PNG 37KB
MDP1.PNG 4KB
13000.PNG 31KB
ACa.PNG 85KB
dqnb.png 291KB
AC2.PNG 89KB
A3Cres.PNG 50KB
MDP2.PNG 38KB
mydqnm.PNG 13KB
a3c2_t.PNG 42KB
dqn2.png 27KB
a3c3_t.PNG 46KB
dqn1.jpg 46KB
mydqns.PNG 12KB
RL_A3C
utils.py 2KB
main.py 5KB
gym_eval.py 5KB
environment.py 7KB
model.py 2KB
player_util.py 3KB
trained_models
Breakout-v0.dat 12.39MB
shared_optim.py 7KB
config.json 979B
logs
Breakout-v0_mon_log 2KB
Breakout-v0_log 790KB
result.gif 5.22MB
train.py 4KB
test.py 4KB
README.md 13B
testblog
testblog.py 10KB
Exer1
mnist.zip 11.06MB
app.py 22KB
readme 41B
weights.pkl 965KB
Exer3
cifarCNN.py 10KB
createcsv.py 381B
pics
10.png 6KB
9.png 17KB
3.png 4KB
loss.png 17KB
12.png 29KB
1.png 6KB
11.png 27KB
6.png 7KB
5.png 61KB
4.png 5KB
8.png 17KB
7.png 7KB
2.png 18KB
README.md 19KB
Figure_1.png 23KB
README.md 42B
共 68 条
- 1
资源评论
季风泯灭的季节
- 粉丝: 605
- 资源: 2920
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功