# 基于PyTorch实现的RankIQA
功能说明:
- 对图像质量进行对比排序(Rank)训练;
- 若图像质量有对应的单个质量评分,可对质量分进行回归预测(regress)训练;
- 若图像质量有多个质量评分,构成一个直方图分布,可对分布进行EMD拟合(emd)训练;
- 训练完毕的网络,可借助`application/convert.py`应用,将模型转为JIT格式;
- 最后,使用`demos/image_assessment.py`进行部署调用;
可用的骨干网络包括:
- [x] PyTorch自带的网络:resnet, shufflenet, densenet, mobilenet, mnasnet等;
- [x] MobileNet v3;
- [x] EfficientNet系列;
- [x] ResNeSt系列;
---
## 包含特性
- 图像的对比排序训练,损失函数为`RankingLoss`,调用了PyTorch的`nn.margin_rank_loss`;
- 图像预测评分和目标评分分布的推土距离损失,损失函数为`EMDLoss`,
损失函数参考了[NIMA: Neural Image Assessment](https://arxiv.org/abs/1709.05424)的目标函数;
- 图片单个评分的回归损失,这个支持两种预测结果(num_classes=1直接裸输出作为预测分 和 num_classes>1求softmax概率加权平均),
与单个评分的回归,用mse做损失函数。
---
## 文件结构说明
- `applications`: 包括`test.py, train.py, convert.py`等应用,提供给`main.py`调用;
- `checkpoints`: 训练好的模型文件保存目录(当前可能不存在);
- `criterions`: 自定义损失函数,目前主要包含 `ranking_loss`, `emd_loss`, `regress_loss`;
- `data`: 训练/测试/验证/预测等数据集存放的路径,数据集格式查看下面的`使用说明/数据准备`;
- `dataloader`: 自定义数据集`Dataset`子类和数据集加载`DataLoader`子类,保证在不同任务中加载不同格式的数据文件进行训练;
- `demos`: 模型使用的demo,目前`image_assessment.py`显示如何调用`jit`格式模型进行预测;
- `logs`: 训练过程中TensorBoard日志存放的文件(当前可能不存在);
- `models`: 自定义的模型结构;
- `optim`: 一些前沿的优化器,PyTorch官方还未实现,RAdam可尝试;
- `pretrained`: 预训练模型文件(当前可能不存在);
- `utils`: 工具脚本:图片数据校验`check_images.py`、日志`my_logger.py`、数据生成`data_generate`等;
- `config.py`: 配置文件;
- `main.py`: 总入口;
- `requirements.txt`: 工程依赖包列表;
---
## 使用说明
### 数据准备
在文件夹`data`下放数据,分成三个文件夹: `train/test/val`,对应 训练/测试/验证 数据文件夹;
每个子文件夹下,需要根据训练任务的不同放置训练图像(若需要修改数据读取方式等,可查看`data/my_dataloader`)。
数据准备完毕后,使用`utils/check_images.py`脚本,检查图像数据的有效性,防止在训练过程中遇到无效图片中止训练。
#### 排序对比损失任务
在每个数据集文件夹根目录下放置标签文件:如训练集`data/train/label.txt`,每行两张图像的项目相对路径,其中第一个图像质量>第二个图像质量,
内容举例如下:
```text
data/train/refimgs/carnivaldolls.bmp,data/train/gaussian_noise/gaussian_noise5/carnivaldolls.bmp
data/train/refimgs/carnivaldolls.bmp,data/train/gaussian_noise/gaussian_noise7/carnivaldolls.bmp
data/train/refimgs/carnivaldolls.bmp,data/train/gaussian_noise/gaussian_noise11/carnivaldolls.bmp
data/train/refimgs/carnivaldolls.bmp,data/train/gaussian_noise/gaussian_noise15/carnivaldolls.bmp
data/train/refimgs/carnivaldolls.bmp,data/train/gaussian_noise/gaussian_noise21/carnivaldolls.bmp
data/train/refimgs/carnivaldolls.bmp,data/train/white_noise/white_noise0/carnivaldolls.bmp
data/train/refimgs/carnivaldolls.bmp,data/train/white_noise/white_noise3/carnivaldolls.bmp
data/train/refimgs/carnivaldolls.bmp,data/train/white_noise/white_noise5/carnivaldolls.bmp
data/train/refimgs/carnivaldolls.bmp,data/train/white_noise/white_noise7/carnivaldolls.bmp
data/train/refimgs/carnivaldolls.bmp,data/train/white_noise/white_noise9/carnivaldolls.bmp
...
```
#### 分布的推土距离损失任务
在每个数据集文件夹根目录下放置标签文件:如训练集`data/train/label.txt`,每行:一张图像的项目相对路径,这张图的MOS概率分布,
内容举例如下:
```text
data/train/refimgs/carnivaldolls.bmp,0.,0.,0.,0.2,0.8
data/train/gaussian_noise/gaussian_noise7/carnivaldolls.bmp,0.,0.,0.,0.2,0.8
...
```
表示 图像`data/train/refimgs/carnivaldolls.bmp`的评分[1,2,3,4,5]对应的分布为[0.,0.,0.,0.2,0.8]。
注意评分是`1:N+1,N=num_classes`,这与我损失函数`criterions/emd_loss.py`的编写相关。
#### 回归损失任务
在每个数据集文件夹根目录下放置标签文件:如训练集`data/train/label.txt`,每行:一张图像的项目相对路径,这张图的MOS分值
内容举例如下:
```text
data/train/refimgs/carnivaldolls.bmp,7.6
data/train/gaussian_noise/gaussian_noise7/carnivaldolls.bmp,3.2
...
```
表示 图像`data/train/refimgs/carnivaldolls.bmp`的评分为3.2。
在训练时,如果配置`num_classes=1`,则损失函数计算为:||output - 3.2||;
配置`num_classes=N`,则损失函数计算为:||(output.softmax() * [1:N+1]).sum() - 3.2||;
详细可参考损失函数`criterions/regress_loss.py`。
### 部分重要配置参数说明
针对`config.py`里的部分重要参数说明如下:
- `--data`: 数据集根目录,下面包含`train`, `test`, `val`三个目录的数据集,默认当前文件夹下`data/`目录;
- `--image_size`: 输入应该为两个整数值,预训练模型的输入时正方形的,也就是[224, 224]之类的;
实际可以根据自己需要更改,数据预处理时,会将图像centercrop等比例部分,然后resize指定的输入尺寸。
- `--num_classes`: 模型的预测分支数;
- `-b`: 设置batch size大小,默认为256,可根据GPU显存设置;
- `-j`: 设置数据加载的进程数,默认为8,可根据CPU使用量设置;
- `--criterion`: 损失函数,`ranking_loss`, `emd_loss`, `regress_loss`三种;
- `--margin`: 配合`ranking_loss`使用,margin表示对比的差异间隔;
- `--lr`: 初始学习率,`main.py`里我默认使用Adam优化器;目前学习率的scheduler我使用的是`LambdaLR`接口,自定义函数规则如下,
详细可参考`main.py`的`adjust_learning_rate(epoch, args)`函数:
```
~ warmup: 0.1
~ warmup + int([1.5 * (epochs - warmup)]/4.0): 1,
~ warmup + int([2.5 * (epochs - warmup)]/4.0): 0.1
~ warmup + int([3.5 * (epochs - warmup)]/4.0) 0.01
~ epochs: 0.001
```
- `--warmup`: warmup的迭代次数,训练前warmup个epoch会将 初始学习率*0.1 作为warmup期间的学习率;
- `--epochs`: 训练的总迭代次数;
- `--resume`: 权重文件路径,模型文件将被加载以进行模型初始化,`--jit`和`--evaluation`时需要指定;
- `--jit`: 将模型转为JIT格式,利于部署;
- `--evaluation`: 在测试集上进行模型评估;
---
## 快速使用 —— 使用公开数据集AVA(aesthetic visual analysis)进行训练、测试、部署
下载数据[mtobeiyf/ava_downloader](https://github.com/mtobeiyf/ava_downloader)
---
## 使用说明
可参考对应的`z_task_shell/*.sh`文件
### 模型信息打印
打印分支数为`6`、输入图像分辨率`400x224`的`efficientnet-b0`网络的基本信息:
```shell
python main.py --arch efficientnet_b0 --num_classes 6 --image_size 244 224
```
### 模型训练
基于`data/`目录下的`train`数据集,使用分支数为`1`、输入图像分辨率`244x224`、损失函数为排序对比损失ranking loss(+ margin=0.1)
的`efficientnet-b0`网络,同时加载预训练模型、训练学习率warmup 5个epoch,batch size为384,
数据加载worker为16个,训练65个epoch:
```shell
python main.py --data data/ --train --arch efficientnet_b0 --num_classes 1 \
--criterion=r
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
基于PyTorch实现的图像质量评估模型RankIQA源码+说明(课程设计).zip 这是95分以上高分必过课程设计项目,下载即用无需修改,确保可以运行。也可作为期末大作业。 基于PyTorch实现的图像质量评估模型RankIQA源码+说明(课程设计).zip 这是95分以上高分必过课程设计项目,下载即用无需修改,确保可以运行。也可作为期末大作业。基于PyTorch实现的图像质量评估模型RankIQA源码+说明(课程设计).zip 这是95分以上高分必过课程设计项目,下载即用无需修改,确保可以运行。也可作为期末大作业。基于PyTorch实现的图像质量评估模型RankIQA源码+说明(课程设计).zip 这是95分以上高分必过课程设计项目,下载即用无需修改,确保可以运行。也可作为期末大作业。基于PyTorch实现的图像质量评估模型RankIQA源码+说明(课程设计).zip 这是95分以上高分必过课程设计项目,下载即用无需修改,确保可以运行。也可作为期末大作业。基于PyTorch实现的图像质量评估模型RankIQA源码+说明(课程设计).zip 这是95分以上高分必过课程设计项目,下
资源推荐
资源详情
资源评论
收起资源包目录
基于PyTorch实现的图像质量评估模型RankIQA.zip (58个子文件)
基于PyTorch实现的图像质量评估模型RankIQA
main.py 8KB
data
README.md 29B
utils
__init__.py 161B
data_generate
live_generator.py 6KB
__init__.py 48B
tid2013_generator.py 87B
my_meters.py 1KB
image_metrics.py 489B
my_summary.py 8KB
check_images.py 977B
my_logger.py 928B
z_task_shell
0_generate_live_dataset.sh 174B
4_main_evaluate.sh 359B
5_main_convert_jit.sh 231B
3_main_train_distributed.sh 384B
2_main_train_gpu.sh 336B
1_main_info.sh 104B
dataloader
__init__.py 77B
my_dataloader.py 7KB
image_rescale.py 4KB
criterions
__init__.py 240B
regress_loss.py 2KB
emd_loss.py 3KB
ranking_loss.py 2KB
pretrained
README.md 35B
requirements.txt 270B
logs
README.md 37B
models
__init__.py 93B
torch_models.py 2KB
mobilenetv3
__init__.py 86B
mobilenetv3.py 6KB
factory.py 2KB
resnest
__init__.py 157B
splat.py 3KB
ablation.py 4KB
resnet.py 13KB
resnest.py 2KB
efficientnet
utils.py 2KB
__init__.py 92B
model.py 6KB
components.py 5KB
factory.py 3KB
config.py 7KB
checkpoints
README.md 56B
.gitignore 2KB
applications
__init__.py 303B
convert.py 625B
train.py 5KB
test.py 2KB
demos
__init__.py 58B
images
a0bkfkutkf6_11_kid.jpg 43KB
a0bdb8qe4dr_50.5_low.jpg 32KB
a0ffursj21n_47.2_adult.jpg 41KB
a0c3e0mtkc3_11_man.jpg 22KB
a0ahjorwcpv_7_normal.jpg 55KB
image_assessment.py 3KB
README.md 10KB
config.py 4KB
共 58 条
- 1
资源评论
- Zzzll_2024-02-05发现一个超赞的资源,赶紧学习起来,大家一起进步,支持!
- m0_743970512024-04-28资源很赞,希望多一些这类资源。
- 佯谬2024-03-13资源不错,对我启发很大,获得了新的灵感,受益匪浅。
- 2301_800036382024-03-18感谢大佬分享的资源给了我灵感,果断支持!感谢分享~
- 2301_798642682024-04-28资源内容总结地很全面,值得借鉴,对我来说很有用,解决了我的燃眉之急。
程序员张小妍
- 粉丝: 1w+
- 资源: 2678
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功