# 简单的中文文本情感分类
一个用 PyTorch 实现的中文文本情感分类网络,代码较简单,功能较丰富,包含了多种模型 baseline。
## 环境需求
* python == 3.6
* torch == 1.1.0
* Intel(R) Xeon(R) CPU E5-2620 v4 @ 2.10GHz
* NVIDIA TITAN Xp
其余的见 `requirements.txt`
## 使用方法
先预处理,`./run_preprocess_word2vec.sh` 或 `./run_preprocess_elmo.sh 3`(3 是 gpu 编号)
然后运行 `python3 main.py --config_path config_cnn.json`
## 预处理
将所给文本的每个词转换成预训练模型的词向量后存到文件里。我分别尝试了这两种 embedding:
* ELMo 中文预训练模型,1024d(https://github.com/HIT-SCIR/ELMoForManyLangs)
* Chinese-Word-Vectors,300d(https://github.com/Embedding/Chinese-Word-Vectors)
请自行下载相应的模型文件到 `data/word2vec/` 或 `data/zhs.model` 文件夹下。
具体细节见 `preprocess.py` 文件,若想使用自己的数据集,修改该文件即可。
## 实现的模型
### MLP (2 layer)
Linear + ReLU + Dropout + Linear + Softmax
### CNN (1 layer) + MLP (2 layer)
Conv1d + ReLU + Dropout + MaxPool1d + Linear + ReLU + Dropout + Linear + Softmax
见这篇 paper [https://www.aclweb.org/anthology/D14-1181](https://www.aclweb.org/anthology/D14-1181)
[1] Kim, Y. (2014). Convolutional Neural Networks for Sentence Classification. Proceedings of the 2014 Conference on Empirical Methods in Natural Language Processing (EMNLP 2014), 1746–1751.
### RNN (1 layer) + Self-Attention + MLP (2 layer)
RNN (GRU or LSTM or bi-GRU or bi-LSTM) + Self-Attention + Linear + ReLU + Dropout + Linear + Softmax
Self-Attention 见这篇 paper [https://arxiv.org/pdf/1703.03130.pdf](https://arxiv.org/pdf/1703.03130.pdf)
[2] Zhouhan Lin, Minwei Feng, Cicero Nogueira dos Santos, Mo Yu, Bing Xiang, Bowen Zhou, and Yoshua Bengio. 2017. A structured self-attentive sentence embedding. In Proceedings of International Conference on Learning Representations.
## 某些参数的解释
* `seed`:`20000125` (保证结果可复现)
* `gpu`:`false` (使用 cpu),`true` (使用 nvidia 系 gpu,推荐)
* `output_path`:运行模型会将日志文件、TensorBoard 文件、配置文件生成到该目录下
* `loss`:`l1` (L1Loss) ,`mse` (MSELoss),`cross_entropy` (CrossEntropyLoss,推荐)
* `optimizer`:`sgd`,`adagrad` (Adagrad 自带了 L2 regularization,推荐)
* `embedding_size`:`1024` (ELMo),`300` (Chinese-Word-Vectors,较小,推荐)
* `type`:`mlp`,`cnn`,`rnn`
具体见 `config_mlp.json`、`config_cnn.json`、`config_rnn.json` 这些文件。
## 数据集
数据集用的是 THU 计算机系《人工智能导论》作业三的数据集,在这里我不方便公开数据集及其介绍。大概的介绍一下就是,中文新闻,8 种情感分类。
这个数据集我必须要说一下,数据集比较小,网民标注数据不太准确,训练集和测试集分布不太一样(训练集是 2012 年 1 月至 2 月 发布的 2,342 篇新闻文章,测试集是 2012 年 3 月至 4 月发布的 2,228 篇新闻文章),所以某些模型可能达不到预期的效果。
我从所给训练集数据取出后 1/10 作为 dev 数据集,每训一个 epoch 都测试一次 dev 数据集,然后取一个准确率最高的那个 epoch 的模型作为最终测试的模型。
## 实验结果
先放最好的结果(总共训练了 300 epoch,取 dev 数据集准确率最高的那个 epoch 来对 test 数据集进行测试得到的下表,总用时指的是训完 300 个 epoch 后的用时):
| 模型 | Accuracy(%) | F1(%) | CORR | 总用时 | 参数 |
| :------: | :---------: | :---------: | :--: | :--------------: | :--------------------------------------------------------: |
| ![](./doc/col_mlp.png) MLP | 59.4 | 21.5 | 0.28 | 7m44s | [save/mlp_1/config.json](./save/mlp_1/config.json) |
| ![](./doc/col_cnn.png) CNN | 62.4 | 30.2 | 0.41 | 9m56s | [save/cnn_5/config.json](./save/cnn_5/config.json) |
| ![](./doc/col_bi-lstm.png) bi-LSTM | 58.1 | 30.8 | 0.27 | 37m47s | [save/bi-lstm_1/config.json](./save/bi-lstm_1/config.json) |
| ![](./doc/col_bi-gru.png) bi-GRU | 57.3 | 26.3 | 0.31 | 34m47s | [save/bi-gru_1/config.json](./save/bi-gru_1/config.json) |
| ![](./doc/col_lstm.png) LSTM | 55.56 | 25.3 | 0.26 | 21m44s | [save/lstm_1/config.json](./save/lstm_1/config.json) |
| ![](./doc/col_gru.png) GRU | 51.3 | 25.3 | 0.26 | 20m41s | [save/gru_1/config.json](./save/gru_1/config.json) |
| MLP-ELMo | 58.1 | 21.9 | 0.21 | 18m26s | [save/mlp_3/config.json](./save/mlp_3/config.json) |
| CNN-ELMo | 59.8 | 30.1 | 0.34 | 14m23s | [save/cnn_6/config.json](./save/cnn_6/config.json) |
下图为 dev Accuracy
![](./doc/dev_Accuracy.svg)
下图为 dev F1 (macro)
![](./doc/dev_F1_macro.svg)
下图为 dev CORR
![](./doc/dev_CORR.svg)
下图为 train Accuracy
![](./doc/train_Accuracy.svg)
若想查看更多图表请运行以下命令
```
$ tensorboard --logdir MLP:save/mlp_1/runs/,\
CNN:save/cnn_5/runs/,\
bi-LSTM:save/bi-lstm_1/runs/,\
bi-GRU:save/bi-gru_1/runs/,\
LSTM:save/lstm_1/runs/,\
GRU:save/gru_1/runs/,\
MLP-ELMo:save/mlp_3/runs/,\
CNN-ELMo:save/cnn_6/runs/
```
## 模型与参数比较
由上图来看,MLP 收敛最快,最早进入过拟合,测试结果一般;RNN 系的模型收敛速度较快(bi-GRU除外),测试结果不是很好;CNN 的模型收敛较慢,特别稳定,测试结果特别好。
MLP 作为一个入门模型,效果却非常不错,甚至吊打 RNN 系,这其实挺迷的,我能想到的原因只有我参数没找对或者数据集有毒。而 CNN 我曾经尝试过用 2 层卷积层,但效果不如 1 层的,所以后来就放弃了。
优化器基本上都是用的 `adagrad`,`sgd` 根据调参结果基本不会再用了(收敛太慢了)。
```json
"optimizer": "adagrad",
"lr": 0.01,
"lr_decay": 0,
"weight_decay": 0.0001,
```
`ELMo` 的词向量经测试效果不太好,感觉应该是 pre-trained model 训的数据集和我们这个数据集分布差的有点多,然后我又没将 pre-trained model 接到我网络前面继续 fine tune 才导致的效果差?比较 [save/cnn_5](./save/cnn_5/config.json)(Chinese-Word-Vectors,橙色)和 [save/cnn_6](./save/cnn_6/config.json)(ELMo,蓝色)的图能看出在这个数据集上这个 ELMo 效果确实太好。
| dev Accuracy | train Accuracy |
| :------------------------: | :--------------------------: |
| ![](./doc/embed_dev.svg) | ![](./doc/embed_train.svg) |
`Dropout` 是为了防止过拟合,比较 [save/cnn_1](./save/cnn_1/config.json)(dropout = 0.5,橙色)和 [save/cnn_4](./save/cnn_4/config.json)(dropout = 0.9,蓝色)的图能明显看出,当 dropout 越大,收敛越慢,但准确率更高,防止过拟合能力越强(毕竟解耦能力强)。
| dev Accuracy | train Accuracy |
| :------------------------: | :--------------------------: |
| ![](./doc/dropout_dev.svg) | ![](./doc/dropout_train.svg) |
`Batch Normalization` 可以防止过拟合,比较卷积层里无 BN 无 Drouput 的 CNN [save/cnn_8](./save/cnn_8/config.json)(蓝色)和卷积层里有 BN 无 Drouput 的 CNN [save/cnn_7](./save/cnn_7/config.json)(Conv1d + BatchNorm1D + ReLU + MaxPool1d + Linear + ReLU + Dropout + Linear + Softmax,橙色)可以非常明显的看出。
| dev Accuracy | train Accuracy |
| :------------------------: | :--------------------------: |
| ![](./doc/bn_dev.svg) | ![](./
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
<项目介绍> 简单的中文文本情感分类 一个用 PyTorch 实现的中文文本情感分类网络,代码较简单,功能较丰富,包含了多种模型 baseline。 环境需求 python == 3.6 torch == 1.1.0 Intel(R) Xeon(R) CPU E5-2620 v4 @ 2.10GHz NVIDIA TITAN Xp 其余的见 requiremen - 不懂运行,下载完可以私聊问,可远程教学 该资源内项目源码是个人的毕设,代码都测试ok,都是运行成功后才上传资源,答辩评审平均分达到96分,放心下载使用! 1、该资源内项目代码都经过测试运行成功,功能ok的情况下才上传的,请放心下载使用! 2、本项目适合计算机相关专业(如计科、人工智能、通信工程、自动化、电子信息等)的在校学生、老师或者企业员工下载学习,也适合小白学习进阶,当然也可作为毕设项目、课程设计、作业、项目初期立项演示等。 3、如果基础还行,也可在此代码基础上进行修改,以实现其他功能,也可用于毕设、课设、作业等。 下载后请首先打开README.md文件(如有),仅供学习参考, 切勿用于商业用途。 --------
资源推荐
资源详情
资源评论
收起资源包目录
人工智能导论作业-基于python+PyTorch中文文本情感分类源码+文档说明(高分课程设计) (112个子文件)
events.out.tfevents.1559437928.gpu-theta.6850.0 402KB
events.out.tfevents.1559409302.gpu-theta.16485.0 402KB
events.out.tfevents.1559409399.gpu-theta.22755.0 402KB
events.out.tfevents.1559400227.gpu-theta.32041.0 402KB
events.out.tfevents.1559443831.gpu-theta.16155.0 402KB
events.out.tfevents.1559402877.gpu-theta.2069.0 402KB
events.out.tfevents.1559441032.gpu-theta.21582.0 402KB
events.out.tfevents.1559437980.gpu-theta.10531.0 402KB
events.out.tfevents.1559440265.gpu-theta.15460.0 402KB
events.out.tfevents.1559399753.gpu-theta.10281.0 402KB
events.out.tfevents.1559438821.gpu-theta.15937.0 402KB
events.out.tfevents.1559438048.gpu-theta.14643.0 402KB
events.out.tfevents.1559462040.gpu-theta.19995.0 402KB
events.out.tfevents.1559405473.gpu-theta.8556.0 402KB
events.out.tfevents.1559406000.gpu-theta.3523.0 402KB
events.out.tfevents.1559409541.gpu-theta.30631.0 402KB
events.out.tfevents.1559399750.gpu-theta.9979.0 402KB
events.out.tfevents.1559462006.gpu-theta.17687.0 402KB
sinanews.demo 10KB
.gitignore 3KB
.gitignore 71B
.gitignore 60B
ssc.iml 398B
config_cnn.json 627B
config_rnn.json 594B
config.json 519B
config.json 519B
config.json 518B
config.json 516B
config.json 514B
config.json 514B
config.json 512B
config.json 511B
config.json 510B
config.json 510B
config.json 509B
config.json 506B
config.json 499B
config.json 493B
config.json 493B
config_mlp.json 449B
config.json 395B
config.json 395B
config.json 392B
README.md 13KB
col_cnn.png 191B
col_bi-gru.png 189B
col_gru.png 184B
col_lstm.png 184B
col_bi-lstm.png 157B
col_mlp.png 134B
biLM.py 25KB
main.py 16KB
encoder_base.py 16KB
lstm_cell_with_projection.py 13KB
__main__.py 10KB
elmo.py 9KB
util.py 9KB
classify_layer.py 8KB
elmo.py 8KB
frontend.py 7KB
token_embedder.py 4KB
preprocess.py 4KB
highway.py 3KB
embedding_layer.py 2KB
dataloader.py 1KB
lstm.py 1KB
utils.py 408B
__init__.py 49B
__init__.py 0B
run_preprocess_word2vec.sh 257B
run_preprocess_elmo.sh 227B
run_mlp.sh 103B
run_rnn.sh 103B
run_cnn.sh 103B
train_Accuracy.svg 443KB
dropout_train.svg 157KB
self-attention_train.svg 157KB
bn_train.svg 156KB
embed_train.svg 156KB
dev_Accuracy.svg 150KB
dev_F1_macro.svg 150KB
dev_CORR.svg 146KB
bn_dev.svg 61KB
embed_dev.svg 57KB
dropout_dev.svg 57KB
self-attention_dev.svg 54KB
log.txt 554KB
log.txt 455KB
log.txt 313KB
log.txt 312KB
log.txt 312KB
log.txt 311KB
log.txt 311KB
log.txt 311KB
log.txt 310KB
log.txt 309KB
log.txt 306KB
log.txt 299KB
log.txt 299KB
共 112 条
- 1
- 2
资源评论
机智的程序员zero
- 粉丝: 2284
- 资源: 4469
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功