# -*- coding:utf8 -*-
# ==============================================================================
# Copyright 2017 Baidu.com, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
This module computes evaluation metrics for DuReader dataset.
"""
import argparse
import json
import sys
import zipfile
from collections import Counter
# from .bleu_metric.bleu import Bleu
from .rouge import Rouge
EMPTY = ''
YESNO_LABELS = set(['Yes', 'No', 'Depends'])
def normalize(s):
"""
Normalize strings to space joined chars.
Args:
s: a list of strings.
Returns:
A list of normalized strings.
"""
if not s:
return s
normalized = []
for ss in s:
tokens = [c for c in list(ss) if len(c.strip()) != 0]
normalized.append(' '.join(tokens))
return normalized
def data_check(obj, task):
"""
Check data.
Raises:
Raises AssertionError when data is not legal.
"""
assert 'question_id' in obj, "Missing 'question_id' field."
assert 'question_type' in obj, \
"Missing 'question_type' field. question_id: {}".format(obj['question_type'])
assert 'yesno_answers' in obj, \
"Missing 'yesno_answers' field. question_id: {}".format(obj['question_id'])
assert isinstance(obj['yesno_answers'], list), \
r"""'yesno_answers' field must be a list, if the 'question_type' is not
'YES_NO', then this field should be an empty list.
question_id: {}""".format(obj['question_id'])
assert 'entity_answers' in obj, \
"Missing 'entity_answers' field. question_id: {}".format(obj['question_id'])
assert isinstance(obj['entity_answers'], list) \
and len(obj['entity_answers']) > 0, \
r"""'entity_answers' field must be a list, and has at least one element,
which can be a empty list. question_id: {}""".format(obj['question_id'])
def read_file(file_name, task, is_ref=False):
"""
Read predict answers or reference answers from file.
Args:
file_name: the name of the file containing predict result or reference
result.
Returns:
A dictionary mapping question_id to the result information. The result
information itself is also a dictionary with has four keys:
- question_type: type of the query.
- yesno_answers: A list of yesno answers corresponding to 'answers'.
- answers: A list of predicted answers.
- entity_answers: A list, each element is also a list containing the entities
tagged out from the corresponding answer string.
"""
def _open(file_name, mode, zip_obj=None):
if zip_obj is not None:
return zip_obj.open(file_name, mode)
return open(file_name, mode)
results = {}
keys = ['answers', 'yesno_answers', 'entity_answers', 'question_type']
if is_ref:
keys += ['source']
zf = zipfile.ZipFile(file_name, 'r') if file_name.endswith('.zip') else None
file_list = [file_name] if zf is None else zf.namelist()
for fn in file_list:
for line in _open(fn, 'r', zip_obj=zf):
try:
obj = json.loads(line.strip())
except ValueError:
raise ValueError("Every line of data should be legal json")
data_check(obj, task)
qid = obj['question_id']
assert qid not in results, "Duplicate question_id: {}".format(qid)
results[qid] = {}
for k in keys:
results[qid][k] = obj[k]
return results
def compute_bleu_rouge(pred_dict, ref_dict, bleu_order=4):
"""
Compute bleu and rouge scores.
"""
assert set(pred_dict.keys()) == set(ref_dict.keys()), \
"missing keys: {}".format(set(ref_dict.keys()) - set(pred_dict.keys()))
scores = {}
# bleu_scores, _ = Bleu(bleu_order).compute_score(ref_dict, pred_dict)
# for i, bleu_score in enumerate(bleu_scores):
# scores['Bleu-%d' % (i + 1)] = bleu_score
rouge_score, _ = Rouge().compute_score(ref_dict, pred_dict)
scores['Rouge-L'] = rouge_score
return scores
def local_prf(pred_list, ref_list):
"""
Compute local precision recall and f1-score,
given only one prediction list and one reference list
"""
common = Counter(pred_list) & Counter(ref_list)
num_same = sum(common.values())
if num_same == 0:
return 0, 0, 0
p = 1.0 * num_same / len(pred_list)
r = 1.0 * num_same / len(ref_list)
f1 = (2 * p * r) / (p + r)
return p, r, f1
def compute_prf(pred_dict, ref_dict):
"""
Compute precision recall and f1-score.
"""
pred_question_ids = set(pred_dict.keys())
ref_question_ids = set(ref_dict.keys())
correct_preds, total_correct, total_preds = 0, 0, 0
for question_id in ref_question_ids:
pred_entity_list = pred_dict.get(question_id, [[]])
assert len(pred_entity_list) == 1, \
'the number of entity list for question_id {} is not 1.'.format(question_id)
pred_entity_list = pred_entity_list[0]
all_ref_entity_lists = ref_dict[question_id]
best_local_f1 = 0
best_ref_entity_list = None
for ref_entity_list in all_ref_entity_lists:
local_f1 = local_prf(pred_entity_list, ref_entity_list)[2]
if local_f1 > best_local_f1:
best_ref_entity_list = ref_entity_list
best_local_f1 = local_f1
if best_ref_entity_list is None:
if len(all_ref_entity_lists) > 0:
best_ref_entity_list = sorted(all_ref_entity_lists,
key=lambda x: len(x))[0]
else:
best_ref_entity_list = []
gold_entities = set(best_ref_entity_list)
pred_entities = set(pred_entity_list)
correct_preds += len(gold_entities & pred_entities)
total_preds += len(pred_entities)
total_correct += len(gold_entities)
p = float(correct_preds) / total_preds if correct_preds > 0 else 0
r = float(correct_preds) / total_correct if correct_preds > 0 else 0
f1 = 2 * p * r / (p + r) if correct_preds > 0 else 0
return {'Precision': p, 'Recall': r, 'F1': f1}
def prepare_prf(pred_dict, ref_dict):
"""
Prepares data for calculation of prf scores.
"""
preds = {k: v['entity_answers'] for k, v in pred_dict.items()}
refs = {k: v['entity_answers'] for k, v in ref_dict.items()}
return preds, refs
def filter_dict(result_dict, key_tag):
"""
Filter a subset of the result_dict, where keys ends with 'key_tag'.
"""
filtered = {}
for k, v in result_dict.items():
if k.endswith(key_tag):
filtered[k] = v
return filtered
def get_metrics(pred_result, ref_result, task, source):
"""
Computes metrics.
"""
metrics = {}
ref_result_filtered = {}
pred_result_filtered = {}
if source == 'both':
ref_result_filtered = ref_result
pred_result_filtered = pred_result
else:
for question_id, info in ref_result.items():
if info['source'] == source:
ref_result_filtered[question_id] = info
if question_id in pred_result:
pred_result_filtered[question_id] = pred_result[question_id]
if task == 'main' or task == 'all' \