# -*- coding: utf8 -*-
Project Name: NLP
File Name: mrc_for_ner
Author: czh
Create Date: 2022/2/23
Change Activity:
import os
import sys
import json
from tqdm import tqdm
import codecs
from typing import List, Tuple
import numpy as np
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from transformers import BertConfig, BertTokenizerFast
from nlp.models.bert_for_ner import BertQueryNER
from nlp.tools.path import project_root_path
root_path = project_root_path()
# bert_model_name_or_path = "/Users/czh/Downloads/chinese-roberta-ext"
bert_model_name_or_path = "/data/chenzhihao/chinese-roberta-ext"
data_dir = "datas/cluener"
output_dir = "output_file_dir/mrc"
max_sequence_length = 512
batch_size = 10
epochs = 30
lr_rate = 2e-5
gradient_accumulation_steps = 1
logging_steps = 500
num_worker = 0
warmup_ratio = 0.1
weight_start = 1.0
weight_end = 1.0
weight_span = 1.0
span_loss_candidates = "all" # ["all", "pred_and_gold","pred_gold_random","gold"],Candidates used to compute span loss
device = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu')
tokenizer = BertTokenizerFast.from_pretrained(bert_model_name_or_path, do_lower_case=False, add_special_tokens=True)
bert_config = BertConfig.from_pretrained(bert_model_name_or_path)
bce_loss = nn.BCEWithLogitsLoss(reduction='none')
query_map = {
'address': "找出省、市、区、街道乡村等抽象或具体的地点",
'book': "找出小说、杂志、习题集、教科书、教辅、地图册、食谱等具体的书名",
'company': "找出公司、集团、银行(央行,中国人民银行除外,二者属于政府机构)等具体的公司名",
'game': "找出常见的游戏名",
'government': "找出中央行政机关和地方行政机关的名字",
'movie': "找出电影、纪录片等放映或上线的影片名字",
'name': "找出真实和虚构的人名",
'organization': "找出包括篮球队、足球队、乐团、社团、小说里面的帮派等真实或虚构的组织机构名",
'position': "找出现代和古时候的职称名",
'scene': "找出常见的旅游景点"
def trans_data_to_mrc(input_file):
mrc_samples = []
with codecs.open(input_file, encoding="utf8") as fr:
for line in fr:
line = line.strip()
if not line:
item = json.loads(line)
text = item['text']
label_dict = item['label']
for label, query in query_map.items():
start_positions = []
end_positions = []
entity_dict = label_dict.get(label, None)
if not entity_dict:
for k, offsets in entity_dict.items():
for s, e in offsets:
"context": text,
"start_position": start_positions,
"end_position": end_positions,
"query": query
return mrc_samples
class MRCNERDataset(Dataset):
MRC NER Dataset
max_length: int, max length of query+context
possible_only: if True, only use possible samples that contain answer for the query/context
is_chinese: is chinese dataset
def __init__(self, datasets, max_length: int = 512, possible_only=False,
is_chinese=False, pad_to_maxlen=False):
self.all_data = datasets
self.max_length = max_length
self.possible_only = possible_only
if self.possible_only:
self.all_data = [
x for x in self.all_data if x["start_position"]
self.is_chinese = is_chinese
self.pad_to_maxlen = pad_to_maxlen
def __len__(self):
return len(self.all_data)
def __getitem__(self, item):
item: int, idx
tokens: tokens of query + context, [seq_len]
token_type_ids: token type ids, 0 for query, 1 for context, [seq_len]
start_labels: start labels of NER in tokens, [seq_len]
end_labels: end labelsof NER in tokens, [seq_len]
label_mask: label mask, 1 for counting into loss, 0 for ignoring. [seq_len]
match_labels: match labels, [seq_len, seq_len]
sample_idx: sample id
label_idx: label id
data = self.all_data[item]
qas_id = data.get("qas_id", "0.0")
sample_idx, label_idx = qas_id.split(".")
sample_idx = torch.LongTensor([int(sample_idx)])
label_idx = torch.LongTensor([int(label_idx)])
query = data["query"]
context = data["context"]
start_positions = data["start_position"]
end_positions = data["end_position"]
# TODO: 修改tokenizer
query_context_tokens = tokenizer.encode_plus(query, context, add_special_tokens=True, return_offsets_mapping=True)
tokens = query_context_tokens['input_ids']
attention_masks = query_context_tokens['attention_mask']
type_ids = query_context_tokens['token_type_ids']
offsets = query_context_tokens['offset_mapping']
# find new start_positions/end_positions, considering
# 1. we add query tokens at the beginning
# 2. word-piece tokenize
origin_offset2token_idx_start = {}
origin_offset2token_idx_end = {}
for token_idx in range(len(tokens)):
# skip query tokens
if type_ids[token_idx] == 0:
token_start, token_end = offsets[token_idx][0], offsets[token_idx][1]
# skip [CLS] or [SEP]
if token_start == token_end == 0:
origin_offset2token_idx_start[token_start] = token_idx
origin_offset2token_idx_end[token_end] = token_idx
new_start_positions = []
new_end_positions = []
for s, e in zip(start_positions, end_positions):
new_s = origin_offset2token_idx_start[s]
new_e = origin_offset2token_idx_end[e]
except Exception as exc:
print((s, e), offsets)
print(tokenizer.tokenize(query, context, add_special_tokens=True))
if 0 < new_s <= new_e < len(offsets):
print((s, e), (new_s, new_e), query, context)
label_mask = [
(0 if type_ids[token_idx] == 0 or offsets[token_idx] == (0, 0) else 1)
for token_idx in range(len(tokens))
start_label_mask = label_mask.copy()
end_label_mask = label_mask.copy()
assert all(start_label_mask[p] != 0 for p in new_start_positions)
assert all(end_label_mask[p] != 0 for p in new_end_positions)
assert len(label_mask) == len(tokens)
start_labels = [(1 if idx in new_start_positions else 0)
for idx in range(len(tokens))]
end_labels = [(1 if idx in new_end_positions else 0)
for idx in range(len(tokens))]
# truncate
token_ids = tokens[: self.max_leng