-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmetrics.py
More file actions
57 lines (44 loc) · 1.87 KB
/
metrics.py
File metadata and controls
57 lines (44 loc) · 1.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import numpy as np
import editdistance
from typing import List
from typing import Optional
from collections import namedtuple
metric_names = ["loss", "wer", "ter", "edit_distance", "normalised_edit_distance"]
Metrics = namedtuple(
"Metrics", field_names=metric_names
)
def get_metrics(predictions: List[List[str]], targets: List[List[str]],
losses: Optional[List[float]] = None) -> Metrics:
assert len(predictions) == len(targets)
assert len(predictions) > 0
num_correct_sequences = 0
total_num_sequences = 0
num_correct_tokens = 0
total_num_tokens = 0
edit_distances = []
normalised_edit_distances = []
for predicted_symbols, target_symbols in zip(predictions, targets):
total_num_sequences += 1
# WER / Sequence Accuracy
if predicted_symbols == target_symbols:
num_correct_sequences += 1
# Token Accuracy
if len(predicted_symbols) != len(target_symbols):
num_correct_tokens = np.nan # Can't calculate token accuracy for sequences with unequal number of symbols
for predicted_symbol, target_symbol in zip(predicted_symbols, target_symbols):
num_correct_tokens += (predicted_symbol == target_symbol)
total_num_tokens += 1
# (Normalised) Edit Distance
dist = editdistance.distance(predicted_symbols, target_symbols)
edit_distances.append(dist)
normalised_edit_distances.append(dist / len(target_symbols))
wer = 100 * (1 - num_correct_sequences / total_num_sequences)
ter = 100 * (1 - num_correct_tokens / total_num_tokens) if total_num_tokens > 0 else None
loss = np.mean(losses) if losses is not None else None
return Metrics(
loss=loss,
wer=wer,
ter=ter,
edit_distance=np.mean(edit_distances),
normalised_edit_distance=np.mean(normalised_edit_distances)
)