# 基于深度学习Pytorch框架的文本分类
## 1、修改配置文件
在config.py文件中对常规参数进行修改。
```python
# data preprocess
data_path = 'dataset'
train_file = 'train.csv'
valid_file = 'valid.csv'
test_file = 'test.csv'
fix_length = 256
batch_size = 64
# data label list
label_list = ['低危', '中危', '高危', '超危']
class_number = len(label_list)
# train details
epochs = 30
learning_rate = 1e-3
```
其中,fix_length需要自己统计出数据集中的文本长度分布比例,进行合适的参数选择。batch_size需要根据显卡的显存大小进行参数选择。label_list需要换成对应数据集的标签列表。
## 2、选择合适的模型
在train.py中修改模型和模型命名。分别在97和98行进行修改。
![](./imgs/fig_1.png)
## 3、模型
- [x] TextCNN
- [x] TextRNN
- [x] TextRNN+Attention
- [x] TextRCNN
- [x] Transformer
- [ ] Some other attention
- [ ] Bert之类的预训练模型开在另一个仓库下
模型都直接定义在model/目录下,在forward最后返回的out的形状应该是[batch size, num_classes]这样的。
## 4、训练
模型中的超参数,例如hidden_size, multi_heads, n_layers需要自行修改。
修改完train.py中上述2处直接run就行,训练好的模型将保存在done_model/目录下。
![](./imgs/fig_2.png)
## 5、使用训练好的模型进行文本分类
修改Classify.py中 model = getModel('Transformer') ,模型改为之前训练好的模型名称。
然后直接run Classify.py就行。
![](./imgs/fig_3.png)
没有合适的资源?快使用搜索试试~ 我知道了~
基于深度学习框架pytorch实现的中文文本分类,目前包括textcnn,textrnn,textrcnn,text.zip
共16个文件
py:9个
png:3个
csv:2个
需积分: 5 0 下载量 161 浏览量
2024-02-04
20:18:17
上传
评论
收藏 17.77MB ZIP 举报
温馨提示
基于深度学习框架pytorch实现的中文文本分类,目前包括textcnn,textrnn,textrcnn,text
资源推荐
资源详情
资源评论
收起资源包目录
基于深度学习框架pytorch实现的中文文本分类,目前包括textcnn,textrnn,textrcnn,text.zip (16个子文件)
ahao2
DataSet.py 2KB
LICENSE 11KB
readme.md 2KB
dataset
test_pinggu.csv 12.01MB
train_pinggu.csv 47.73MB
Classify.py 1KB
model
TextRNN_Attention.py 2KB
TextRNN.py 1KB
TextRCNN.py 2KB
Transformer.py 5KB
TextCNN.py 1KB
Config.py 318B
train.py 4KB
imgs
fig_2.png 41KB
fig_1.png 25KB
fig_3.png 144KB
共 16 条
- 1
资源评论
码农阿豪
- 粉丝: 1w+
- 资源: 1754
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功