-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathconll.py
More file actions
120 lines (104 loc) · 5.44 KB
/
conll.py
File metadata and controls
120 lines (104 loc) · 5.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import re
import tempfile
import subprocess
import operator
import collections
import logging
logger = logging.getLogger(__name__)
BEGIN_DOCUMENT_REGEX = re.compile(r"#begin document (.*)(?:\.|_)(\d+_?\d*)") # First line at each document
COREF_RESULTS_REGEX = re.compile(r".*Coreference: Recall: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\tPrecision: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\tF1: ([0-9.]+)%.*", re.DOTALL)
REMOVE_MENTION_MARKUP = re.compile(r"\(?(\d+)\)?")
def get_doc_key(doc_id, part):
return "{}_{}".format(doc_id, int(part))
def output_conll(input_file, output_file, predictions, subtoken_map, merge_overlapping_spans):
prediction_map = {}
for doc_key, clusters in predictions.items():
start_map = collections.defaultdict(list)
end_map = collections.defaultdict(list)
word_map = collections.defaultdict(list)
for cluster_id, mentions in enumerate(clusters):
for start, end in mentions:
start, end = subtoken_map[doc_key][start], subtoken_map[doc_key][end]
if start == end:
word_map[start].append(cluster_id)
else:
start_map[start].append((cluster_id, end))
end_map[end].append((cluster_id, start))
for k,v in start_map.items():
start_map[k] = [cluster_id for cluster_id, end in sorted(v, key=operator.itemgetter(1), reverse=True)]
for k,v in end_map.items():
end_map[k] = [cluster_id for cluster_id, start in sorted(v, key=operator.itemgetter(1), reverse=True)]
prediction_map[doc_key] = (start_map, end_map, word_map)
word_index = 0
active_count = collections.Counter()
for line in input_file.readlines():
row = line.split()
if len(row) == 0:
output_file.write("\n")
elif row[0].startswith("#"):
begin_match = re.match(BEGIN_DOCUMENT_REGEX, line)
if begin_match:
# Reset at the start of each document
active_count = collections.Counter()
doc_key = get_doc_key(begin_match.group(1), begin_match.group(2))
start_map, end_map, word_map = prediction_map[doc_key]
word_index = 0
output_file.write(line)
output_file.write("\n")
else:
# assert get_doc_key(row[0], row[1]) == doc_key
# assert "_".join(row[0].split(".")) == doc_key
row[3] = REMOVE_MENTION_MARKUP.sub(r"\1", row[3])
row[6] = REMOVE_MENTION_MARKUP.sub(r"\1", row[6])
coref_list = []
if word_index in end_map:
for cluster_id in end_map[word_index]:
active_count[cluster_id] -= 1
if active_count[cluster_id] == 0 or not merge_overlapping_spans:
coref_list.append("{})".format(cluster_id))
if word_index in word_map:
if merge_overlapping_spans:
for cluster_id in set(word_map[word_index]) - set(start_map[word_index]) - set(end_map[word_index]):
if active_count[cluster_id] == 0:
coref_list.append("({})".format(cluster_id))
else:
for cluster_id in set(word_map[word_index]):
coref_list.append("({})".format(cluster_id))
if word_index in start_map:
for cluster_id in start_map[word_index]:
active_count[cluster_id] += 1
if active_count[cluster_id] == 1 or not merge_overlapping_spans:
coref_list.append("({}".format(cluster_id))
if len(coref_list) == 0:
row[-1] = "-"
else:
row[-1] = "|".join(coref_list)
output_file.write(" ".join(row))
output_file.write("\n")
word_index += 1
def official_conll_eval(conll_scorer, gold_path, predicted_path, metric, official_stdout=True):
cmd = [conll_scorer, metric, gold_path, predicted_path, "none"]
process = subprocess.Popen(cmd, stdout=subprocess.PIPE)
stdout, stderr = process.communicate()
process.wait()
stdout = stdout.decode("utf-8")
if stderr is not None:
logger.error(stderr)
if official_stdout:
logger.info("Official result for {}".format(metric))
logger.info(stdout)
coref_results_match = re.match(COREF_RESULTS_REGEX, stdout)
try:
recall = float(coref_results_match.group(1))
precision = float(coref_results_match.group(2))
f1 = float(coref_results_match.group(3))
return {"r": recall, "p": precision, "f": f1}
except AttributeError: # This happens if we can't calculate it properly in the script for some reason
return {"r": None, "p": None, "f": None}
def evaluate_conll(conll_scorer, gold_path, predictions, subtoken_maps, out_file=None, official_stdout=True, merge_overlapping_spans=False):
with open(out_file, "w") if out_file is not None else tempfile.NamedTemporaryFile(delete=True, mode="w") as prediction_file:
with open(gold_path, "r") as gold_file:
output_conll(gold_file, prediction_file, predictions, subtoken_maps, merge_overlapping_spans=merge_overlapping_spans)
# logger.info("Predicted conll file: {}".format(prediction_file.name))
results = {m: official_conll_eval(conll_scorer, gold_file.name, prediction_file.name, m, official_stdout) for m in ("muc", "bcub", "ceafe") }
return results