forked from arixlin/tensorflow-wavenet
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest.py
62 lines (46 loc) · 2.1 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#-*- coding:utf-8 -*-
from __future__ import print_function
from model import Model
from utils import SpeechLoader
import tensorflow as tf # 1.12.0
import numpy as np
import librosa
import os
# 语音识别
# 把batch_size改为1
def speech_to_text():
n_mfcc = 60
# load data
wav_path = os.path.join(os.getcwd(), 'data', 'wav', 'train')
label_file = os.path.join(os.getcwd(), 'data', 'doc', 'trans', 'train.word.txt')
speech_loader = SpeechLoader(wav_path, label_file, batch_size=1, n_mfcc=60)
# load model
model = Model(speech_loader.vocab_size, n_mfcc=n_mfcc, is_training=False)
saver = tf.train.Saver(tf.trainable_variables())
with tf.Session() as sess:
saver.restore(sess, tf.train.latest_checkpoint('model'))
for j in range(905, 915):
# extract feature
wav_file = os.path.join(os.getcwd(), 'data', 'wav', 'test', 'D12_'+str(j)+'.wav')
wav, sr = librosa.load(wav_file, mono=True)
mfcc = np.transpose(np.expand_dims(librosa.feature.mfcc(wav, sr, n_mfcc=n_mfcc), axis=0), [0,2,1])
mfcc = mfcc.tolist()
# fill 0
while len(mfcc[0]) < speech_loader.wav_max_len:
mfcc[0].append([0] * n_mfcc)
# word dict
wmap = {value:key for key, value in speech_loader.wordmap.items()}
# recognition
decoded = tf.transpose(model.logit, perm=[1, 0, 2])
decoded, probs = tf.nn.ctc_beam_search_decoder(decoded, model.seq_len, top_paths=1, merge_repeated=True)
predict = tf.sparse_to_dense(decoded[0].indices, decoded[0].dense_shape, decoded[0].values) + 1
output, probs = sess.run([predict, probs], feed_dict={model.input_data: mfcc})
# print result
words = ''
for i in range(len(output[0])):
words += wmap.get(output[0][i], -1)
print("---------------------------")
print("Input: " + wav_file)
print("Output: " + words)
if __name__ == '__main__':
speech_to_text()