# pytorch-classifier
image classifier implement in pytoch.
# Directory
1. **[Introduction](#Introduction)**
2. **[How to use](#Howtouse)**
3. **[Argument Explanation](#ArgumentExplanation)**
4. **[Model Zoo](#ModelZoo)**
5. **[Some explanation](#Someexplanation)**
6. **[TODO](#TODO)**
7. **[Reference](#Reference)**
8. **[Update Log](#Reference)**
<a id="Introduction"></a>
## Introduction
为什么推荐你使用这个代码?
- **丰富的可视化功能**
1. 训练图像可视化.
2. 损失函数,精度,学习率迭代图像可视化.
3. 热力图可视化.
4. TSNE可视化.
5. 数据集识别情况可视化.(metrice.py文件中--visual参数,开启可以自动把识别正确和错误的文件路径,类别,概率保存到csv中,方便后续分析)
6. 类别精度可视化.(可视化训练集,验证集,测试集中的总精度,混淆矩阵,每个类别的precision,recall,accuracy,f0.5,f1,f2,auc,aupr)
7. 总体精度可视化.(kappa,precision,recll,f1,accuracy,mpa)
- **丰富的模型库**
1. 由作者整合的丰富模型库,主流的模型基本全部支持,支持的模型个数高达50+,其全部支持ImageNet的预训练权重,[详细请看Model Zoo.(变形金刚系列后续更新)](#ModelZoo)
2. 目前支持的模型都是通过作者从github和torchvision整合,因此支持修改、改进模型进行实验,并不是直接调用库创建模型.
- **丰富的训练策略**
1. 支持断点续训,只需要设定一个参数(--resume).
2. 支持多种常见的损失函数.(目前支持PolyLoss,CrossEntropyLoss,FocalLoss)
3. 支持一个参数即可设置类别平衡.
4. 支持混合精度训练.(使你的机器能支持更大的batchsize)
5. 支持[知识蒸馏](Knowledge_Distillation.md).
- **丰富的数据增强策略**
1. 支持RandAugment, AutoAugment, TrivialAugmentWide, AugMix, Mixup, CutMix, CutOut, TTA等强大的数据增强.
2. 支持添加torchvision中的数据增强.
3. 支持添加自定义数据增强.[详细看Some explanation第十四点](#2)
- **丰富的学习率调整策略**
本程序支持学习率预热,支持预热后的自定义学习率策略.[详细看Some explanation第五点](#1)
- **支持导出各种常用推理框架模型**
目前支持导出torchscript,onnx,tensorrt推理模型.
<a id="6"></a>
- **简单的安装过程**
1. 安装好pytorch, torchvision(pytorch==1.12.0+ torchvision==0.13.0+)
可以在[pytorch](https://pytorch.org/get-started/previous-versions/)官网找到对应的命令进行安装.
2. pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
- **人性化的设定**
1. 大部分可视化数据(混淆矩阵,tsne,每个类别的指标)都会以csv或者log的格式保存到本地,方便后期美工图像.
2. 程序大部分输出信息使用PrettyTable进行美化输出,大大增加可观性.
<a id="Howtouse"></a>
## How to use
1. 安装程序所需的[环境](#6).
2. 根据[Some explanation中的第三点](#5)处理好数据集.
<a id="ArgumentExplanation"></a>
## Argument Explanation
- **main.py**
实现训练的主要程序.
参数解释:
- **model_name**
type: string, default: resnet18
选择的模型类型.
- **pretrained**
default: False
是否加载预训练权重.
- **weight**
type: string, default: ''
载入权重的路径.跟pretrained没有关系,pretrained是自动从网上下载权重载入到模型中.[详细解释请看Some explanation中的第十一点](#4)
- **config**
type: string, default: config/config.py
配置文件的路径.
- **device**
type: string, default: ''
使用的设备.(cuda device, i.e. 0 or 0,1,2,3 or cpu)
- **train_path**
type: string, default: dataset/train
训练集的路径.
- **val_path**
type: string, default: dataset/val
验证集的路径.
- **test_path**
type: string, default: dataset/test
测试集的路径.
- **label_path**
type: string, default: dataset/label.txt
标签的路径.
- **image_size**
type: int, default: 224
输入模型的图像尺寸大小.
- **image_channel**
type:int, default: 3
输入模型的图像通道大小.(目前只支持三通道)
- **workers**
type: int, default: 4
pytorch中的dataloader中的workers数量.
- **batch_size**
type: int, default: 64
单次训练所选取的样本个数,如果设置为-1则程序会计算当前使用的gpu最大的batch_size个数(占满gpu的百分之80),如果使用者电脑上没有gpu或者安装了cpu版本的pytorch,自动设置batch_size=16.
- **epoch**
type: int, default: 100
训练次数.
- **save_path**
type: string, default: runs/exp
用于保存训练过程、计算指标、测试过程的路径.
- **resume**
default: False
是否在save_path参数的路径中继续训练未完成的任务.
- **loss**
type: string, default: CrossEntropyLoss, choices: ['PolyLoss', 'CrossEntropyLoss', 'FocalLoss']
损失函数类型.
- **optimizer**
type: string, default: AdamW, choices: ['SGD', 'AdamW', 'RMSProp']
优化器类型.
- **lr**
type: float, default: 1e-3
学习率大小.
- **label_smoothing**
type: float, default: 0.1
损失函数中的标签平滑的值.
- **class_balance**
default: False
是否采用标签平衡.(使用sklearn中的compute_class_weight进行实现)
- **weight_decay**
type: float, default: 5e-4
权重正则化.
- **momentum**
type: float, default: 0.9
优化器中的动量参数.
- **amp**
default: False
是否使用混合精度训练.
- **warmup**
default: False
是否采用学习率预热.
- **warmup_ratios**
type: float, default: 0.05
学习率预热中的预热比例.(warmup_epochs=int(warmup_ratios * epoch))
- **warmup_minlr**
type: float, default: 1e-6
默认学习率调整中学习率最小值,如果warmup设置了为True,也是warmup学习率的初始值.
- **metrice**
type: string, default: acc, choices:['loss', 'acc', 'mean_acc']
根据metrice选择的指标来进行保存best.pt.
- **patience**
type: int, default:30
早停法中的patience.(设置为0即为不使用早停法)
- **imagenet_meanstd**
default:False
是否采用imagenet的均值和方差,False则使用当前训练集的均值和方差.
- **mixup**
type:string, default: none, choices:['mixup', 'cutmix', 'none']
mixup数据增强及其变种的选择.
- **Augment**
type: string, default: none, choices: ['RandAugment', 'AutoAugment', 'TrivialAugmentWide', 'AugMix', 'none']
数据增强类型,none则为不增强.
- **test_tta**
default: False
是否采用测试阶段的数据增强.
- **kd**
default: False
是否进行知识蒸馏.
- **kd_method**
type: string, default: SoftTarget, choices: ['SoftTarget', 'MGD', 'SP', 'AT']
知识蒸馏类型.
- **kd_ratio**
type: float, default: 0.7
知识蒸馏损失系数.
- **teacher_path**
type: string, default: ''
知识蒸馏中老师模型的路径.
- **rdrop**
default: False
是否采用R-Drop.(不支持知识蒸馏)
- **ema**
default: False
是否采用EMA.(不支持知识蒸馏)
- **metrice.py**
实现计算指标的主要程序.
参数解释:
- **train_path**
type: string, default: dataset/train
训练集的路径.
- **val_path**
type: string, default: dataset/val
验证集的路径.
- **test_path**
type: string, default: dataset/test
测试集的路径.
- **label_path**
type: string, default: dataset/label.txt
标签的路径.
- **device**
type: string, default: ''
使�
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
基于pytorch的垃圾分类,带训练模型和数据集的下载链接! 多达200类别-垃圾分类! 附带5种先进的图像分类网络! 代码支持知识蒸馏,里面有详细的教程! 代码里面还有50+种模型选择,支持对比实验,每个模型都支持Imagenet预训练权重,详细请看代码里面的Readme!!
资源推荐
资源详情
资源评论




















收起资源包目录

























































































共 80 条
- 1
资源评论

- m0_744044422023-05-24感谢资源主的分享,很值得参考学习,资源价值较高,支持!
- Plutoda2023-04-13资源和描述一致,质量不错,解决了我的问题,感谢资源主。


小风飞子
- 粉丝: 265
- 资源: 970

下载权益

C知道特权

VIP文章

课程特权

开通VIP
上传资源 快速赚钱
我的内容管理 收起
我的资源 快来上传第一个资源
我的收益
登录查看自己的收益我的积分 登录查看自己的积分
我的C币 登录后查看C币余额
我的收藏
我的下载
下载帮助


会员权益专享
安全验证
文档复制为VIP权益,开通VIP直接复制
