# -*- coding: utf8 -*-
"""
======================================
Project Name: NLP
File Name: mrc_for_ner
Author: czh
Create Date: 2022/2/23
--------------------------------------
Change Activity:
======================================
"""
import os
import sys
import json
from tqdm import tqdm
import codecs
from typing import List, Tuple
sys.path.append("/data/chenzhihao/NLP")
import numpy as np
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from transformers import BertConfig, BertTokenizerFast
from nlp.models.bert_for_ner import BertQueryNER
from nlp.tools.path import project_root_path
root_path = project_root_path()
# bert_model_name_or_path = "/Users/czh/Downloads/chinese-roberta-ext"
bert_model_name_or_path = "/data/chenzhihao/chinese-roberta-ext"
data_dir = "datas/cluener"
output_dir = "output_file_dir/mrc"
max_sequence_length = 512
batch_size = 10
epochs = 30
lr_rate = 2e-5
gradient_accumulation_steps = 1
logging_steps = 500
num_worker = 0
warmup_ratio = 0.1
weight_start = 1.0
weight_end = 1.0
weight_span = 1.0
span_loss_candidates = "all" # ["all", "pred_and_gold","pred_gold_random","gold"],Candidates used to compute span loss
device = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu')
tokenizer = BertTokenizerFast.from_pretrained(bert_model_name_or_path, do_lower_case=False, add_special_tokens=True)
bert_config = BertConfig.from_pretrained(bert_model_name_or_path)
bce_loss = nn.BCEWithLogitsLoss(reduction='none')
query_map = {
'address': "找出省、市、区、街道乡村等抽象或具体的地点",
'book': "找出小说、杂志、习题集、教科书、教辅、地图册、食谱等具体的书名",
'company': "找出公司、集团、银行(央行,中国人民银行除外,二者属于政府机构)等具体的公司名",
'game': "找出常见的游戏名",
'government': "找出中央行政机关和地方行政机关的名字",
'movie': "找出电影、纪录片等放映或上线的影片名字",
'name': "找出真实和虚构的人名",
'organization': "找出包括篮球队、足球队、乐团、社团、小说里面的帮派等真实或虚构的组织机构名",
'position': "找出现代和古时候的职称名",
'scene': "找出常见的旅游景点"
}
def trans_data_to_mrc(input_file):
mrc_samples = []
with codecs.open(input_file, encoding="utf8") as fr:
for line in fr:
line = line.strip()
if not line:
continue
item = json.loads(line)
text = item['text']
label_dict = item['label']
for label, query in query_map.items():
start_positions = []
end_positions = []
entity_dict = label_dict.get(label, None)
if not entity_dict:
continue
for k, offsets in entity_dict.items():
for s, e in offsets:
start_positions.append(s)
end_positions.append(e+1)
mrc_samples.append(
{
"context": text,
"start_position": start_positions,
"end_position": end_positions,
"query": query
}
)
return mrc_samples
class MRCNERDataset(Dataset):
"""
MRC NER Dataset
Args:
datasets:
max_length: int, max length of query+context
possible_only: if True, only use possible samples that contain answer for the query/context
is_chinese: is chinese dataset
"""
def __init__(self, datasets, max_length: int = 512, possible_only=False,
is_chinese=False, pad_to_maxlen=False):
self.all_data = datasets
self.max_length = max_length
self.possible_only = possible_only
if self.possible_only:
self.all_data = [
x for x in self.all_data if x["start_position"]
]
self.is_chinese = is_chinese
self.pad_to_maxlen = pad_to_maxlen
def __len__(self):
return len(self.all_data)
def __getitem__(self, item):
"""
Args:
item: int, idx
Returns:
tokens: tokens of query + context, [seq_len]
token_type_ids: token type ids, 0 for query, 1 for context, [seq_len]
start_labels: start labels of NER in tokens, [seq_len]
end_labels: end labelsof NER in tokens, [seq_len]
label_mask: label mask, 1 for counting into loss, 0 for ignoring. [seq_len]
match_labels: match labels, [seq_len, seq_len]
sample_idx: sample id
label_idx: label id
"""
data = self.all_data[item]
qas_id = data.get("qas_id", "0.0")
sample_idx, label_idx = qas_id.split(".")
sample_idx = torch.LongTensor([int(sample_idx)])
label_idx = torch.LongTensor([int(label_idx)])
query = data["query"]
context = data["context"]
start_positions = data["start_position"]
end_positions = data["end_position"]
# TODO: 修改tokenizer
query_context_tokens = tokenizer.encode_plus(query, context, add_special_tokens=True, return_offsets_mapping=True)
tokens = query_context_tokens['input_ids']
attention_masks = query_context_tokens['attention_mask']
type_ids = query_context_tokens['token_type_ids']
offsets = query_context_tokens['offset_mapping']
# find new start_positions/end_positions, considering
# 1. we add query tokens at the beginning
# 2. word-piece tokenize
origin_offset2token_idx_start = {}
origin_offset2token_idx_end = {}
for token_idx in range(len(tokens)):
# skip query tokens
if type_ids[token_idx] == 0:
continue
token_start, token_end = offsets[token_idx][0], offsets[token_idx][1]
# skip [CLS] or [SEP]
if token_start == token_end == 0:
continue
origin_offset2token_idx_start[token_start] = token_idx
origin_offset2token_idx_end[token_end] = token_idx
new_start_positions = []
new_end_positions = []
for s, e in zip(start_positions, end_positions):
try:
new_s = origin_offset2token_idx_start[s]
new_e = origin_offset2token_idx_end[e]
except Exception as exc:
print((s, e), offsets)
print(origin_offset2token_idx_start)
print(origin_offset2token_idx_end)
print(tokenizer.tokenize(query, context, add_special_tokens=True))
print(exc)
continue
if 0 < new_s <= new_e < len(offsets):
new_start_positions.append(new_s)
new_end_positions.append(new_e)
else:
print((s, e), (new_s, new_e), query, context)
label_mask = [
(0 if type_ids[token_idx] == 0 or offsets[token_idx] == (0, 0) else 1)
for token_idx in range(len(tokens))
]
start_label_mask = label_mask.copy()
end_label_mask = label_mask.copy()
assert all(start_label_mask[p] != 0 for p in new_start_positions)
assert all(end_label_mask[p] != 0 for p in new_end_positions)
assert len(label_mask) == len(tokens)
start_labels = [(1 if idx in new_start_positions else 0)
for idx in range(len(tokens))]
end_labels = [(1 if idx in new_end_positions else 0)
for idx in range(len(tokens))]
# truncate
token_ids = tokens[: self.max_leng