# Copyright (c) 2022 Heiheiyoyo. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import functools
import json
import logging
import math
import random
import re
import shutil
import threading
import time
from functools import partial
import colorlog
import numpy as np
import torch
from colorama import Back, Fore
from torch.utils.data import Dataset
from tqdm import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm
loggers = {}
log_config = {
'DEBUG': {
'level': 10,
'color': 'purple'
},
'INFO': {
'level': 20,
'color': 'green'
},
'TRAIN': {
'level': 21,
'color': 'cyan'
},
'EVAL': {
'level': 22,
'color': 'blue'
},
'WARNING': {
'level': 30,
'color': 'yellow'
},
'ERROR': {
'level': 40,
'color': 'red'
},
'CRITICAL': {
'level': 50,
'color': 'bold_red'
}
}
def get_span(start_ids, end_ids, with_prob=False):
"""
Get span set from position start and end list.
Args:
start_ids (List[int]/List[tuple]): The start index list.
end_ids (List[int]/List[tuple]): The end index list.
with_prob (bool): If True, each element for start_ids and end_ids is a tuple aslike: (index, probability).
Returns:
set: The span set without overlapping, every id can only be used once .
"""
if with_prob:
start_ids = sorted(start_ids, key=lambda x: x[0])
end_ids = sorted(end_ids, key=lambda x: x[0])
else:
start_ids = sorted(start_ids)
end_ids = sorted(end_ids)
start_pointer = 0
end_pointer = 0
len_start = len(start_ids)
len_end = len(end_ids)
couple_dict = {}
while start_pointer < len_start and end_pointer < len_end:
if with_prob:
start_id = start_ids[start_pointer][0]
end_id = end_ids[end_pointer][0]
else:
start_id = start_ids[start_pointer]
end_id = end_ids[end_pointer]
if start_id == end_id:
couple_dict[end_ids[end_pointer]] = start_ids[start_pointer]
start_pointer += 1
end_pointer += 1
continue
if start_id < end_id:
couple_dict[end_ids[end_pointer]] = start_ids[start_pointer]
start_pointer += 1
continue
if start_id > end_id:
end_pointer += 1
continue
result = [(couple_dict[end], end) for end in couple_dict]
result = set(result)
return result
def get_bool_ids_greater_than(probs, limit=0.5, return_prob=False):
"""
Get idx of the last dimension in probability arrays, which is greater than a limitation.
Args:
probs (List[List[float]]): The input probability arrays.
limit (float): The limitation for probability.
return_prob (bool): Whether to return the probability
Returns:
List[List[int]]: The index of the last dimension meet the conditions.
"""
probs = np.array(probs)
dim_len = len(probs.shape)
if dim_len > 1:
result = []
for p in probs:
result.append(get_bool_ids_greater_than(p, limit, return_prob))
return result
else:
result = []
for i, p in enumerate(probs):
if p > limit:
if return_prob:
result.append((i, p))
else:
result.append(i)
return result
class SpanEvaluator:
"""
SpanEvaluator computes the precision, recall and F1-score for span detection.
"""
def __init__(self):
super(SpanEvaluator, self).__init__()
self.num_infer_spans = 0
self.num_label_spans = 0
self.num_correct_spans = 0
def compute(self, start_probs, end_probs, gold_start_ids, gold_end_ids):
"""
Computes the precision, recall and F1-score for span detection.
"""
pred_start_ids = get_bool_ids_greater_than(start_probs)
pred_end_ids = get_bool_ids_greater_than(end_probs)
gold_start_ids = get_bool_ids_greater_than(gold_start_ids.tolist())
gold_end_ids = get_bool_ids_greater_than(gold_end_ids.tolist())
num_correct_spans = 0
num_infer_spans = 0
num_label_spans = 0
for predict_start_ids, predict_end_ids, label_start_ids, label_end_ids in zip(
pred_start_ids, pred_end_ids, gold_start_ids, gold_end_ids):
[_correct, _infer, _label] = self.eval_span(
predict_start_ids, predict_end_ids, label_start_ids,
label_end_ids)
num_correct_spans += _correct
num_infer_spans += _infer
num_label_spans += _label
return num_correct_spans, num_infer_spans, num_label_spans
def update(self, num_correct_spans, num_infer_spans, num_label_spans):
"""
This function takes (num_infer_spans, num_label_spans, num_correct_spans) as input,
to accumulate and update the corresponding status of the SpanEvaluator object.
"""
self.num_infer_spans += num_infer_spans
self.num_label_spans += num_label_spans
self.num_correct_spans += num_correct_spans
def eval_span(self, predict_start_ids, predict_end_ids, label_start_ids,
label_end_ids):
"""
evaluate position extraction (start, end)
return num_correct, num_infer, num_label
input: [1, 2, 10] [4, 12] [2, 10] [4, 11]
output: (1, 2, 2)
"""
pred_set = get_span(predict_start_ids, predict_end_ids)
label_set = get_span(label_start_ids, label_end_ids)
num_correct = len(pred_set & label_set)
num_infer = len(pred_set)
num_label = len(label_set)
return (num_correct, num_infer, num_label)
def accumulate(self):
"""
This function returns the mean precision, recall and f1 score for all accumulated minibatches.
Returns:
tuple: Returns tuple (`precision, recall, f1 score`).
"""
precision = float(self.num_correct_spans /
self.num_infer_spans) if self.num_infer_spans else 0.
recall = float(self.num_correct_spans /
self.num_label_spans) if self.num_label_spans else 0.
f1_score = float(2 * precision * recall /
(precision + recall)) if self.num_correct_spans else 0.
return precision, recall, f1_score
def reset(self):
"""
Reset function empties the evaluation memory for previous mini-batches.
"""
self.num_infer_spans = 0
self.num_label_spans = 0
self.num_correct_spans = 0
def name(self):
"""
Return name of metric instance.
"""
return "precision", "recall", "f1"
class IEDataset(Dataset):
"""
Dataset for Information Extraction fron jsonl file.
The line type is
{
content
result_list
prompt
}
"""
def __init__(self, file_path, tokenizer, max_seq_len) -> None:
super().__init__()
self.file_path = file_path
self.dataset = list(reader(file_path))
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
return convert_example(self.dataset[index], tokenizer=self.
基于百度uie的关系抽取.zip
版权申诉
142 浏览量
2024-01-18
21:38:31
上传
评论 2
收藏 1017KB ZIP 举报
博士僧小星
- 粉丝: 1711
- 资源: 5876
最新资源
- 农村信用社联合社计算机信息系统投产与变更管理办.docx
- 农村信用社联合社计算机信息系统数据管理办法.docx
- 利用SPSS作临床效度分析线上计算网站介绍-医学研究部统计谘.(医学PPT课件).ppt
- 利用Zabbix监控mysqldump定时备份数据库状态.docx
- 利用计算机解决问题的基本过程.doc
- 化工铁路通信工程总结.doc
- 北京大学网络教育软件工程作业.docx
- 医药公司(连锁店)计算机操作规程未新系统的自行按照旧制修改-新系统过制的编号加修模版.doc
- 医药公司(连锁店)计算机系统操作规程模版.doc
- 医药连锁门店计算机系统的操作和管理程序未新系统的自行按照旧制修改-新系统过制的编号加修模版.docx
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈