# 实验介绍
pytorch实现狗品种二分类以及tensorrt加速
# 开发环境
- windows 11
- pytorch 1.12.0
- torchvision 0.13.0
- python 3.7.13
- tensorboard 1.15.0
- tensorboardx 2.5.1
# 数据集介绍
使用Imagewoof数据集,这是一个由旧金山大学于2020年发布的狗十分类数据集。主要品种有:澳大利亚梗、边境梗、萨摩耶、比格犬、西施犬、英国猎狐犬、罗得西亚脊背犬、澳洲野狗、金毛猎犬、古英国牧羊犬。本实验选取了澳大利亚梗跟萨摩耶做二分类。
下载路径:https://www.aliyundrive.com/s/iVwzXD28suG ,验证码:h80q
# 网络架构
实验过程中一共使用了四种不同的网络,即net9、net44、resnet18、resnet34以及resnet50。
- net9为自己设计的简单网络,一共有9层,其对应于代码中的Net1:
<p>
<img src="images/net9.png">
</p>
- net44也为自己设计的网络,一共有44层,其中使用了残差块,其对应于代码中的Net2:
<p>
<img src="images/net44.png">
</p>
- resnet18、resnet34以及resnet50为pytorch自带的预训练网络,其对应于代码中的Net3。
备注:
Linear(输入数据大小,输出数据大小)
Conv2d(输入通道数,输出通道数,卷积核大小,步长(默认为1))
MaxPool2d(池化核大小,步长)
# 脚本运行
启动resnet34训练脚本(评估会在每个epoch结束之后评估一次):
```python
python main.py -m 3 -t 34 -dp datasets所在路径
```
命令行参数解释:
| 参数简称 | 参数全称 | 参数含义 |
|------|---------------|-----------------------------------------------------------------------------------|
| -m | --model | 网络运行时选用的模型。1:选用net9;2:选用net44;3:选用resnet系列网络,默认值为3。 |
| -c | --channels | 训练时图片的通道数。1:使用单通道图片,3:使用三通道图片,默认值为3。 |
| -t | --type | resnet网络的类型,只有在使用resnet为backbone时才会用到。18:resnet18,34:resnet34,50:resnet50,默认值为34。 |
| -l | --lr_rate | 学习率,默认值为8e-05。 |
| -e | --epoch | epoch数,默认值为30。 |
| -p | --pth_name | 保存的权重名称。 |
| -bs | --batch_size | batch size数,默认值为32。 |
| -gm | --gamma | ExponentialLR中使用的超参gamma,默认值为0.99。 |
| -dp | --datasets_path | 数据集路径。 |
| -n | --num_classes | 分类数,默认值为2。 |
启动resnet34评估脚本:
```python
python detector.py -m 3 -dp datasets所在路径
```
命令行参数解释:
| 参数简称 | 参数全称 | 参数含义 |
|------|---------------|-----------------------------------------------------------------------------------|
| -m | --model | 网络运行时选用的模型。1:选用net9;2:选用net44;3:选用resnet系列网络,默认值为3。 |
| -c | --channels | 训练时图片的通道数。1:使用单通道图片,3:使用三通道图片,默认值为3。 |
| -t | --type | resnet网络的类型,只有在使用resnet为backbone时才会用到。18:resnet18,34:resnet34,50:resnet50,默认值为34。 |
| -p | --pth_name | 保存的权重名称。 |
| -dp | --datasets_path | 数据集路径。 |
| -n | --num_classes | 分类数,默认值为2。 |
# 实验结果
在学习率为8e-05、epoch为30以及batch size为32时,不同的网络不使用数据增强训练三通道图片的精度如下表所示:
| 网络 | 测试集最大acc |
|----------|----------|
| net9 | 93.42% |
| net44 | 93.21% |
| resnet18 | 99.15% |
| resnet34 | 99.58% |
| resnet50 | 99.79% |
# TensorRT加速
## 1.Python版本
### 加速环境
- Windows 11
- NVIDIA GeForce RTX 3060
- python 3.7.13
- TensorRT 8.4.1.5
- cuda 11.6
- cudnn 8.4
- onnxruntime 1.12.1
### pth权重文件转onnx权重文件
```python
python to_onnx.py --torch_file_path pth文件路径 --onnx_file_path onnx文件输出路径
```
### onnx文件转trt文件
使用TensorRT/bin目录下的trtexec进行模型转换(默认情况下转换为TF32数据格式):
```bash
.\trtexec.exe --onnx=onnx文件路径 --saveEngine=trt文件输出路径 --workspace=6000
```
float16转换:
```bash
.\trtexec.exe --onnx=onnx文件路径 --saveEngine=trt文件输出路径 --workspace=6000 --fp16
```
int8转换:
```bash
.\trtexec.exe --onnx=onnx文件路径 --saveEngine=trt文件输出路径 --workspace=6000 --int8
```
### python调用trt文件进行推理
```python
python trt_infer.py --trt_file_path trt文件路径 -dp datasets所在路径
```
### 实验结果
| 模型 | acc | 平均latency(单位s) | 文件大小(单位MB) |
|------------------------|------------|-------------------|-------------|
| resnet34-xxx.pth | 99.58% | 0.004614 | 81.3 |
| resnet34-xxx.trt(tf32) | 99.58% | 0.000749 | 115 |
| resnet34-xxx-fp16.trt | 99.58% | 0.000557 | 41.0 |
| resnet34-xxx-int8.trt | **99.79%** | **0.000452** | **20.7** |
### 该部分参考博客
https://zhuanlan.zhihu.com/p/467401558
https://zhuanlan.zhihu.com/p/371239130
https://zhuanlan.zhihu.com/p/527238167
## 2.C++版本
仍在实验当中
没有合适的资源?快使用搜索试试~ 我知道了~
人工智能与机器学习的课程实验并尝试用TensorRT加速.zip
共47个文件
py:14个
laptop-nkjuvedn:11个
png:6个
需积分: 5 0 下载量 189 浏览量
2024-04-18
12:39:52
上传
评论
收藏 234.24MB ZIP 举报
温馨提示
机器学习是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。它专门研究计算机如何模拟或实现人类的学习行为,以获取新的知识或技能,并重新组织已有的知识结构,从而不断改善自身的性能。机器学习是人工智能的核心,也是使计算机具有智能的根本途径。 应用: 机器学习在各个领域都有广泛的应用。在医疗保健领域,它可用于医疗影像识别、疾病预测、个性化治疗等方面。在金融领域,机器学习可用于风控、信用评分、欺诈检测以及股票预测。此外,在零售和电子商务、智能交通、生产制造等领域,机器学习也发挥着重要作用,如商品推荐、需求预测、交通流量预测、质量控制等。 优点: 机器学习模型能够处理大量数据,并在相对短的时间内产生可行且效果良好的结果。 它能够同时处理标称型和数值型数据,并可以处理具有缺失属性的样本。 机器学习算法如决策树,易于理解和解释,可以可视化分析,容易提取出规则。 一些机器学习模型,如随机森林或提升树,可以有效地解决过拟合问题。 缺点: 机器学习模型在处理某些特定问题时可能会出现过拟合或欠拟合的情况,导致预测结果不准确。 对于某些复杂的非线性问题,单一的机器学习算法可能难以有效地进行建模和预测。 机器学习模型的训练通常需要大量的数据和计算资源,这可能会增加实施成本和时间。 总的来说,机器学习虽然具有许多优点和应用领域,但也存在一些挑战和限制。在实际应用中,需要根据具体问题和需求选择合适的机器学习算法和模型,并进行适当的优化和调整。
资源推荐
资源详情
资源评论
收起资源包目录
人工智能与机器学习的课程实验并尝试用TensorRT加速.zip (47个子文件)
content
main.py 528B
exam1
__init__.py 0B
infer.py 1KB
runs
Layer5_with_conv_lr0.001_epoch20
events.out.tfevents.1666960689.LAPTOP-NKJUVEDN 101KB
Layer5_with_conv_lr0.001_epoch40
events.out.tfevents.1666961047.LAPTOP-NKJUVEDN 203KB
Layer5_with_conv_lr0.0005_epoch40
events.out.tfevents.1666961523.LAPTOP-NKJUVEDN 203KB
Layer5_without_conv_lr0.001_epoch20
events.out.tfevents.1666958753.LAPTOP-NKJUVEDN 101KB
Layer5_with_conv_lr0.0005_epoch10
events.out.tfevents.1666962273.LAPTOP-NKJUVEDN 51KB
Layer5_without_conv_lr0.0005_epoch20
events.out.tfevents.1666958376.LAPTOP-NKJUVEDN 101KB
Layer5_without_conv_lr0.0005_epoch40
events.out.tfevents.1666959045.LAPTOP-NKJUVEDN 203KB
Layer5_with_conv_lr0.0005_epoch20
events.out.tfevents.1666958097.LAPTOP-NKJUVEDN 101KB
Layer5_without_conv_lr0.001_epoch40
events.out.tfevents.1666960109.LAPTOP-NKJUVEDN 203KB
Layer5_without_conv_lr0.0005_epoch10
events.out.tfevents.1666962556.LAPTOP-NKJUVEDN 51KB
trainer.py 3KB
datasets
MNIST
processed
training.pt 45.32MB
test.pt 7.55MB
raw
train-images-idx3-ubyte 44.86MB
t10k-images-idx3-ubyte 7.48MB
t10k-labels-idx1-ubyte 10KB
train-images-idx3-ubyte.gz 0B
train-labels-idx1-ubyte 59KB
.keep 0B
build_trt_net.py 5KB
model
model_without_conv.pth 5.74MB
trt_infer.py 2KB
images
loss.png 115KB
网络一.png 18KB
网络二.png 20KB
acc.png 191KB
net.py 2KB
README.md 2KB
exam2
nets.py 5KB
__init__.py 0B
to_onnx.py 2KB
main.py 5KB
runs
resnet34_without_data_augmentation-lr8e-05-epoch30-channels3
events.out.tfevents.1671116262.LAPTOP-NKJUVEDN 75KB
detector.py 3KB
gen_datasets.py 3KB
models
resnet34_without_data_augmentation-lr8e-05-epoch30-channels3.onnx 81.18MB
resnet34_without_data_augmentation-lr8e-05-epoch30-channels3.pth 81.33MB
resnet34_without_data_augmentation-lr8e-05-epoch30-channels3-fp16.trt 41.05MB
resnet34_without_data_augmentation-lr8e-05-epoch30-channels3-int8.trt 20.76MB
trt_infer.py 5KB
images
net9.png 37KB
net44.png 64KB
README.md 6KB
README.md 168B
共 47 条
- 1
资源评论
普通网友
- 粉丝: 3908
- 资源: 7442
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功