-
Notifications
You must be signed in to change notification settings - Fork 21
/
decoder.py
147 lines (113 loc) · 5.73 KB
/
decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import numpy as np
import math
import itertools
from scipy.special import softmax
np.seterr(divide='ignore')
class DecodeResult:
def __init__(self, score, words):
self.score, self.words = score, words
self.text = " ".join(word["word"] for word in words)
class GreedyDecoder:
def __init__(self, labels, blank_idx=0):
self.labels, self.blank_idx = labels, blank_idx
self.delim_idx = self.labels.index("|")
def decode(self, output, start_timestamp=0, frame_time=0.02):
best_path = np.argmax(output.astype(np.float32, copy=False), axis=1)
score = None
words, new_word, i = [], True, 0
current_word, current_timestamp, end_idx = None, start_timestamp, 0
words_len = 0
for k, g in itertools.groupby(best_path):
if k != self.blank_idx:
if new_word and k != self.delim_idx:
new_word, start_idx = False, i
current_word, current_timestamp = self.labels[k], frame_time * i + start_timestamp
elif k == self.delim_idx:
end_timestamp = frame_time * i + start_timestamp
new_word, end_idx = True, i
word_score = output[range(start_idx, end_idx), best_path[range(start_idx, end_idx)]] - np.max(output)
if score is not None:
score = np.hstack([score, word_score])
else:
score = word_score
word_confidence = np.round(np.exp(word_score.mean() / max(1, end_idx - start_idx)) * 100.0, 2)
words_len += end_idx - start_idx
words.append({
"word": current_word,
"start": np.round(current_timestamp, 2),
"end": np.round(end_timestamp, 2),
"confidence": word_confidence
})
else:
current_word += self.labels[k]
i += sum(1 for _ in g)
score = np.round(np.exp(score.mean() / max(1, words_len)) * 100.0, 2)
return DecodeResult(score, words)
class TrieDecoder:
def __init__(self, lexicon, tokens, lm_path, beam_threshold=30):
from trie_decoder.common import Dictionary, create_word_dict, load_words
from trie_decoder.decoder import CriterionType, DecoderOptions, KenLM, LexiconDecoder
lexicon = load_words(lexicon)
self.wordDict = create_word_dict(lexicon)
self.tokenDict = Dictionary(tokens)
self.lm = KenLM(lm_path, self.wordDict)
trie, self.sil_idx, self.blank_idx, self.unk_idx = self.get_trie(lexicon)
transitions = np.zeros((self.tokenDict.index_size(), self.tokenDict.index_size())).flatten()
opts = DecoderOptions(
2000, 100, beam_threshold, 1.4, 1.0, -math.inf, -1, 0, False, CriterionType.CTC
)
self.trieDecoder = LexiconDecoder(
opts, trie, self.lm, self.sil_idx, self.blank_idx, self.unk_idx, transitions, False
)
self.delim_idx = self.tokenDict.get_index("|")
def get_trie(self, lexicon):
from trie_decoder.common import tkn_to_idx
from trie_decoder.decoder import SmearingMode, Trie
unk_idx = self.wordDict.get_index("<unk>")
sil_idx = blank_idx = self.tokenDict.get_index("#")
trie = Trie(self.tokenDict.index_size(), sil_idx)
start_state = self.lm.start(False)
for word, spellings in lexicon.items():
usr_idx = self.wordDict.get_index(word)
_, score = self.lm.score(start_state, usr_idx)
score = np.round(score, 2)
for spelling in spellings:
spelling_indices = tkn_to_idx(spelling, self.tokenDict, 0)
trie.insert(spelling_indices, usr_idx, score)
trie.smear(SmearingMode.MAX)
return trie, sil_idx, blank_idx, unk_idx
def decode(self, output, start_timestamp=0, frame_time=0.02):
output = np.log(softmax(output[:, :].astype(np.float32, copy=False), axis=-1))
t, n = output.shape
result = self.trieDecoder.decode(output.ctypes.data, t, n)[0]
tokens = result.tokens
words, new_word = [], True
current_word, current_timestamp, start_idx, end_idx = None, start_timestamp, 0, 0
lm_state = self.lm.start(False)
words_len = 0
for i, k in enumerate(tokens):
if k != self.blank_idx:
if i > 0 and k == tokens[i - 1]:
pass
elif k == self.sil_idx:
new_word = True
else:
if new_word and k != self.delim_idx:
new_word = False
current_word, current_timestamp = self.tokenDict.get_entry(k), frame_time * i + start_timestamp
start_idx = i
elif k == self.delim_idx:
new_word, end_idx = True, i
lm_state, word_lm_score = self.lm.score(lm_state, self.wordDict.get_index(current_word))
end_timestamp = frame_time * i + start_timestamp
words_len += end_idx - start_idx
words.append({
"word": current_word,
"start": np.round(current_timestamp, 2),
"end": np.round(end_timestamp, 2),
"confidence": np.round(np.exp(word_lm_score / max(1, end_idx - start_idx)) * 100, 2)
})
else:
current_word += self.tokenDict.get_entry(k)
score = np.round(np.exp(result.score / max(1, words_len)), 2)
return DecodeResult(score, words)