# 基于深度学习Pytorch框架的文本分类
## 1、修改配置文件
在config.py文件中对常规参数进行修改。
```python
# data preprocess
data_path = 'dataset'
train_file = 'train.csv'
valid_file = 'valid.csv'
test_file = 'test.csv'
fix_length = 256
batch_size = 64
# data label list
label_list = ['低危', '中危', '高危', '超危']
class_number = len(label_list)
# train details
epochs = 30
learning_rate = 1e-3
```
其中,fix_length需要自己统计出数据集中的文本长度分布比例,进行合适的参数选择。batch_size需要根据显卡的显存大小进行参数选择。label_list需要换成对应数据集的标签列表。
## 2、选择合适的模型
在train.py中修改模型和模型命名。分别在97和98行进行修改。
![](./imgs/fig_1.png)
## 3、模型
- [x] TextCNN
- [x] TextRNN
- [x] TextRNN+Attention
- [x] TextRCNN
- [x] Transformer
- [ ] Some other attention
- [ ] Bert之类的预训练模型开在另一个仓库下
模型都直接定义在model/目录下,在forward最后返回的out的形状应该是[batch size, num_classes]这样的。
## 4、训练
模型中的超参数,例如hidden_size, multi_heads, n_layers需要自行修改。
修改完train.py中上述2处直接run就行,训练好的模型将保存在done_model/目录下。
![](./imgs/fig_2.png)
## 5、使用训练好的模型进行文本分类
修改Classify.py中 model = getModel('Transformer') ,模型改为之前训练好的模型名称。
然后直接run Classify.py就行。
![](./imgs/fig_3.png)
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
1、该资源内项目代码经过严格调试,下载即用确保可以运行! 2、该资源适合计算机相关专业(如计科、人工智能、大数据、数学、电子信息等)正在做课程设计、期末大作业和毕设项目的学生、或者相关技术学习者作为学习资料参考使用。 3、该资源包括全部源码,需要具备一定基础才能看懂并调试代码。 基于深度学习框架pytorch实现的中文文本分类源码+项目说明(包括textcnn,textrnn,textrcnn,textrnn+attention,transformer).zip
资源推荐
资源详情
资源评论
收起资源包目录
基于深度学习框架pytorch实现的中文文本分类源码+项目说明(包括textcnn,textrnn,textrcnn,textrnn+attention,transformer).zip (16个子文件)
project_code_0628
DataSet.py 2KB
LICENSE 11KB
readme.md 2KB
dataset
test_pinggu.csv 12.01MB
train_pinggu.csv 47.73MB
Classify.py 1KB
model
TextRNN_Attention.py 2KB
TextRNN.py 1KB
TextRCNN.py 2KB
Transformer.py 5KB
TextCNN.py 1KB
Config.py 318B
train.py 4KB
imgs
fig_2.png 41KB
fig_1.png 25KB
fig_3.png 144KB
共 16 条
- 1
资源评论
- weixin_440013012024-03-12支持这个资源,内容详细,主要是能解决当下的问题,感谢大佬分享~
辣椒种子
- 粉丝: 3425
- 资源: 5723
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功