-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy patheval_functions.py
73 lines (58 loc) · 2.11 KB
/
eval_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from utils import levenshtein
def parse_krn_content(krn, ler_parsing=False, cer_parsing=False):
if cer_parsing:
krn = krn.replace("\n", " <b> ")
krn = krn.replace("\t", " <t> ")
tokens = krn.split(" ")
characters = []
for token in tokens:
if token not in ['<b>', '<t>']:
characters.append(token)
else:
for char in token:
characters.append(char)
return characters
elif ler_parsing:
krn_lines = krn.split("\n")
for i, line in enumerate(krn_lines):
line = line.replace("\n", " <b> ")
line = line.replace("\t", " <t> ")
krn_lines[i] = line
return krn_lines
else:
krn = krn.replace("\n", " <b> ")
krn = krn.replace("\t", " <t> ")
return krn.split(" ")
def compute_metric(a1, a2):
acc_ed_dist = 0
acc_len = 0
for (h, g) in zip(a1, a2):
acc_ed_dist += levenshtein(h, g)
acc_len += len(g)
return 100.*acc_ed_dist / acc_len
def get_metrics(hyp_array, gt_array):
hyp_cer = []
gt_cer = []
hyp_ser = []
gt_ser = []
hyp_ler = []
gt_ler = []
for h_string, gt_string in zip(hyp_array, gt_array):
hyp_ler.append(parse_krn_content(h_string, ler_parsing=True, cer_parsing=False))
gt_ler.append(parse_krn_content(gt_string, ler_parsing=True, cer_parsing=False))
hyp_ser.append(parse_krn_content(h_string, ler_parsing=False, cer_parsing=False))
gt_ser.append(parse_krn_content(gt_string, ler_parsing=False, cer_parsing=False))
hyp_cer.append(parse_krn_content(h_string, ler_parsing=False, cer_parsing=True))
gt_cer.append(parse_krn_content(gt_string, ler_parsing=False, cer_parsing=True))
acc_ed_dist = 0
acc_len = 0
cer = 0
ser = 0
ler = 0
for (h, g) in zip(hyp_cer, gt_cer):
acc_ed_dist += levenshtein(h, g)
acc_len += len(g)
cer = compute_metric(hyp_cer, gt_cer)
ser = compute_metric(hyp_ser, gt_ser)
ler = compute_metric(hyp_ler, gt_ler)
return cer, ser, ler