# 手把手教你使用Pytorch训练自己的分类模型
![封面](https://vehicle4cm.oss-cn-beijing.aliyuncs.com/imgs/%E5%B0%81%E9%9D%A2.png)
之前更新过一起tf版本的训练自己的物体分类模型,但是很多兄弟反应tf版本的代码在GPU上无法运行,这个原因是tf在30系显卡上没有很好的支持。所以我们重新更新一期Pytorch版本的物体分类模型训练教程,在这个教程里面,你将会学会**物体分类的基本概念+数据集的处理+模型的训练和测试+图形化界面的构建**。我这里使用的显卡是NVIDIA RTX3060 6G的笔记本显卡。为了避免带货的嫌疑,我就不说具体的机器型号了,实际的体验中呢,一般4G以上的显存跑个resnet和yolo之类的是没有问题的,如果你是科研人员的话(科研人员估计也不会看我的博客),则需要更牛的服务器来支持你的研究。
## 基本概念
![gogo](https://vehicle4cm.oss-cn-beijing.aliyuncs.com/imgs/gogo.jpg)
从左向右依次是图像分类,目标检测,语义分割和实例分割。
**图像分类**是指为输入图像分配类别标签。自 2012 年采用深度卷积网络方法设计的 AlexNet 夺得 ImageNet 竞赛冠军后,图像分类开始全面采用深度卷积网络。2015 年,微软提出的 ResNet 采用残差思想,将输入中的一部分数据不经过神经网络而直接进入到输出中,解决了反向传播时的梯度弥散问题,从而使得网络深度达到 152 层,将错误率降低到 3.57%,远低于 5.1%的人眼识别错误率,夺得了ImageNet 大赛的冠军。
**目标检测**指用框标出物体的位置并给出物体的类别。2013 年加州大学伯克利分校的 Ross B. Girshick 提出 RCNN 算法之后,基于卷积神经网络的目标检测成为主流。之后的检测算法主要分为两类,一是基于区域建议的目标检测算法,通过提取候选区域,对相应区域进行以深度学习方法为主的分类,如 RCNN、Fast-RCNN、Faster-RCNN、SPP-net 和 Mask R-CNN 等系列方法。二是基于回归的目标检测算法,如 YOLO、SSD 和 DenseBox 等。
**图像分割**指将图像细分为多个图像子区域。2015 年开始,以全卷积神经网络(FCN)为代表的一系列基于卷积神经网络的语义分割方法相继提出,不断提高图像语义分割精度,成为目前主流的图像语义分割方法。实例分割则是实例级别的语义分割。
我们本期教程主要是<font color='red'>图像分类</font>,即给定一张图片,模型判断出他的具体类别。
## 环境配置
### Anaconda 和 Pycahrm安装
nvidia-驱动下载地址:[官方驱动 | NVIDIA](https://www.nvidia.cn/Download/index.aspx?lang=cn)
![image-20221206174728937](https://vehicle4cm.oss-cn-beijing.aliyuncs.com/imgs/image-20221206174728937.png)
使用代码之前请先确保电脑上已经安装好了anaconda和pycharm。
环境的基本配置请看这期博客:[如何在pycharm中配置anaconda的虚拟环境_肆十二的博客-CSDN博客_pycharm配置anaconda虚拟环境](https://blog.csdn.net/ECHOSON/article/details/117220445)
miniconda下载地址:[Index of /anaconda/miniconda/ | 清华大学开源软件镜像站 | Tsinghua Open Source Mirror](https://mirrors.tuna.tsinghua.edu.cn/anaconda/miniconda/)
![image-20221206173858594](https://vehicle4cm.oss-cn-beijing.aliyuncs.com/imgs/image-20221206173858594.png)
conda加速
```bash
conda config --add channels https://mirrors.ustc.edu.cn/anaconda/pkgs/main/
conda config --add channels https://mirrors.ustc.edu.cn/anaconda/pkgs/free/
conda config --add channels https://mirrors.bfsu.edu.cn/anaconda/cloud/pytorch/
conda config --set show_channel_urls yes
pip config set global.index-url https://mirrors.ustc.edu.cn/pypi/web/simple
```
Pycharm的下载地址:[Other Versions - PyCharm (jetbrains.com)](https://www.jetbrains.com/pycharm/download/other.html)
![image-20221206173934245](https://vehicle4cm.oss-cn-beijing.aliyuncs.com/imgs/image-20221206173934245.png)
### 代码环境配置
代码环境配置步骤较多,建议按照视频教程操作,下面只列出关键命令,方便大家复制粘贴。
```bash
conda create -n cls-42 python==3.8.5
conda activate cls-42
conda install pytorch==1.10.0 torchvision==0.11.0 torchaudio==0.10.0 cudatoolkit=11.3
cd 自己本地的代码目录 (或者在本地代码目录的上方打开cmd)
pip install -r requirements.txt
```
## 数据集
### 数据集的搜集
数据集一般有两种方式获取,一种可以通过自己拍摄或者是爬虫爬取建立自建的数据集,这里在本科毕设和大作业的过程中用的比较多,另外一种是使用公开的数据集,后续我这边也会更新一些视觉相关的数据集,大家可以在这里自行查找:[肆十二的博客_CSDN博客-大作业,目标检测,个人心得领域博主](https://blog.csdn.net/ECHOSON?type=download)
![image-20221206174854041](https://vehicle4cm.oss-cn-beijing.aliyuncs.com/imgs/image-20221206174854041.png)
对于公开数据集,比如医学分割,我们一般从这个网址获取:
```bash
https://www.isic-archive.com/#!/onlyHeaderTop/gallery
```
我们这里提供了一个爬虫的程序,可以帮助大家从百度图片中爬取自己需要的图片,程序的名称是`data_get.py`,使用起来非常方便,大家直接运行程序之后,属于自己想要爬取的图片即可,这段程序我直接放在这里。
```python
# -*- coding: utf-8 -*-
# @Time : 2021/6/17 20:29
# @File : get_data.py
# @Software: PyCharm
# @Brief : 爬取百度图片
import requests
import re
import os
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/84.0.4147.125 Safari/537.36'}
name = input('请输入要爬取的图片类别:')
num = 0
num_1 = 0
num_2 = 0
x = input('请输入要爬取的图片数量?(1等于60张图片,2等于120张图片):')
list_1 = []
for i in range(int(x)):
name_1 = os.getcwd()
name_2 = os.path.join(name_1, 'data/' + name)
url = 'https://image.baidu.com/search/flip?tn=baiduimage&ie=utf-8&word=' + name + '&pn=' + str(i * 30)
res = requests.get(url, headers=headers)
htlm_1 = res.content.decode()
a = re.findall('"objURL":"(.*?)",', htlm_1)
if not os.path.exists(name_2):
os.makedirs(name_2)
for b in a:
try:
b_1 = re.findall('https:(.*?)&', b)
b_2 = ''.join(b_1)
if b_2 not in list_1:
num = num + 1
img = requests.get(b)
f = open(os.path.join(name_1, 'data/' + name, name + str(num) + '.jpg'), 'ab')
print('---------正在下载第' + str(num) + '张图片----------')
f.write(img.content)
f.close()
list_1.append(b_2)
elif b_2 in list_1:
num_1 = num_1 + 1
continue
except Exception as e:
print('---------第' + str(num) + '张图片无法下载----------')
num_2 = num_2 + 1
continue
# 为了防止下载的数据有坏图,直接在下载过程中对数据进行清洗
print('下载完成,总共下载{}张,成功下载:{}张,重复下载:{}张,下载失败:{}张'.format(num + num_1 + num_2, num, num_1, num_2))
```
比如这里我想要爬取向日葵的图片,运行之后输入向日葵,然后输入想要爬取的图片数量即可。
![image-20221129140549793](https://cmfighting.oss-cn-shenzhen.aliyuncs.com/iiimgs/image-20221129140549793.png)
输入完成之后,爬取之后的图片将会自动保存在data目录下。
![image-20221129140629126](https://cmfighting.oss-cn-shenzhen.aliyuncs.com/iiimgs/image-20221129140629126.png)
### 数据集清洗
在实际的使用中,opencv对中文的支持并不好,在一些封装好的以opencv作为后端的api中�
- 1
- 2
- 3
前往页