## 文本匹配任务(Text Matching)
文本匹配多用于计算两个文本之间的相似度,关于文本匹配的详细介绍在:[这里](https://zhuanlan.zhihu.com/p/585533302?)。
本项目对3种常用的文本匹配的方法进行实现:PointWise(单塔)、DSSM(双塔)、Sentence BERT(双塔)。
## 1. 环境安装
本项目基于 `pytorch` + `transformers` 实现,运行前请安装相关依赖包:
```sh
pip install -r ../../requirements.txt
```
## 2. 数据集准备
项目中提供了一部分示例数据,我们使用「商品评论」和「商品类别」来进行文本匹配任务,数据在 `data/comment_classify` 。
若想使用`自定义数据`训练,只需要仿照示例数据构建数据集即可:
```python
衣服:指穿在身上遮体御寒并起美化作用的物品。 为什么是开过的洗发水都流出来了、是用过的吗?是这样子包装的吗? 0
衣服:指穿在身上遮体御寒并起美化作用的物品。 开始买回来大很多 后来换了回来又小了 号码区别太不正规 建议各位谨慎 1
...
```
每一行用 `\t` 分隔符分开,第一部分部分为`商品类型(text1)`,中间部分为`商品评论(text2)`,最后一部分为`商品评论和商品类型是否一致(label)`。
## 3. 模型训练
### 3.1 PointWise(单塔)
#### 3.1.1 模型训练
修改训练脚本 `train_pointwise.sh` 里的对应参数, 开启模型训练:
```sh
python train_pointwise.py \
--model "nghuyong/ernie-3.0-base-zh" \ # backbone
--train_path "data/comment_classify/train.txt" \ # 训练集
--dev_path "data/comment_classify/dev.txt" \ #验证集
--save_dir "checkpoints/comment_classify" \ # 训练模型存放地址
--img_log_dir "logs/comment_classify" \ # loss曲线图保存位置
--img_log_name "ERNIE-PointWise" \ # loss曲线图保存文件夹
--batch_size 8 \
--max_seq_len 128 \
--valid_steps 50 \
--logging_steps 10 \
--num_train_epochs 10 \
--device "cuda:0"
```
正确开启训练后,终端会打印以下信息:
```sh
...
global step 10, epoch: 1, loss: 0.77517, speed: 3.43 step/s
global step 20, epoch: 1, loss: 0.67356, speed: 4.15 step/s
global step 30, epoch: 1, loss: 0.53567, speed: 4.15 step/s
global step 40, epoch: 1, loss: 0.47579, speed: 4.15 step/s
global step 50, epoch: 2, loss: 0.43162, speed: 4.41 step/s
Evaluation precision: 0.88571, recall: 0.87736, F1: 0.88152
best F1 performence has been updated: 0.00000 --> 0.88152
global step 60, epoch: 2, loss: 0.40301, speed: 4.08 step/s
global step 70, epoch: 2, loss: 0.37792, speed: 4.03 step/s
global step 80, epoch: 2, loss: 0.35343, speed: 4.04 step/s
global step 90, epoch: 2, loss: 0.33623, speed: 4.23 step/s
global step 100, epoch: 3, loss: 0.31319, speed: 4.01 step/s
Evaluation precision: 0.96970, recall: 0.90566, F1: 0.93659
best F1 performence has been updated: 0.88152 --> 0.93659
...
```
在 `logs/comment_classify` 文件下将会保存训练曲线图:
<img src='assets/pointwise_train_log.png'></img>
#### 3.1.2 模型推理
完成模型训练后,运行 `inference_pointwise.py` 以加载训练好的模型并应用:
```python
...
test_inference(
'手机:一种可以在较广范围内使用的便携式电话终端。', # 第一句话
'味道非常好,京东送货速度也非常快,特别满意。', # 第二句话
max_seq_len=128
)
...
```
运行推理程序:
```sh
python inference_pointwise.py
```
得到以下推理结果:
```sh
tensor([[ 1.8477, -2.0484]], device='cuda:0') # 两句话不相似(0)的概率更大
```
### 3.2 DSSM(双塔)
#### 3.2.1 模型训练
修改训练脚本 `train_dssm.sh` 里的对应参数, 开启模型训练:
```sh
python train_dssm.py \
--model "nghuyong/ernie-3.0-base-zh" \
--train_path "data/comment_classify/train.txt" \
--dev_path "data/comment_classify/dev.txt" \
--save_dir "checkpoints/comment_classify/dssm" \
--img_log_dir "logs/comment_classify" \
--img_log_name "ERNIE-DSSM" \
--batch_size 8 \
--max_seq_len 256 \
--valid_steps 50 \
--logging_steps 10 \
--num_train_epochs 10 \
--device "cuda:0"
```
正确开启训练后,终端会打印以下信息:
```sh
...
global step 0, epoch: 1, loss: 0.62319, speed: 15.16 step/s
Evaluation precision: 0.29912, recall: 0.96226, F1: 0.45638
best F1 performence has been updated: 0.00000 --> 0.45638
global step 10, epoch: 1, loss: 0.40931, speed: 3.64 step/s
global step 20, epoch: 1, loss: 0.36969, speed: 3.69 step/s
global step 30, epoch: 1, loss: 0.33927, speed: 3.69 step/s
global step 40, epoch: 1, loss: 0.31732, speed: 3.70 step/s
global step 50, epoch: 1, loss: 0.30996, speed: 3.68 step/s
...
```
在 `logs/comment_classify` 文件下将会保存训练曲线图:
<img src='assets/dssm_train_log.png'></img>
#### 3.2.2 模型推理
和单塔模型不一样的是,双塔模型可以事先计算所有候选类别的Embedding,当新来一个句子时,只需计算新句子的Embedding,并通过余弦相似度找到最优解即可。
因此,在推理之前,我们需要提前计算所有类别的Embedding并保存。
> 类别Embedding计算
运行 `get_embedding.py` 文件以计算对应类别embedding并存放到本地:
```python
...
text_file = 'data/comment_classify/types_desc.txt' # 候选文本存放地址
output_file = 'embeddings/comment_classify/dssm_type_embeddings.json' # embedding存放地址
device = 'cuda:0' # 指定GPU设备
model_type = 'dssm' # 使用DSSM还是Sentence Transformer
saved_model_path = './checkpoints/comment_classify/dssm/model_best/' # 训练模型存放地址
tokenizer = AutoTokenizer.from_pretrained(saved_model_path)
model = torch.load(os.path.join(saved_model_path, 'model.pt'))
model.to(device).eval()
...
```
其中,所有需要预先计算的内容都存放在 `types_desc.txt` 文件中。
文件用 `\t` 分隔,分别代表 `类别id`、`类别名称`、`类别描述`:
```txt
0 水果 指多汁且主要味觉为甜味和酸味,可食用的植物果实。
1 洗浴 洗浴用品。
2 平板 也叫便携式电脑,是一种小型、方便携带的个人电脑,以触摸屏作为基本的输入设备。
...
```
执行 `python get_embeddings.py` 命令后,会在代码中设置的embedding存放地址中找到对应的embedding文件:
```json
{
"0": {"label": "水果", "text": "水果:指多汁且主要味觉为甜味和酸味,可食用的植物果实。", "embedding": [0.3363891839981079, -0.8757723569869995, -0.4140555262565613, 0.8288457989692688, -0.8255823850631714, 0.9906797409057617, -0.9985526204109192, 0.9907819032669067, -0.9326567649841309, -0.9372553825378418, 0.11966298520565033, -0.7452883720397949,...]},
"1": ...,
...
}
```
> 模型推理
完成预计算后,接下来就可以开始推理了。
我们构建一条新评论:`这个破笔记本卡的不要不要的,差评`。
运行 `python inference_dssm.py`,得到下面结果:
```python
[
('平板', 0.9515482187271118),
('电脑', 0.8216977119445801),
('洗浴', 0.12220608443021774),
('衣服', 0.1199738010764122),
('手机', 0.07764233648777008),
('酒店', 0.044791921973228455),
('水果', -0.050112202763557434),
('电器', -0.07554933428764343),
('书籍', -0.08481660485267639),
('蒙牛', -0.16164332628250122)
]
```
函数将输出(类别,余弦相似度)的二元组,并按照相似度做倒排(相似度取值范围:[-1, 1])。
### 3.3 Sentence Transformer(双塔)
#### 3.3.1 模型训练
修改训练脚本 `train_sentence_transformer.sh` 里的对应参数, 开启模型训练:
```sh
python train_sentence_transformer.py \
--model "nghuyong/ernie-3.0-base-zh"
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
<项目介绍> 文本匹配任务(Text Matching) 文本匹配多用于计算两个文本之间的相似度,关于文本匹配的详细介绍在:这里。 本项目对3种常用的文本匹配的方法进行实现:PointWise(单塔)、DSSM(双塔)、Sentence BERT(双塔)。 1. 环境安装 本项目基于 pytorch + transformers 实现,运行前请安装相关依赖包: pip install -r ../../requirements.txt 2. 数据集准备 项 - 不懂运行,下载完可以私聊问,可远程教学 该资源内项目源码是个人的毕设,代码都测试ok,都是运行成功后才上传资源,答辩评审平均分达到96分,放心下载使用! 1、该资源内项目代码都经过测试运行成功,功能ok的情况下才上传的,请放心下载使用! 2、本项目适合计算机相关专业(如计科、人工智能、通信工程、自动化、电子信息等)的在校学生、老师或者企业员工下载学习,也适合小白学习进阶,当然也可作为毕设项目、课程设计、作业、项目初期立项演示等。 3、如果基础还行,也可在此代码基础上进行修改,以实现其他功能,也可用于毕设、课设、作业等。
资源推荐
资源详情
资源评论
收起资源包目录
text_match-master.zip (34个子文件)
text_match-master
unsupervised
__init__.py 0B
simcse
utils.py 8KB
__init__.py 0B
assets
ERNIE-ESimCSE.png 220KB
data
LCQMC
dev.tsv 674KB
test.tsv 734KB
train.txt 15.29MB
readme.md 4KB
iTrainingLogger.py 7KB
model.py 7KB
inference.py 4KB
train.sh 449B
train.py 12KB
supervised
utils.py 7KB
__init__.py 0B
train_pointwise.py 10KB
get_embedding.py 5KB
train_dssm.py 11KB
assets
dssm_train_log.png 176KB
pointwise_train_log.png 217KB
sentence_transformer_train_log.png 181KB
data
comment_classify
dev.txt 91KB
types_desc.txt 807B
train.txt 338KB
train_pointwise.sh 459B
readme.md 11KB
inference_pointwise.py 3KB
train_sentence_transformer.py 11KB
iTrainingLogger.py 7KB
model.py 16KB
train_sentence_transformer.sh 469B
inference_sentence_transformer.py 4KB
train_dssm.sh 482B
inference_dssm.py 4KB
共 34 条
- 1
资源评论
.Android安卓科研室.
- 粉丝: 4417
- 资源: 2453
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功