# pytorch-classifier
image classifier implement in pytoch.
# Directory
1. **[Introduction](#Introduction)**
2. **[How to use](#Howtouse)**
3. **[Argument Explanation](#ArgumentExplanation)**
4. **[Model Zoo](#ModelZoo)**
5. **[Some explanation](#Someexplanation)**
6. **[TODO](#TODO)**
7. **[Reference](#Reference)**
8. **[Update Log](#Reference)**
<a id="Introduction"></a>
## Introduction
为什么推荐你使用这个代码?
- **丰富的可视化功能**
1. 训练图像可视化.
2. 损失函数,精度,学习率迭代图像可视化.
3. 热力图可视化.
4. TSNE可视化.
5. 数据集识别情况可视化.(metrice.py文件中--visual参数,开启可以自动把识别正确和错误的文件路径,类别,概率保存到csv中,方便后续分析)
6. 类别精度可视化.(可视化训练集,验证集,测试集中的总精度,混淆矩阵,每个类别的precision,recall,accuracy,f0.5,f1,f2,auc,aupr)
7. 总体精度可视化.(kappa,precision,recll,f1,accuracy,mpa)
- **丰富的模型库**
1. 由作者整合的丰富模型库,主流的模型基本全部支持,支持的模型个数高达50+,其全部支持ImageNet的预训练权重,[详细请看Model Zoo.(变形金刚系列后续更新)](#ModelZoo)
2. 目前支持的模型都是通过作者从github和torchvision整合,因此支持修改、改进模型进行实验,并不是直接调用库创建模型.
- **丰富的训练策略**
1. 支持断点续训,只需要设定一个参数(--resume).
2. 支持多种常见的损失函数.(目前支持PolyLoss,CrossEntropyLoss,FocalLoss)
3. 支持一个参数即可设置类别平衡.
4. 支持混合精度训练.(使你的机器能支持更大的batchsize)
5. 支持[知识蒸馏](Knowledge_Distillation.md).
- **丰富的数据增强策略**
1. 支持RandAugment, AutoAugment, TrivialAugmentWide, AugMix, Mixup, CutMix, CutOut, TTA等强大的数据增强.
2. 支持添加torchvision中的数据增强.
3. 支持添加自定义数据增强.[详细看Some explanation第十四点](#2)
- **丰富的学习率调整策略**
本程序支持学习率预热,支持预热后的自定义学习率策略.[详细看Some explanation第五点](#1)
- **支持导出各种常用推理框架模型**
目前支持导出torchscript,onnx,tensorrt推理模型.
<a id="6"></a>
- **简单的安装过程**
1. 安装好pytorch, torchvision(pytorch==1.12.0+ torchvision==0.13.0+)
可以在[pytorch](https://pytorch.org/get-started/previous-versions/)官网找到对应的命令进行安装.
2. pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
- **人性化的设定**
1. 大部分可视化数据(混淆矩阵,tsne,每个类别的指标)都会以csv或者log的格式保存到本地,方便后期美工图像.
2. 程序大部分输出信息使用PrettyTable进行美化输出,大大增加可观性.
<a id="Howtouse"></a>
## How to use
1. 安装程序所需的[环境](#6).
2. 根据[Some explanation中的第三点](#5)处理好数据集.
<a id="ArgumentExplanation"></a>
## Argument Explanation
- **main.py**
实现训练的主要程序.
参数解释:
- **model_name**
type: string, default: resnet18
选择的模型类型.
- **pretrained**
default: False
是否加载预训练权重.
- **weight**
type: string, default: ''
载入权重的路径.跟pretrained没有关系,pretrained是自动从网上下载权重载入到模型中.[详细解释请看Some explanation中的第十一点](#4)
- **config**
type: string, default: config/config.py
配置文件的路径.
- **device**
type: string, default: ''
使用的设备.(cuda device, i.e. 0 or 0,1,2,3 or cpu)
- **train_path**
type: string, default: dataset/train
训练集的路径.
- **val_path**
type: string, default: dataset/val
验证集的路径.
- **test_path**
type: string, default: dataset/test
测试集的路径.
- **label_path**
type: string, default: dataset/label.txt
标签的路径.
- **image_size**
type: int, default: 224
输入模型的图像尺寸大小.
- **image_channel**
type:int, default: 3
输入模型的图像通道大小.(目前只支持三通道)
- **workers**
type: int, default: 4
pytorch中的dataloader中的workers数量.
- **batch_size**
type: int, default: 64
单次训练所选取的样本个数,如果设置为-1则程序会计算当前使用的gpu最大的batch_size个数(占满gpu的百分之80),如果使用者电脑上没有gpu或者安装了cpu版本的pytorch,自动设置batch_size=16.
- **epoch**
type: int, default: 100
训练次数.
- **save_path**
type: string, default: runs/exp
用于保存训练过程、计算指标、测试过程的路径.
- **resume**
default: False
是否在save_path参数的路径中继续训练未完成的任务.
- **loss**
type: string, default: CrossEntropyLoss, choices: ['PolyLoss', 'CrossEntropyLoss', 'FocalLoss']
损失函数类型.
- **optimizer**
type: string, default: AdamW, choices: ['SGD', 'AdamW', 'RMSProp']
优化器类型.
- **lr**
type: float, default: 1e-3
学习率大小.
- **label_smoothing**
type: float, default: 0.1
损失函数中的标签平滑的值.
- **class_balance**
default: False
是否采用标签平衡.(使用sklearn中的compute_class_weight进行实现)
- **weight_decay**
type: float, default: 5e-4
权重正则化.
- **momentum**
type: float, default: 0.9
优化器中的动量参数.
- **amp**
default: False
是否使用混合精度训练.
- **warmup**
default: False
是否采用学习率预热.
- **warmup_ratios**
type: float, default: 0.05
学习率预热中的预热比例.(warmup_epochs=int(warmup_ratios * epoch))
- **warmup_minlr**
type: float, default: 1e-6
默认学习率调整中学习率最小值,如果warmup设置了为True,也是warmup学习率的初始值.
- **metrice**
type: string, default: acc, choices:['loss', 'acc', 'mean_acc']
根据metrice选择的指标来进行保存best.pt.
- **patience**
type: int, default:30
早停法中的patience.(设置为0即为不使用早停法)
- **imagenet_meanstd**
default:False
是否采用imagenet的均值和方差,False则使用当前训练集的均值和方差.
- **mixup**
type:string, default: none, choices:['mixup', 'cutmix', 'none']
mixup数据增强及其变种的选择.
- **Augment**
type: string, default: none, choices: ['RandAugment', 'AutoAugment', 'TrivialAugmentWide', 'AugMix', 'none']
数据增强类型,none则为不增强.
- **test_tta**
default: False
是否采用测试阶段的数据增强.
- **kd**
default: False
是否进行知识蒸馏.
- **kd_method**
type: string, default: SoftTarget, choices: ['SoftTarget', 'MGD', 'SP', 'AT']
知识蒸馏类型.
- **kd_ratio**
type: float, default: 0.7
知识蒸馏损失系数.
- **teacher_path**
type: string, default: ''
知识蒸馏中老师模型的路径.
- **rdrop**
default: False
是否采用R-Drop.(不支持知识蒸馏)
- **ema**
default: False
是否采用EMA.(不支持知识蒸馏)
- **metrice.py**
实现计算指标的主要程序.
参数解释:
- **train_path**
type: string, default: dataset/train
训练集的路径.
- **val_path**
type: string, default: dataset/val
验证集的路径.
- **test_path**
type: string, default: dataset/test
测试集的路径.
- **label_path**
type: string, default: dataset/label.txt
标签的路径.
- **device**
type: string, default: ''
使�
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
基于pytorch的垃圾分类,带训练模型和数据集的下载链接! 多达200类别-垃圾分类! 附带5种先进的图像分类网络! 代码支持知识蒸馏,里面有详细的教程! 代码里面还有50+种模型选择,支持对比实验,每个模型都支持Imagenet预训练权重,详细请看代码里面的Readme!!
资源推荐
资源详情
资源评论
收起资源包目录
基于pytorch的垃圾分类,带训练模型和数据集的下载链接!.zip (80个子文件)
基于pytorch的垃圾分类,带训练模型和数据集的下载链接!
processing.py 2KB
repghostnet_1_0x.log 6.84MB
main.py 13KB
LICENSE 1KB
export.py 6KB
predict.py 5KB
vovnet39.log 6.76MB
utils
utils.py 38KB
__init__.py 0B
utils_loss.py 3KB
utils_model.py 11KB
utils_aug.py 6KB
utils_distill.py 4KB
utils_fit.py 5KB
__pycache__
utils_model.cpython-38.pyc 3KB
utils_aug.cpython-38.pyc 6KB
utils_fit.cpython-38.pyc 3KB
utils_fit.cpython-39.pyc 3KB
utils.cpython-39.pyc 21KB
utils.cpython-38.pyc 34KB
utils_loss.cpython-38.pyc 4KB
utils_distill.cpython-38.pyc 5KB
utils_aug.cpython-39.pyc 6KB
__init__.cpython-38.pyc 137B
utils_model.cpython-39.pyc 3KB
figure
custom_lr.png 22KB
warmup_lr.png 27KB
base_lr.png 24KB
warmup_custom_lr.png 26KB
metrice.py 6KB
commad.txt 2KB
model
__init__.py 404B
convnext.py 10KB
cspnet.py 39KB
repghost.py 19KB
dpn.py 12KB
vovnet.py 10KB
shufflenetv2.py 10KB
ghostnet.py 9KB
mobilenetv2.py 9KB
repvgg.py 18KB
mobilenetv3.py 14KB
densenet.py 14KB
resnet.py 17KB
sequencer.py 18KB
vgg.py 10KB
efficientnetv2.py 44KB
resnest.py 21KB
__pycache__
vovnet.cpython-38.pyc 8KB
ghostnet.cpython-38.pyc 8KB
resnet.cpython-38.pyc 14KB
densenet.cpython-38.pyc 12KB
mobilenetv2.cpython-38.pyc 7KB
vgg.cpython-38.pyc 9KB
sequencer.cpython-38.pyc 15KB
mobilenetv3.cpython-38.pyc 11KB
repvgg.cpython-38.pyc 13KB
__init__.cpython-39.pyc 466B
resnest.cpython-38.pyc 14KB
shufflenetv2.cpython-39.pyc 8KB
convnext.cpython-38.pyc 9KB
dpn.cpython-38.pyc 10KB
shufflenetv2.cpython-38.pyc 9KB
efficientnet.cpython-38.pyc 13KB
efficientnetv2.cpython-38.pyc 31KB
repghost.cpython-38.pyc 14KB
mnasnet.cpython-38.pyc 10KB
cspnet.cpython-38.pyc 27KB
__init__.cpython-38.pyc 479B
mnasnet.py 13KB
requirements.txt 627B
Knowledge_Distillation.md 19KB
efficientnet_v2_s.log 4.93MB
shufflenet_v2.log 7.04MB
README.md 34KB
convnext_base.log 5.23MB
config
sgd_config.py 830B
__pycache__
sgd_config.cpython-38.pyc 1KB
config.cpython-38.pyc 1KB
config.py 985B
共 80 条
- 1
小风飞子
- 粉丝: 321
- 资源: 1496
下载权益
C知道特权
VIP文章
课程特权
开通VIP
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- 探索tecreate:软件开发的未来之星.zip
- 打标机项目C#源码连接扫码
- 基于SSM的房屋租赁系统的设计与实现
- xyctf:从入门到精通的实用指南.zip
- mmqrcode1714153659780.png
- Screenshot_2024-04-27-06-08-58-486_com.baidu.xin.aiqicha.jpg
- 基于Javaweb+Tomcat+MySQL的大学生公寓管理系统+sql文件.zip
- 实训作业基于javaweb的订单管理系统源码+数据库+实训报告.zip
- 多机调度问题贪心算法基于最小堆和贪心算法求解多机调度问题.zip
- 基于同态加密技术的匿名电子投票系统源码.zip
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功
- 1
- 2
- 3
前往页