# 训练评估与模型优化指南
**目录**
* [Analysis模块介绍](#Analysis模块介绍)
* [环境准备](#环境准备)
* [模型评估](#模型评估)
* [可解释性分析](#可解释性分析)
* [单词级别可解释性分析](#单词级别可解释性分析)
* [句子级别可解释性分析](#句子级别可解释性分析)
* [数据优化](#数据优化)
* [稀疏数据筛选方案](#稀疏数据筛选方案)
* [脏数据清洗方案](#脏数据清洗方案)
* [数据增强策略方案](#数据增强策略方案)
## Analysis模块介绍
Analysis模块提供了**模型评估、可解释性分析、数据优化**等功能,旨在帮助开发者更好地分析文本分类模型预测结果和对模型效果进行优化。
- **模型评估:** 对整体分类情况和每个类别分别进行评估,并打印预测错误样本,帮助开发者分析模型表现找到训练和预测数据中存在的问题。
- **可解释性分析:** 基于[TrustAI](https://github.com/PaddlePaddle/TrustAI)提供单词和句子级别的模型可解释性分析,帮助理解模型预测结果。
- **数据优化:** 结合[TrustAI](https://github.com/PaddlePaddle/TrustAI)和[数据增强API](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/docs/dataaug.md)提供了**稀疏数据筛选、脏数据清洗、数据增强**三种优化策略,从多角度优化训练数据提升模型效果。
<div align="center">
<img src="https://user-images.githubusercontent.com/63761690/195241942-70068989-df17-4f53-9f71-c189d8c5c88d.png" width="600">
</div>
以下是本项目主要代码结构及说明:
```text
analysis/
├── evaluate.py # 评估脚本
├── sent_interpret.py # 句子级别可解释性分析脚本
├── word_interpret.py # 单词级别可解释性分析notebook
├── sparse.py # 稀疏数据筛选脚本
├── dirty.py # 脏数据清洗脚本
├── aug.py # 数据增强脚本
└── README.md # 训练评估与模型优化指南
```
## 环境准备
需要可解释性分析和数据优化需要安装相关环境。
- trustai >= 0.1.7
- interpretdl >= 0.7.0
**安装TrustAI**(可选)如果使用可解释性分析和数据优化中稀疏数据筛选和脏数据清洗需要安装TrustAI。
```shell
pip install trustai==0.1.7
```
**安装InterpretDL**(可选)如果使用词级别可解释性分析GradShap方法,需要安装InterpretDL
```shell
pip install interpretdl==0.7.0
```
## 模型评估
我们使用训练好的模型计算模型的在开发集的准确率,同时打印每个类别数据量及表现:
```shell
python evaluate.py \
--device "gpu" \
--dataset_dir "../data" \
--params_path "../checkpoint" \
--max_seq_length 128 \
--batch_size 32 \
--bad_case_file "bad_case.txt"
```
默认在GPU环境下使用,在CPU环境下修改参数配置为`--device "cpu"`
可支持配置的参数:
* `device`: 选用什么设备进行训练,可选择cpu、gpu、xpu、npu;默认为"gpu"。
* `dataset_dir`:必须,本地数据集路径,数据集路径中应包含dev.txt和label.txt文件;默认为None。
* `params_path`:保存训练模型的目录;默认为"../checkpoint/"。
* `max_seq_length`:分词器tokenizer使用的最大序列长度,ERNIE模型最大不能超过2048。请根据文本长度选择,通常推荐128、256或512,若出现显存不足,请适当调低这一参数;默认为128。
* `batch_size`:批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数;默认为32。
* `dev_file`:本地数据集中开发集文件名;默认为"dev.txt"。
* `label_file`:本地数据集中标签集文件名;默认为"label.txt"。
* `bad_case_path`:开发集中预测错误样本保存路径;默认为"/bad_case.txt"。
输出打印示例:
```text
[2022-08-11 03:10:14,058] [ INFO] - -----Evaluate model-------
[2022-08-11 03:10:14,059] [ INFO] - Dev dataset size: 1498
[2022-08-11 03:10:14,059] [ INFO] - Accuracy in dev dataset: 89.19%
[2022-08-11 03:10:14,059] [ INFO] - Macro avg in dev dataset: precision: 93.48 | recall: 93.26 | F1 score 93.22
[2022-08-11 03:10:14,059] [ INFO] - Micro avg in dev dataset: precision: 95.07 | recall: 95.46 | F1 score 95.26
[2022-08-11 03:10:14,095] [ INFO] - Level 1 Label Performance: Macro F1 score: 96.39 | Micro F1 score: 96.81 | Accuracy: 94.93
[2022-08-11 03:10:14,255] [ INFO] - Level 2 Label Performance: Macro F1 score: 92.79 | Micro F1 score: 93.90 | Accuracy: 89.72
[2022-08-11 03:10:14,256] [ INFO] - Class name: 交往
[2022-08-11 03:10:14,256] [ INFO] - Evaluation examples in dev dataset: 60(4.0%) | precision: 91.94 | recall: 95.00 | F1 score 93.44
[2022-08-11 03:10:14,256] [ INFO] - ----------------------------
[2022-08-11 03:10:14,256] [ INFO] - Class name: 交往##会见
[2022-08-11 03:10:14,256] [ INFO] - Evaluation examples in dev dataset: 12(0.8%) | precision: 92.31 | recall: 100.00 | F1 score 96.00
...
```
预测错误的样本保存在bad_case.txt文件中:
```text
Text Label Prediction
据猛龙随队记者JoshLewenberg报道,消息人士透露,猛龙已将前锋萨加巴-科纳特裁掉。此前他与猛龙签下了一份Exhibit10合同。在被裁掉后,科纳特下赛季大概率将前往猛龙的发展联盟球队效力。 组织关系,组织关系##加盟,组织关系##裁员 组织关系,组织关系##解雇
冠军射手被裁掉,欲加入湖人队,但湖人却无意,冠军射手何去何从 组织关系,组织关系##裁员 组织关系,组织关系##解雇
6月7日报道,IBM将裁员超过1000人。IBM周四确认,将裁减一千多人。据知情人士称,此次裁员将影响到约1700名员工,约占IBM全球逾34万员工中的0.5%。IBM股价今年累计上涨16%,但该公司4月发布的财报显示,一季度营收下降5%,低于市场预期。 组织关系,组织关系##裁员 组织关系,组织关系##裁员,财经/交易
有多名魅族员工表示,从6月份开始,魅族开始了新一轮裁员,重点裁员区域是营销和线下。裁员占比超过30%,剩余员工将不过千余人,魅族的知名工程师,爱讲真话的洪汉生已经从钉钉里退出了,外界传言说他去了OPPO。 组织关系,组织关系##退出,组织关系##裁员 组织关系,组织关系##裁员
...
```
## 可解释性分析
"模型为什么会预测出这个结果?"是文本分类任务开发者时常遇到的问题,如何分析错误样本(bad case)是文本分类任务落地中重要一环,本项目基于TrustAI开源了基于词级别和句子级别的模型可解释性分析方法,帮助开发者更好地理解文本分类模型与数据,有助于后续的模型优化与数据清洗标注。
### 单词级别可解释性分析
本项目开源模型的词级别可解释性分析Notebook,提供LIME、Integrated Gradient、GradShap 三种分析方法,支持分析微调后模型的预测结果,开发者可以通过更改**数据目录**和**模型目录**在自己的任务中使用Jupyter Notebook进行数据分析。
运行 [word_interpret.ipynb](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/applications/text_classification/hierarchical/analysis/README.md) 代码,即可分析影响样本预测结果的关键词以及可视化所有词对预测结果的贡献情况,颜色越深代表这个词对预测结果影响越大:
<div align="center">
<img src="https://user-images.githubusercontent.com/63761690/195334753-78cc2dc8-a5ba-4460-9fde-3b1bb704c053.png" width="1000">
</div>
### 句子级别可解释性分析
本项目基于特征相似度([FeatureSimilarity](https://arxiv.org/abs/2104.04128))算法,计算对样本预测结果正影响的训练数据,帮助理解模型的预测结果与训练集数据的关系。
待分析数据文件`interpret_input_file`应为�