#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@author: czh
@email:
@date: 2022/7/5 16:19
"""
# 无监督对比学习ConSERT
# 参考https://github.com/yym6472/ConSERT/blob/master/main.py
import os
import sys
import math
import logging
import torch
import numpy as np
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
sys.path.append('/data2/work2/chenzhihao/NLP')
from nlp.sentence_transformers import models, losses
from nlp.sentence_transformers import SentenceTransformer, LoggingHandler, SentencesDataset, InputExample
from nlp.sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SimilarityFunction
from nlp.processors.semantic_match_preprocessor import load_data, load_data_for_snli
logging.basicConfig(format='%(asctime)s - %(filename)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
def init_model(model_name, args):
# Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings
if args['no_dropout']:
word_embedding_model = models.Transformer(model_name, attention_probs_dropout_prob=0.0, hidden_dropout_prob=0.0)
else:
word_embedding_model = models.Transformer(model_name)
# Apply mean pooling to get one fixed sized sentence vector
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
pooling_mode_mean_tokens=True,
pooling_mode_cls_token=False,
pooling_mode_max_tokens=False)
if args['use_simsiam']:
projection_model = models.MLP3(hidden_dim=args['projection_hidden_dim'], norm=args['projection_norm_type'])
model = SentenceTransformer(modules=[word_embedding_model, pooling_model, projection_model],
device=args['device'])
else:
model = SentenceTransformer(modules=[word_embedding_model, pooling_model], device=args['device'])
model.max_seq_length = args['max_seq_length']
# Tensorboard writer
tensorboard_writer = SummaryWriter(args['tensorboard_log_dir'] or os.path.join(args['model_save_path'], "logs"))
model.tensorboard_writer = tensorboard_writer
return model
def prepare_datasets(data_dir, args, need_label=True, seg_tag='\t'):
dataset = load_data(data_dir, seg_tag=seg_tag) # noqa
data_samples = []
for data in dataset:
if args['data_type'] == "STS-B":
label = data[2] / 5.0
else:
label = data[2] if args['object_type'] == "classification" else float(data[2])
if need_label:
data_samples.append(InputExample(texts=[data[0], data[1]], label=label))
else:
if args['no_pair']:
data_samples.append(InputExample(texts=[data[0]]))
data_samples.append(InputExample(texts=[data[1]]))
else:
data_samples.append(InputExample(texts=[data[0], data[1]]))
return data_samples
def prepare_snli_datasets(data_dir, args, label2int, prefix='cnsd_snli_v1.0'):
train_data_path = os.path.join(data_dir, prefix+".train.json")
dev_data_path = os.path.join(data_dir, prefix+".dev.json")
test_data_path = os.path.join(data_dir, prefix+".test.json")
train_data_lst = load_data_for_snli(train_data_path, return_list=True)
dev_data_lst = load_data_for_snli(dev_data_path, return_list=True)
test_data_lst = load_data_for_snli(test_data_path, return_list=True)
all_datas = train_data_lst + dev_data_lst + test_data_lst
train_samples = []
for item in all_datas:
s1 = item['sentence1'].strip()
s2 = item['sentence2'].strip()
label = item['gold_label'].strip()
label_id = label2int[label]
if args['no_pair']:
assert args['cl_loss_only'], "no pair texts only used when contrastive loss only"
train_samples.append(InputExample(texts=[s1]))
train_samples.append(InputExample(texts=[s2]))
else:
train_samples.append(InputExample(texts=[s1, s2], label=label_id))
np.random.shuffle(train_samples)
return train_samples
def train(train_samples, model, dev_evaluator, args: dict): # noqa
train_dataset = SentencesDataset(train_samples, model=model)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args['train_batch_size'])
if args['adv_training'] and args['add_cl']:
train_loss = losses.AdvCLSoftmaxLoss(model=model,
sentence_embedding_dimension=model.get_sentence_embedding_dimension(),
num_labels=args['num_labels'],
concatenation_sent_max_square=args['concatenation_sent_max_square'],
use_adversarial_training=args['adv_training'],
noise_norm=args['noise_norm'],
adv_loss_stop_grad=args['adv_loss_stop_grad'],
adversarial_loss_rate=args['adv_loss_rate'],
use_contrastive_loss=args['add_cl'],
contrastive_loss_type=args['cl_type'],
contrastive_loss_rate=args['cl_rate'],
temperature=args['temperature'],
contrastive_loss_stop_grad=args['contrastive_loss_stop_grad'],
mapping_to_small_space=args['mapping_to_small_space'],
add_contrastive_predictor=args['add_contrastive_predictor'],
projection_hidden_dim=args['projection_hidden_dim'],
projection_use_batch_norm=args['projection_use_batch_norm'],
add_projection=args['add_projection'],
projection_norm_type=args['projection_norm_type'],
contrastive_loss_only=args['cl_loss_only'],
data_augmentation_strategy=args['data_augmentation_strategy'],
cutoff_direction=args['cutoff_direction'],
cutoff_rate=args['cutoff_rate'],
regularization_term_rate=args['regularization_term_rate'],
loss_rate_scheduler=args['loss_rate_scheduler'])
elif args['adv_training']:
train_loss = losses.AdvCLSoftmaxLoss(model=model,
sentence_embedding_dimension=model.get_sentence_embedding_dimension(),
num_labels=args['num_labels'],
concatenation_sent_max_square=args['concatenation_sent_max_square'],
use_adversarial_training=args['adv_training'],
noise_norm=args['noise_norm'],
adv_loss_stop_grad=args['adv_loss_stop_grad'],
adversarial_loss_rate=args['adv_loss_rate'])
elif args['add_cl']:
train_loss = losses.AdvCLSoftmaxLoss(model=model,
sentence_embedding_dimension=model.get_sentence_embedding_dimension(),
num_labels=args['num_labels'],
concatenation_sent_max_square=args['concatenation_sent_max_sq
NLP语义匹配.zip
版权申诉
23 浏览量
2023-08-26
10:48:45
上传
评论
收藏 68KB ZIP 举报
sjx_alo
- 粉丝: 1w+
- 资源: 1216
最新资源
- FloEFD 2021版案例教程-03 多孔介质
- 一款极好用的 Office/WPS/Word/Excel/PPT/PDF工具箱软件 OfficeUtils 3.1
- FloEFD 2021版案例教程-02 共轭传热
- java毕业设计+扫雷(程序)
- 一款极好用的 Office/WPS/Word/Excel/PPT/PDF工具箱软件 OfficeUtils 2.8
- 轻松学51单片机-基于普中科技开发板练习蓝桥杯及机器人大赛等(6-蜂鸣器)
- strawberry-perl-5.38.2.2-64bit.msi
- FloEFD 2021版案例教程-01 球阀设计
- MeyboMail Web(Java)简化版
- java(结合lucene)版的公交搜索系统
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈