# bertorch: 基于 pytorch 的 bert 实现和下游任务微调
bertorch 是一个基于 pytorch 进行 bert 实现和下游任务微调的工具,支持常用的自然语言处理任务,包括文本分类、文本匹配、语义理解和序列标注等。
## 1. 依赖环境
- Python >= 3.6
- torch >= 1.1
- argparse
- json
- loguru
- numpy
- packaging
- re
## 2. 文本分类
本项目展示了以 BERT 为代表的预训练模型如何 Finetune 完成文本分类任务。我们以中文情感分类公开数据集 ChnSentiCorp 为例,运行如下的命令,基于 DistributedDataParallel 进行单机多卡分布式训练,在训练集 (train.tsv) 上进行模型训练,并在验证集 (dev.tsv) 上进行评估:
```shell
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 run_classifier.py --train_data_file ./data/ChnSentiCorp/train.tsv --dev_data_file ./data/ChnSentiCorp/dev.tsv --label_file ./data/ChnSentiCorp/labels.txt --save_best_model --epochs 3 --batch_size 32
```
可支持的配置参数:
```
usage: run_classifier.py [-h] [--local_rank LOCAL_RANK]
[--pretrained_model_name_or_path PRETRAINED_MODEL_NAME_OR_PATH]
[--init_from_ckpt INIT_FROM_CKPT] --train_data_file
TRAIN_DATA_FILE [--dev_data_file DEV_DATA_FILE]
--label_file LABEL_FILE [--batch_size BATCH_SIZE]
[--scheduler {linear,cosine,cosine_with_restarts,polynomial,constant,constant_with_warmup}]
[--learning_rate LEARNING_RATE]
[--warmup_proportion WARMUP_PROPORTION] [--seed SEED]
[--save_steps SAVE_STEPS]
[--logging_steps LOGGING_STEPS]
[--weight_decay WEIGHT_DECAY] [--epochs EPOCHS]
[--max_seq_length MAX_SEQ_LENGTH]
[--saved_dir SAVED_DIR]
[--max_grad_norm MAX_GRAD_NORM] [--save_best_model]
[--is_text_pair]
```
- local_rank: 可选,分布式训练的节点编号,默认为 -1。
- pretrained_model_name_or_path: 可选,huggingface 中的预训练模型名称或路径,默认为 bert-base-chinese。
- train_data_file: 必选,训练集数据文件路径。
- dev_data_file: 可选,验证集数据文件路径,默认为 None。
- label_file: 必选,类别标签文件路径。
- batch_size: 可选,批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数。默认为 32。
- init_from_ckpt: 可选,要加载的模型参数路径,热启动模型训练。默认为None。
- scheduler: 可选,优化器学习率变化策略,默认为 linear。
- learning_rate: 可选,优化器的最大学习率,默认为 5e-5。
- warmup_proportion: 可选,学习率 warmup 策略的比例,如果为 0.1,则学习率会在前 10% 训练 step 的过程中从 0 慢慢增长到 learning_rate,而后再缓慢衰减。默认为 0。
- weight_decay: 可选,控制正则项力度的参数,用于防止过拟合,默认为 0.0。
- seed: 可选,随机种子,默认为1000。
- logging_steps: 可选,日志打印的间隔 steps,默认为 20。
- save_steps: 可选,保存模型参数的间隔 steps,默认为 100。
- epochs: 可选,训练轮次,默认为 3。
- max_seq_length: 可选,输入到预训练模型中的最大序列长度,最大不能超过 512,默认为 128。
- saved_dir: 可选,保存训练模型的文件夹路径,默认保存在当前目录的 checkpoint 文件夹下。
- max_grad_norm: 可选,训练过程中梯度裁剪的 max_norm 参数,默认为 1.0。
- save_best_model: 可选,是否在最佳验证集指标上保存模型,当训练命令中加入
--save_best_model 时,save_best_model 为 True,否则为 False。
- is_text_pair: 可选,是否进行文本对分类,当训练命令中加入 --is_text_pair 时,进行文本对的分类,否则进行普通文本分类。
模型训练的中间日志如下:
```python
2022-05-25 07:22:29.403 | INFO | __main__:train:301 - global step: 20, epoch: 1, batch: 20, loss: 0.23227, accuracy: 0.87500, speed: 2.12 step/s
2022-05-25 07:22:39.131 | INFO | __main__:train:301 - global step: 40, epoch: 1, batch: 40, loss: 0.30054, accuracy: 0.87500, speed: 2.06 step/s
2022-05-25 07:22:49.010 | INFO | __main__:train:301 - global step: 60, epoch: 1, batch: 60, loss: 0.23514, accuracy: 0.93750, speed: 2.02 step/s
2022-05-25 07:22:58.909 | INFO | __main__:train:301 - global step: 80, epoch: 1, batch: 80, loss: 0.12026, accuracy: 0.96875, speed: 2.02 step/s
2022-05-25 07:23:08.804 | INFO | __main__:train:301 - global step: 100, epoch: 1, batch: 100, loss: 0.21955, accuracy: 0.90625, speed: 2.02 step/s
2022-05-25 07:23:13.534 | INFO | __main__:train:307 - eval loss: 0.22564, accuracy: 0.91750
2022-05-25 07:23:25.222 | INFO | __main__:train:301 - global step: 120, epoch: 1, batch: 120, loss: 0.32157, accuracy: 0.90625, speed: 2.03 step/s
2022-05-25 07:23:35.104 | INFO | __main__:train:301 - global step: 140, epoch: 1, batch: 140, loss: 0.20107, accuracy: 0.87500, speed: 2.02 step/s
2022-05-25 07:23:44.978 | INFO | __main__:train:301 - global step: 160, epoch: 2, batch: 10, loss: 0.08750, accuracy: 0.96875, speed: 2.03 step/s
2022-05-25 07:23:54.869 | INFO | __main__:train:301 - global step: 180, epoch: 2, batch: 30, loss: 0.08308, accuracy: 1.00000, speed: 2.02 step/s
2022-05-25 07:24:04.754 | INFO | __main__:train:301 - global step: 200, epoch: 2, batch: 50, loss: 0.10256, accuracy: 0.93750, speed: 2.02 step/s
2022-05-25 07:24:09.480 | INFO | __main__:train:307 - eval loss: 0.22497, accuracy: 0.93083
2022-05-25 07:24:21.020 | INFO | __main__:train:301 - global step: 220, epoch: 2, batch: 70, loss: 0.23989, accuracy: 0.93750, speed: 2.03 step/s
2022-05-25 07:24:30.919 | INFO | __main__:train:301 - global step: 240, epoch: 2, batch: 90, loss: 0.00897, accuracy: 1.00000, speed: 2.02 step/s
2022-05-25 07:24:40.777 | INFO | __main__:train:301 - global step: 260, epoch: 2, batch: 110, loss: 0.13605, accuracy: 0.93750, speed: 2.03 step/s
2022-05-25 07:24:50.640 | INFO | __main__:train:301 - global step: 280, epoch: 2, batch: 130, loss: 0.14508, accuracy: 0.93750, speed: 2.03 step/s
2022-05-25 07:25:00.529 | INFO | __main__:train:301 - global step: 300, epoch: 2, batch: 150, loss: 0.04770, accuracy: 0.96875, speed: 2.02 step/s
2022-05-25 07:25:05.256 | INFO | __main__:train:307 - eval loss: 0.23039, accuracy: 0.93500
2022-05-25 07:25:16.818 | INFO | __main__:train:301 - global step: 320, epoch: 3, batch: 20, loss: 0.04312, accuracy: 0.96875, speed: 2.04 step/s
2022-05-25 07:25:26.700 | INFO | __main__:train:301 - global step: 340, epoch: 3, batch: 40, loss: 0.05103, accuracy: 0.96875, speed: 2.02 step/s
2022-05-25 07:25:36.588 | INFO | __main__:train:301 - global step: 360, epoch: 3, batch: 60, loss: 0.12114, accuracy: 0.87500, speed: 2.02 step/s
2022-05-25 07:25:46.443 | INFO | __main__:train:301 - global step: 380, epoch: 3, batch: 80, loss: 0.01080, accuracy: 1.00000, speed: 2.03 step/s
2022-05-25 07:25:56.228 | INFO | __main__:train:301 - global step: 400, epoch: 3, batch: 100, loss: 0.14839, accuracy: 0.96875, speed: 2.04 step/s
2022-05-25 07:26:00.953 | INFO | __main__:train:307 - eval loss: 0.22589, accuracy: 0.94083
2022-05-25 07:26:12.483 | INFO | __main__:train:301 - global step: 420, epoch: 3, batch: 120, loss: 0.14986, accuracy: 0.96875, speed: 2.05 step/s
2022-05-25 07:26:22.289 | INFO | __main__:train:301 - global step: 440, epoch: 3, batch: 140, loss: 0.00687, accuracy: 1.00000, speed: 2.04 step/s
```
当需要进行文本对分类时,仅需设置 is_text_pair 为 True。以 CLU
没有合适的资源?快使用搜索试试~ 我知道了~
资源推荐
资源详情
资源评论
收起资源包目录
100011822-基于 pytorch 的 bert 实现和下游任务微调.zip (79个子文件)
bertorch
run_simcse.py 13KB
data
ner
ontonote4
test.json 2.56MB
train.json 6.12MB
train.char.bmes 2.97MB
dev.char.bmes 1.22MB
labels.txt 74B
test.char.bmes 1.28MB
dev.json 2.46MB
msra
test.json 2.12MB
train.json 23.91MB
train.ner 11.88MB
test.ner 1.05MB
labels.txt 47B
dev.json 2.59MB
dev.ner 1.27MB
resume
.DS_Store 6KB
test.json 230KB
train.json 1.83MB
train.char.bmes 1.04MB
dev.char.bmes 117KB
labels.txt 128B
test.char.bmes 131KB
dev.json 208KB
weibo
test.all.bmes 92KB
test.json 184KB
train.json 911KB
dev.all.bmes 89KB
labels.txt 202B
train.all.bmes 451KB
dev.json 180KB
LCQMC
dev.txt 674KB
test.txt 758KB
train.txt 15.74MB
zhwiki
wiki_sents.txt 23.53MB
ChnSentiCorp
ChnSentiCorp.zip 1.86MB
train.tsv 2.94MB
dev.tsv 374KB
labels.txt 18B
test.tsv 364KB
STS-B
sts-b-dev.txt 181KB
CSNLI
dev.txt 1.03MB
labels.txt 33B
train.txt 57.69MB
tnews
train.tsv 3.91MB
dev.tsv 752KB
test.json 1.05MB
train.json 9.26MB
process.py 411B
labels.json 691B
test1.0.json 1.43MB
labels.txt 182B
dev.json 1.73MB
AFQMC
dev.txt 349KB
labels.txt 4B
test.txt 302KB
train.txt 2.72MB
CMNLI
dev.txt 1.88MB
labels.txt 33B
train.txt 61.93MB
batchneg.zip 56.89MB
LICENSE 1KB
predict.py 4KB
run_batchneg.py 13KB
run_ner.py 14KB
bertorch
semantic_model.py 9KB
utils.py 4KB
__init__.py 22B
pretraining.py 17KB
dataset.py 14KB
ner_utils.py 5KB
clas_model.py 3KB
modeling.py 30KB
activations.py 897B
optimization.py 15KB
tokenization.py 34KB
crf.py 20KB
README.md 31KB
run_classifier.py 12KB
run_sentencebert.py 13KB
共 79 条
- 1
资源评论
神仙别闹
- 粉丝: 2708
- 资源: 7668
下载权益
C知道特权
VIP文章
课程特权
开通VIP
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- Judging LLM-as-a-Judge with MT-Bench and Chatbot Arena阅读笔记
- (优作)PID-小车类-两轮自平衡小车资料(L298N 模块原理图及使用说明+c源码)
- 发电系统simulink仿真模型风力光伏发电太阳能电池发电系统按键控制步进电机程序
- 360数字安全:2024年4月勒索软件流行态势分析报告
- 发电系统simulink仿真模型风力光伏发电太阳能电池发电系统U盘读写文件程序
- 基于FPGA的LCD1602的流动显示VHDL.zip
- 基于Javascript实现的3D GIS,支持谷歌地图+必应地图+OpenStreetMap+搜索地图+天地图+源码+界面展示
- cmatrix数字雨安装脚本
- 联想storage-V5030混合存储系统调试教程
- Eclipse+Tomcat+SQLServer2008开发基于MVC框架的房地产信息管理系统+运用gis技术实现地产和地图结合
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功