diff --git a/docs/conf.py b/docs/conf.py index 1a597982c..fc6a665b2 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -12,8 +12,8 @@ # import os import sys -sys.path.insert(0, os.path.abspath('..')) +sys.path.insert(0, os.path.abspath('..')) # -- Project information ----------------------------------------------------- @@ -21,7 +21,6 @@ copyright = '2020, wenet-team' author = 'wenet-team' - # -- General configuration --------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be @@ -43,7 +42,6 @@ # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] - # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: source_suffix = { @@ -57,7 +55,6 @@ # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] - # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for diff --git a/examples/aishell/NST/local/generate_data_list.py b/examples/aishell/NST/local/generate_data_list.py index 684e7cb68..7a8af6177 100644 --- a/examples/aishell/NST/local/generate_data_list.py +++ b/examples/aishell/NST/local/generate_data_list.py @@ -15,6 +15,7 @@ import os import random + def get_args(): parser = argparse.ArgumentParser(description='generate data.list file ') parser.add_argument('--tar_dir', help='path for tar file') @@ -23,8 +24,8 @@ def get_args(): parser.add_argument('--pseudo_data_ratio', type=float, help='ratio of pseudo data, ' - '0 means none pseudo data, ' - '1 means all using pseudo data.') + '0 means none pseudo data, ' + '1 means all using pseudo data.') parser.add_argument('--out_data_list', help='output path for data list') args = parser.parse_args() return args @@ -55,7 +56,9 @@ def main(): for i in range(len(pseudo_data_list)): pseudo_data_list[i] = target_dir + "/" + pseudo_data_list[i] + "\n" - fused_list = pseudo_data_list[:pseudo_len] + supervised_data_list[:supervised_len] + fused_list = pseudo_data_list[: + pseudo_len] + supervised_data_list[: + supervised_len] with open(output_file, "w") as writer: for line in fused_list: diff --git a/examples/aishell/NST/local/generate_filtered_pseudo_label.py b/examples/aishell/NST/local/generate_filtered_pseudo_label.py index 2a8ee83c3..e184ecd76 100644 --- a/examples/aishell/NST/local/generate_filtered_pseudo_label.py +++ b/examples/aishell/NST/local/generate_filtered_pseudo_label.py @@ -19,28 +19,41 @@ def get_args(): - parser = argparse.ArgumentParser(description='generate filter pseudo label') - parser.add_argument('--dir_num', required=True, help='split directory number') - parser.add_argument('--cer_hypo_dir', required=True, + parser = argparse.ArgumentParser( + description='generate filter pseudo label') + parser.add_argument('--dir_num', + required=True, + help='split directory number') + parser.add_argument('--cer_hypo_dir', + required=True, help='prefix for cer_hypo_dir') - parser.add_argument('--utter_time_file', required=True, + parser.add_argument('--utter_time_file', + required=True, help='the json file that contains audio time infos ') - parser.add_argument('--cer_hypo_threshold', required=True, type=float, + parser.add_argument('--cer_hypo_threshold', + required=True, + type=float, help='the cer-hypo threshold used to filter') - parser.add_argument('--speak_rate_threshold', type=float, + parser.add_argument('--speak_rate_threshold', + type=float, help='the cer threshold we use to filter') parser.add_argument('--dir', required=True, help='dir for the experiment ') # output untar and tar - parser.add_argument('--untar_dir', required=True, + parser.add_argument('--untar_dir', + required=True, help='the output path, ' - 'eg: data/train/wenet_untar_cer_hypo_nst1/') - parser.add_argument('--tar_dir', required=True, + 'eg: data/train/wenet_untar_cer_hypo_nst1/') + parser.add_argument('--tar_dir', + required=True, help='the tar file path, ' - 'eg: data/train/wenet_tar_cer_hypo_leq_10_nst1/') - parser.add_argument('--wav_dir', required=True, + 'eg: data/train/wenet_tar_cer_hypo_leq_10_nst1/') + parser.add_argument('--wav_dir', + required=True, help='dir to store wav files, ' - 'eg "data/train/wenet_1k_untar/"') - parser.add_argument('--start_tar_id', default=0 , type=int, + 'eg "data/train/wenet_1k_untar/"') + parser.add_argument('--start_tar_id', + default=0, + type=int, help='the initial tar id (for debugging)') args = parser.parse_args() return args @@ -118,11 +131,14 @@ def main(): utt_time = utter_time[utt_id] - cer_dict[utt_id] = [pred_no_lm, pred_lm, wer_pred_lm, - utt_time, n_hypo, prediction] + cer_dict[utt_id] = [ + pred_no_lm, pred_lm, wer_pred_lm, utt_time, n_hypo, + prediction + ] else: - cer_dict[utt_id] = [pred_no_lm, pred_lm, - wer_pred_lm, -1, -1, prediction] + cer_dict[utt_id] = [ + pred_no_lm, pred_lm, wer_pred_lm, -1, -1, prediction + ] c = 0 cer_preds = [] @@ -170,8 +186,8 @@ def main(): os.makedirs(out_put_dir, exist_ok=True) for i in range(cur_id, end_id): - print("dir:", dir_num, ", " "tar: ", tar_id, - ", ", "progress:", i / len_data) + print("dir:", dir_num, ", " + "tar: ", tar_id, ", ", "progress:", i / len_data) t_id, utter = data_filtered[i] @@ -184,8 +200,8 @@ def main(): with open(output_path, "w", encoding="utf-8") as writer: writer.write(utter) # update .wav - os.system("cp" + " " + wav_path + " " - + out_put_dir + t_id + ".wav") + os.system("cp" + " " + wav_path + " " + out_put_dir + t_id + + ".wav") else: print(" wav does not exists ! ", wav_path) not_exist.append(wav_path) diff --git a/examples/aishell/NST/local/get_wav_labels.py b/examples/aishell/NST/local/get_wav_labels.py index fb0c5c2b0..b02290070 100644 --- a/examples/aishell/NST/local/get_wav_labels.py +++ b/examples/aishell/NST/local/get_wav_labels.py @@ -16,17 +16,27 @@ def get_args(): parser = argparse.ArgumentParser(description='sum up prediction wer') - parser.add_argument('--job_num', type=int, default=8, + parser.add_argument('--job_num', + type=int, + default=8, help='number of total split dir') - parser.add_argument('--dir_split', required=True, + parser.add_argument('--dir_split', + required=True, help='the path to the data_list dir ' - 'eg data/train/wenet1k_good_split_60/') - parser.add_argument('--label', type=int, default=0, + 'eg data/train/wenet1k_good_split_60/') + parser.add_argument('--label', + type=int, + default=0, help='if ture, label file will also be considered.') - parser.add_argument('--hypo_name', type=str, required=True, + parser.add_argument('--hypo_name', + type=str, + required=True, help='the hypothesis path. eg. /hypothesis_0.txt ') - parser.add_argument('--wav_dir', type=str, required=True, - help='the wav dir path. eg. data/train/wenet_1k_untar/ ') + parser.add_argument( + '--wav_dir', + type=str, + required=True, + help='the wav dir path. eg. data/train/wenet_1k_untar/ ') args = parser.parse_args() return args @@ -53,7 +63,8 @@ def main(): output_wav = data_list_dir + "data_sublist" + str(i) + "/wav.scp" output_label = data_list_dir + "data_sublist" + str(i) + "/label.txt" # bad lines are just for debugging - output_bad_lines = data_list_dir + "data_sublist" + str(i) + "/bad_line.txt" + output_bad_lines = data_list_dir + "data_sublist" + str( + i) + "/bad_line.txt" with open(hypo_path, 'r', encoding="utf-8") as reader: hypo_lines = reader.readlines() diff --git a/examples/aishell/NST/local/split_data_list.py b/examples/aishell/NST/local/split_data_list.py index 17d507cb7..b0727f15a 100644 --- a/examples/aishell/NST/local/split_data_list.py +++ b/examples/aishell/NST/local/split_data_list.py @@ -17,13 +17,17 @@ def get_args(): parser = argparse.ArgumentParser(description='') - parser.add_argument('--job_nums', type=int, default=8, + parser.add_argument('--job_nums', + type=int, + default=8, help='number of total split jobs') - parser.add_argument('--data_list_path', required=True, + parser.add_argument('--data_list_path', + required=True, help='the path to the data.list file') - parser.add_argument('--output_dir', required=True, + parser.add_argument('--output_dir', + required=True, help='path to output dir, ' - 'eg --output_dir=data/train/aishell_split_60') + 'eg --output_dir=data/train/aishell_split_60') args = parser.parse_args() return args @@ -46,7 +50,7 @@ def main(): len_d = int(len(data_list_we) / num_lists) rest_lines = data_list_we[num_lists * len_d:] rest_len = len(rest_lines) - print("total num of lines", len(data_list_we) , "rest len is", rest_len) + print("total num of lines", len(data_list_we), "rest len is", rest_len) # generate N sublist for i in range(num_lists): @@ -57,7 +61,7 @@ def main(): with open(output_list, 'w', encoding="utf-8") as writer: - new_list = data_list_we[i * len_d: (i + 1) * len_d] + new_list = data_list_we[i * len_d:(i + 1) * len_d] if i < rest_len: new_list.append(rest_lines[i]) for x in new_list: diff --git a/examples/aishell4/s0/local/aishell4_process_textgrid.py b/examples/aishell4/s0/local/aishell4_process_textgrid.py index c4fdc5434..c6e4653c0 100755 --- a/examples/aishell4/s0/local/aishell4_process_textgrid.py +++ b/examples/aishell4/s0/local/aishell4_process_textgrid.py @@ -9,6 +9,7 @@ class Segment(object): + def __init__(self, uttid, spkr, stime, etime, text): self.uttid = uttid self.spkr = spkr diff --git a/examples/commonvoice/fr/local/create_scp_text.py b/examples/commonvoice/fr/local/create_scp_text.py index b3d94276e..8ad3d23c8 100755 --- a/examples/commonvoice/fr/local/create_scp_text.py +++ b/examples/commonvoice/fr/local/create_scp_text.py @@ -4,20 +4,25 @@ import sys import os import re + + def process(src_str): punc = '~`!#$%^&*()_+-=|\';":/.,?><~·!@#¥%……&*()——+-=“:’;、。,?》《{}' return re.sub(r"[{0}]+".format(punc), "", src_str).upper() + if __name__ == '__main__': src_dir = sys.argv[1] tsv_file = src_dir + "/" + sys.argv[2] + ".tsv" output_dir = sys.argv[3] for file_path in os.listdir(src_dir + "/clips"): - if(os.path.exists(src_dir + "/wavs/" + file_path.split('.')[0] + ".wav")): + if (os.path.exists(src_dir + "/wavs/" + file_path.split('.')[0] + + ".wav")): continue t_str = src_dir + "/clips/" + file_path tt_str = src_dir + "/wavs/" + file_path.split('.')[0] + ".wav" - os.system("ffmpeg -i {0} -ac 1 -ar 16000 -f wav {1}".format(t_str, tt_str)) + os.system("ffmpeg -i {0} -ac 1 -ar 16000 -f wav {1}".format( + t_str, tt_str)) import pandas tsv_content = pandas.read_csv(tsv_file, sep="\t") path_list = tsv_content["path"] diff --git a/examples/csj/s0/csj_tools/wn.0.parse.py b/examples/csj/s0/csj_tools/wn.0.parse.py index d916a2cf0..e13bb4dd4 100644 --- a/examples/csj/s0/csj_tools/wn.0.parse.py +++ b/examples/csj/s0/csj_tools/wn.0.parse.py @@ -1,4 +1,3 @@ - # parse xml files and output simplified version import xml.dom.minidom @@ -6,6 +5,7 @@ import sys import multiprocessing + def parsexml(afile, outpath): outfile = os.path.join(outpath, afile.split('/')[-1] + '.simp') @@ -40,7 +40,8 @@ def parsexml(afile, outpath): if suw.hasAttribute('OrthographicTranscription'): txt = suw.getAttribute('OrthographicTranscription') if suw.hasAttribute('PlainOrthographicTranscription'): - plaintxt = suw.getAttribute('PlainOrthographicTranscription') + plaintxt = suw.getAttribute( + 'PlainOrthographicTranscription') if suw.hasAttribute('PhoneticTranscription'): prontxt = suw.getAttribute('PhoneticTranscription') wlist.append(txt) @@ -63,10 +64,11 @@ def parsexml(afile, outpath): lemmasent = ' '.join(lemmalist) dictlemmasent = ' '.join(dictlemmalist) outrow = '{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format( - starttime, endtime, txtsent, plaintxtsent, - prontxtsent, lemmasent, dictlemmasent) + starttime, endtime, txtsent, plaintxtsent, prontxtsent, + lemmasent, dictlemmasent) bw.write(outrow) + def procfolder_orig(apath, outpath): count = 0 for afile in os.listdir(apath): @@ -77,6 +79,7 @@ def procfolder_orig(apath, outpath): count += 1 print('done: {} [{}]'.format(afile, count)) + def procfolder(apath, outpath): # count = 0 fnlist = list() @@ -98,6 +101,7 @@ def procfolder(apath, outpath): print('parallel {} threads done for {} files in total.'.format( nthreads, len(fnlist))) + if __name__ == '__main__': if len(sys.argv) < 3: print("Usage: {} ".format(sys.argv[0])) diff --git a/examples/csj/s0/csj_tools/wn.1.split_wav.py b/examples/csj/s0/csj_tools/wn.1.split_wav.py index ccdf04e9b..7daf01874 100644 --- a/examples/csj/s0/csj_tools/wn.1.split_wav.py +++ b/examples/csj/s0/csj_tools/wn.1.split_wav.py @@ -7,6 +7,7 @@ import librosa import soundfile as sf + # use .simp as the source for .wav file splitting def wavfn(apath): wavdict = dict() # key=id, value=full.path of .wav @@ -16,6 +17,7 @@ def wavfn(apath): wavdict[aid] = fullwavpath return wavdict + def xmlfn(apath): xmldict = dict() # key=id, value=full.path of .xml.simp for axmlfn in os.listdir(apath): @@ -27,6 +29,7 @@ def xmlfn(apath): xmldict[aid] = axmlfn2 return xmldict + def ch2to1(f1, outf1): wav1, _ = librosa.load(f1, sr=16000, mono=False) if wav1.ndim == 1: @@ -40,6 +43,7 @@ def ch2to1(f1, outf1): # overwrite the old .wav file which is 2ch # print(res, acmd) + def proc1file(fullxmlfn, fullwavfn, outwavpath): with open(fullxmlfn) as xmlbr: for axmlline in xmlbr.readlines(): @@ -58,7 +62,8 @@ def proc1file(fullxmlfn, fullwavfn, outwavpath): partwavfn = os.path.join(outwavpath, name2) dur = float(etime) - float(stime) - acmd = 'sox {} {} trim {} {}'.format(fullwavfn, partwavfn, stime, dur) + acmd = 'sox {} {} trim {} {}'.format(fullwavfn, partwavfn, stime, + dur) res = os.system(acmd) # print(res, acmd) @@ -67,6 +72,7 @@ def proc1file(fullxmlfn, fullwavfn, outwavpath): # otherwise, soundfile.write will give us error report! ch2to1(partwavfn, partwavfn1ch) + def procpath(atag, csjpath, xmlsimppath, outwavpath, idset): # atag = 'core' and 'noncore' axmlpath = xmlsimppath @@ -92,19 +98,19 @@ def procpath(atag, csjpath, xmlsimppath, outwavpath, idset): fullwavfn = wavdict[wavid] if wavid in xmldict: fullxmlfn = xmldict[wavid] - pool.apply_async(proc1file, (fullxmlfn, fullwavfn, outwavpath)) + pool.apply_async(proc1file, + (fullxmlfn, fullwavfn, outwavpath)) pool.close() pool.join() print('parallel {} threads done for {} files.'.format( - nthreads, - len(wavidlist))) + nthreads, len(wavidlist))) + if __name__ == '__main__': if len(sys.argv) < 4: - print( - "Usage: {}".format(sys.argv[0]) + - " [id.list.fn]") + print("Usage: {}".format(sys.argv[0]) + + " [id.list.fn]") exit(1) csjpath = sys.argv[1] diff --git a/examples/csj/s0/csj_tools/wn.2.prep.text.py b/examples/csj/s0/csj_tools/wn.2.prep.text.py index 2b132ad9d..a411347ce 100644 --- a/examples/csj/s0/csj_tools/wn.2.prep.text.py +++ b/examples/csj/s0/csj_tools/wn.2.prep.text.py @@ -3,6 +3,7 @@ # train test1 test2 test3 + def readtst(tstfn): outlist = list() with open(tstfn) as br: @@ -11,6 +12,7 @@ def readtst(tstfn): outlist.append(aline) return outlist + def split_train_tests_xml(xmlpath, test1fn, test2fn, test3fn): test1list = readtst(test1fn) test2list = readtst(test2fn) @@ -37,6 +39,7 @@ def split_train_tests_xml(xmlpath, test1fn, test2fn, test3fn): return outtrainlist, outt1list, outt2list, outt3list + def all_wavs(wavpath): wavlist = list() for afile in os.listdir(wavpath): @@ -46,6 +49,7 @@ def all_wavs(wavpath): wavlist.append(afile2) return wavlist + def gen_text(xmllist, outpath): # id \t text # e.g., /workspace/asr/wenet/examples/csj/s0/data/xml/S11M1689.xml.simp @@ -74,6 +78,7 @@ def gen_text(xmllist, outpath): aoutline = '{}\t{}\n'.format(afullid, atxt) bw.write(aoutline) + def parse_xml_set(xmllist): outset = set() for xml in xmllist: @@ -82,6 +87,7 @@ def parse_xml_set(xmllist): outset.add(aid2) return outset + def gen_wav_scp(xmllist, wavlist, outpath): # xmlset = pure id set, alike 'S04F1228' # can be from train, test1, test2, or test3 @@ -109,15 +115,11 @@ def gen_wav_scp(xmllist, wavlist, outpath): bw.write(aoutline) -def prep_text_wavscp( - xmlpath, wavpath, test1fn, test2fn, test3fn, - outtrainpath, out1path, out2path, out3path): +def prep_text_wavscp(xmlpath, wavpath, test1fn, test2fn, test3fn, outtrainpath, + out1path, out2path, out3path): trainlist, t1list, t2list, t3list = split_train_tests_xml( - xmlpath, - test1fn, - test2fn, - test3fn) + xmlpath, test1fn, test2fn, test3fn) wavlist = all_wavs(wavpath) gen_text(trainlist, outtrainpath) @@ -130,12 +132,12 @@ def prep_text_wavscp( gen_wav_scp(t2list, wavlist, out2path) gen_wav_scp(t3list, wavlist, out3path) + if __name__ == '__main__': if len(sys.argv) < 10: - print( - "Usage: {}".format(sys.argv[0]) + " " + - " " + - " ") + print("Usage: {}".format(sys.argv[0]) + " " + + " " + + " ") exit(1) xmlpath = sys.argv[1] @@ -149,6 +151,5 @@ def prep_text_wavscp( out2path = sys.argv[8] out3path = sys.argv[9] - prep_text_wavscp(xmlpath, wavpath, test1fn, - test2fn, test3fn, outtrainpath, + prep_text_wavscp(xmlpath, wavpath, test1fn, test2fn, test3fn, outtrainpath, out1path, out2path, out3path) diff --git a/examples/csj/s0/csj_tools/wn.3.mincut.py b/examples/csj/s0/csj_tools/wn.3.mincut.py index 39e8b8659..bd7231764 100644 --- a/examples/csj/s0/csj_tools/wn.3.mincut.py +++ b/examples/csj/s0/csj_tools/wn.3.mincut.py @@ -2,6 +2,7 @@ # import os import sys + def mincut(wavscpfn, minsec): outfn = wavscpfn + "_" + str(minsec) @@ -15,6 +16,7 @@ def mincut(wavscpfn, minsec): if dur >= minsec: bw.write(aline + '\n') + # wn.3.mincut.py if __name__ == '__main__': if len(sys.argv) < 3: diff --git a/examples/csj/s0/csj_tools/wn.4.make_raw_list.py b/examples/csj/s0/csj_tools/wn.4.make_raw_list.py index eb5aac28b..4424ef09d 100644 --- a/examples/csj/s0/csj_tools/wn.4.make_raw_list.py +++ b/examples/csj/s0/csj_tools/wn.4.make_raw_list.py @@ -58,7 +58,11 @@ if key in segments_table: wav_key, start, end = segments_table[key] wav = wav_table[wav_key] - line = dict(key=key, wav=wav, txt=txt, start=start, end=end) + line = dict(key=key, + wav=wav, + txt=txt, + start=start, + end=end) else: line = None if line: diff --git a/examples/gigaspeech/s0/local/gigaspeech_scoring.py b/examples/gigaspeech/s0/local/gigaspeech_scoring.py index e7679f4ab..78820a36c 100755 --- a/examples/gigaspeech/s0/local/gigaspeech_scoring.py +++ b/examples/gigaspeech/s0/local/gigaspeech_scoring.py @@ -14,6 +14,7 @@ non_scoring_words = conversational_filler + unk_tags + \ gigaspeech_punctuations + gigaspeech_garbage_utterance_tags + def asr_text_post_processing(text): # 1. convert to uppercase text = text.upper() diff --git a/examples/multi_cn/s0/local/primewords_parse_transcript.py b/examples/multi_cn/s0/local/primewords_parse_transcript.py index 772ab7f93..f2eb8e85d 100755 --- a/examples/multi_cn/s0/local/primewords_parse_transcript.py +++ b/examples/multi_cn/s0/local/primewords_parse_transcript.py @@ -14,8 +14,9 @@ def main(argv): metas[fname] = ele fWavScp = open(os.path.join(argv[2], 'wav.scp'), 'w') - fText = open(os.path.join( - argv[2], 'transcripts.txt'), 'w', encoding="utf-8") + fText = open(os.path.join(argv[2], 'transcripts.txt'), + 'w', + encoding="utf-8") fUtt2Spk = open(os.path.join(argv[2], 'utt2spk'), 'w') for line in open(argv[0]): fpath = line.strip('\r\n') diff --git a/examples/swbd/s0/local/format_acronyms_dict.py b/examples/swbd/s0/local/format_acronyms_dict.py index fa598dd03..bcadd7aec 100755 --- a/examples/swbd/s0/local/format_acronyms_dict.py +++ b/examples/swbd/s0/local/format_acronyms_dict.py @@ -17,13 +17,16 @@ parser = argparse.ArgumentParser(description="format acronyms to a._b._c.") parser.add_argument("-i", "--input", help="Input lexicon", required=True) parser.add_argument("-o", "--output", help="Output lexicon", required=True) -parser.add_argument( - "-L", "--Letter", help="Input single letter pronunciation", required=True -) -parser.add_argument("-M", "--Map", help="Output acronyms mapping", required=True) +parser.add_argument("-L", + "--Letter", + help="Input single letter pronunciation", + required=True) +parser.add_argument("-M", + "--Map", + help="Output acronyms mapping", + required=True) args = parser.parse_args() - fin_lex = open(args.input, "r") fin_Letter = open(args.Letter, "r") fout_lex = open(args.output, "w") @@ -33,14 +36,14 @@ dict_letter = {} for single_letter_lex in fin_Letter: items = single_letter_lex.split() - dict_letter[items[0]] = single_letter_lex[len(items[0]) + 1 :].strip() + dict_letter[items[0]] = single_letter_lex[len(items[0]) + 1:].strip() fin_Letter.close() # print dict_letter for lex in fin_lex: items = lex.split() word = items[0] - lexicon = lex[len(items[0]) + 1 :].strip() + lexicon = lex[len(items[0]) + 1:].strip() # find acronyms from words with only letters and ' pre_match = re.match(r"^[A-Za-z]+$|^[A-Za-z]+\'s$|^[A-Za-z]+s$", word) if pre_match: @@ -50,20 +53,20 @@ actual_lexicon = lexicon[:-2] acronym_lexicon = "" for w in actual_word: - acronym_lexicon = acronym_lexicon + dict_letter[w.upper()] + " " + acronym_lexicon = acronym_lexicon + dict_letter[ + w.upper()] + " " if acronym_lexicon.strip() == actual_lexicon: acronym_mapped = "" acronym_mapped_back = "" for w in actual_word[:-1]: acronym_mapped = acronym_mapped + w.lower() + "._" acronym_mapped_back = acronym_mapped_back + w.lower() + " " - acronym_mapped = acronym_mapped + actual_word[-1].lower() + ".'s" - acronym_mapped_back = ( - acronym_mapped_back + actual_word[-1].lower() + "'s" - ) - fout_map.write( - word + "\t" + acronym_mapped + "\t" + acronym_mapped_back + "\n" - ) + acronym_mapped = acronym_mapped + actual_word[-1].lower( + ) + ".'s" + acronym_mapped_back = (acronym_mapped_back + + actual_word[-1].lower() + "'s") + fout_map.write(word + "\t" + acronym_mapped + "\t" + + acronym_mapped_back + "\n") fout_lex.write(acronym_mapped + " " + lexicon + "\n") else: fout_lex.write(lex) @@ -74,20 +77,20 @@ actual_lexicon = lexicon[:-2] acronym_lexicon = "" for w in actual_word: - acronym_lexicon = acronym_lexicon + dict_letter[w.upper()] + " " + acronym_lexicon = acronym_lexicon + dict_letter[ + w.upper()] + " " if acronym_lexicon.strip() == actual_lexicon: acronym_mapped = "" acronym_mapped_back = "" for w in actual_word[:-1]: acronym_mapped = acronym_mapped + w.lower() + "._" acronym_mapped_back = acronym_mapped_back + w.lower() + " " - acronym_mapped = acronym_mapped + actual_word[-1].lower() + ".s" - acronym_mapped_back = ( - acronym_mapped_back + actual_word[-1].lower() + "'s" - ) - fout_map.write( - word + "\t" + acronym_mapped + "\t" + acronym_mapped_back + "\n" - ) + acronym_mapped = acronym_mapped + actual_word[-1].lower( + ) + ".s" + acronym_mapped_back = (acronym_mapped_back + + actual_word[-1].lower() + "'s") + fout_map.write(word + "\t" + acronym_mapped + "\t" + + acronym_mapped_back + "\n") fout_lex.write(acronym_mapped + " " + lexicon + "\n") else: fout_lex.write(lex) @@ -96,7 +99,8 @@ elif word.find("'") == -1 and word[-1] != "s": acronym_lexicon = "" for w in word: - acronym_lexicon = acronym_lexicon + dict_letter[w.upper()] + " " + acronym_lexicon = acronym_lexicon + dict_letter[ + w.upper()] + " " if acronym_lexicon.strip() == lexicon: acronym_mapped = "" acronym_mapped_back = "" @@ -105,9 +109,8 @@ acronym_mapped_back = acronym_mapped_back + w.lower() + " " acronym_mapped = acronym_mapped + word[-1].lower() + "." acronym_mapped_back = acronym_mapped_back + word[-1].lower() - fout_map.write( - word + "\t" + acronym_mapped + "\t" + acronym_mapped_back + "\n" - ) + fout_map.write(word + "\t" + acronym_mapped + "\t" + + acronym_mapped_back + "\n") fout_lex.write(acronym_mapped + " " + lexicon + "\n") else: fout_lex.write(lex) diff --git a/examples/swbd/s0/local/map_acronyms_transcripts.py b/examples/swbd/s0/local/map_acronyms_transcripts.py index ba02aaec3..be35eb4c8 100755 --- a/examples/swbd/s0/local/map_acronyms_transcripts.py +++ b/examples/swbd/s0/local/map_acronyms_transcripts.py @@ -14,7 +14,10 @@ parser = argparse.ArgumentParser(description="format acronyms to a._b._c.") parser.add_argument("-i", "--input", help="Input transcripts", required=True) parser.add_argument("-o", "--output", help="Output transcripts", required=True) -parser.add_argument("-M", "--Map", help="Input acronyms mapping", required=True) +parser.add_argument("-M", + "--Map", + help="Input acronyms mapping", + required=True) args = parser.parse_args() fin_map = open(args.Map, "r") @@ -28,7 +31,6 @@ del dict_acronym_noi["I"] del dict_acronym_noi["i"] - fin_trans = open(args.input, "r") fout_trans = open(args.output, "w") for line in fin_trans: diff --git a/examples/tedlium3/s0/local/join_suffix.py b/examples/tedlium3/s0/local/join_suffix.py index e496c4d07..51065bd2e 100755 --- a/examples/tedlium3/s0/local/join_suffix.py +++ b/examples/tedlium3/s0/local/join_suffix.py @@ -4,7 +4,6 @@ # 2016 Johns Hopkins University (author: Daniel Povey) # Apache 2.0 - import sys # This script joins together pairs of split-up words like "you 're" -> "you're". diff --git a/examples/wenetspeech/s0/local/extract_meta.py b/examples/wenetspeech/s0/local/extract_meta.py index ce2871d0b..c9625d6af 100755 --- a/examples/wenetspeech/s0/local/extract_meta.py +++ b/examples/wenetspeech/s0/local/extract_meta.py @@ -25,7 +25,8 @@ def get_args(): where the long wav is splitinto segments and data of wenet format is generated. """) - parser.add_argument('input_json', help="""Input json file of WenetSpeech""") + parser.add_argument('input_json', + help="""Input json file of WenetSpeech""") parser.add_argument('output_dir', help="""Output dir for prepared data""") args = parser.parse_args() @@ -92,6 +93,7 @@ def meta_analysis(input_json, output_dir): utt2subsets.write( f'{sid}\t{segment_sub_names}\n') + def main(): args = get_args() diff --git a/runtime/android/app/src/main/cpp/wenet.cc b/runtime/android/app/src/main/cpp/wenet.cc index 7c8e92a37..2f5088950 100644 --- a/runtime/android/app/src/main/cpp/wenet.cc +++ b/runtime/android/app/src/main/cpp/wenet.cc @@ -45,31 +45,30 @@ void init(JNIEnv* env, jobject, jstring jModelDir) { resource = std::make_shared(); resource->model = model; - resource->symbol_table = std::shared_ptr( - fst::SymbolTable::ReadText(dictPath)); + resource->symbol_table = + std::shared_ptr(fst::SymbolTable::ReadText(dictPath)); LOG(INFO) << "dict path: " << dictPath; PostProcessOptions post_process_opts; - resource->post_processor = - std::make_shared(post_process_opts); + resource->post_processor = std::make_shared(post_process_opts); feature_config = std::make_shared(80, 16000); feature_pipeline = std::make_shared(*feature_config); decode_config = std::make_shared(); decode_config->chunk_size = 16; - decoder = std::make_shared(feature_pipeline, resource, - *decode_config); + decoder = + std::make_shared(feature_pipeline, resource, *decode_config); } -void reset(JNIEnv *env, jobject) { +void reset(JNIEnv* env, jobject) { LOG(INFO) << "wenet reset"; decoder->Reset(); state = kEndBatch; total_result = ""; } -void accept_waveform(JNIEnv *env, jobject, jshortArray jWaveform) { +void accept_waveform(JNIEnv* env, jobject, jshortArray jWaveform) { jsize size = env->GetArrayLength(jWaveform); int16_t* waveform = env->GetShortArrayElements(jWaveform, 0); feature_pipeline->AcceptWaveform(waveform, size); @@ -114,7 +113,7 @@ void start_decode() { decode_thread.detach(); } -jboolean get_finished(JNIEnv *env, jobject) { +jboolean get_finished(JNIEnv* env, jobject) { if (state == kEndFeats) { LOG(INFO) << "wenet recognize finished"; return JNI_TRUE; @@ -122,7 +121,7 @@ jboolean get_finished(JNIEnv *env, jobject) { return JNI_FALSE; } -jstring get_result(JNIEnv *env, jobject) { +jstring get_result(JNIEnv* env, jobject) { std::string result; if (decoder->DecodedSomething()) { result = decoder->result()[0].sentence; @@ -132,9 +131,9 @@ jstring get_result(JNIEnv *env, jobject) { } } // namespace wenet -JNIEXPORT jint JNI_OnLoad(JavaVM *vm, void *) { - JNIEnv *env; - if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6) != JNI_OK) { +JNIEXPORT jint JNI_OnLoad(JavaVM* vm, void*) { + JNIEnv* env; + if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6) != JNI_OK) { return JNI_ERR; } @@ -144,16 +143,16 @@ JNIEXPORT jint JNI_OnLoad(JavaVM *vm, void *) { } static const JNINativeMethod methods[] = { - {"init", "(Ljava/lang/String;)V", reinterpret_cast(wenet::init)}, - {"reset", "()V", reinterpret_cast(wenet::reset)}, - {"acceptWaveform", "([S)V", - reinterpret_cast(wenet::accept_waveform)}, - {"setInputFinished", "()V", - reinterpret_cast(wenet::set_input_finished)}, - {"getFinished", "()Z", reinterpret_cast(wenet::get_finished)}, - {"startDecode", "()V", reinterpret_cast(wenet::start_decode)}, - {"getResult", "()Ljava/lang/String;", - reinterpret_cast(wenet::get_result)}, + {"init", "(Ljava/lang/String;)V", reinterpret_cast(wenet::init)}, + {"reset", "()V", reinterpret_cast(wenet::reset)}, + {"acceptWaveform", "([S)V", + reinterpret_cast(wenet::accept_waveform)}, + {"setInputFinished", "()V", + reinterpret_cast(wenet::set_input_finished)}, + {"getFinished", "()Z", reinterpret_cast(wenet::get_finished)}, + {"startDecode", "()V", reinterpret_cast(wenet::start_decode)}, + {"getResult", "()Ljava/lang/String;", + reinterpret_cast(wenet::get_result)}, }; int rc = env->RegisterNatives(c, methods, sizeof(methods) / sizeof(JNINativeMethod)); diff --git a/runtime/core/kaldi/decoder/lattice-faster-decoder.cc b/runtime/core/kaldi/decoder/lattice-faster-decoder.cc index a797f6727..774edc94d 100644 --- a/runtime/core/kaldi/decoder/lattice-faster-decoder.cc +++ b/runtime/core/kaldi/decoder/lattice-faster-decoder.cc @@ -368,7 +368,7 @@ void LatticeFasterDecoderTpl::PruneForwardLinks( prev_link = link; // move to next link link = link->next; } - } // for all outgoing links + } // for all outgoing links if (fabs(tok_extra_cost - tok->extra_cost) > delta) changed = true; // difference new minus old is bigger than delta tok->extra_cost = tok_extra_cost; @@ -472,7 +472,7 @@ void LatticeFasterDecoderTpl::PruneForwardLinksFinal() { tok->extra_cost = tok_extra_cost; // will be +infinity or <= lattice_beam_. } - } // while changed + } // while changed } template @@ -537,7 +537,7 @@ void LatticeFasterDecoderTpl::PruneActiveTokens(BaseFloat delta) { PruneForwardLinks(f, &extra_costs_changed, &links_pruned, delta); if (extra_costs_changed && f > 0) // any token has changed extra_cost active_toks_[f - 1].must_prune_forward_links = true; - if (links_pruned) // any link was pruned + if (links_pruned) // any link was pruned active_toks_[f].must_prune_tokens = true; active_toks_[f].must_prune_forward_links = false; // job done } @@ -585,7 +585,7 @@ void LatticeFasterDecoderTpl::ComputeFinalCosts( if (final_best_cost != NULL) { if (best_cost_with_final != infinity) { // final-state exists. *final_best_cost = best_cost_with_final; - } else { // no final-state exists. + } else { // no final-state exists. *final_best_cost = best_cost; } } diff --git a/runtime/core/kaldi/decoder/lattice-faster-decoder.h b/runtime/core/kaldi/decoder/lattice-faster-decoder.h index 34e71fcd8..d79302c2a 100644 --- a/runtime/core/kaldi/decoder/lattice-faster-decoder.h +++ b/runtime/core/kaldi/decoder/lattice-faster-decoder.h @@ -128,8 +128,8 @@ struct ForwardLink { BaseFloat graph_cost; // graph cost of traversing arc (contains LM, etc.) BaseFloat acoustic_cost; // acoustic cost (pre-scaled) of traversing arc float context_score; - ForwardLink* next; // next in singly-linked list of forward arcs (arcs - // in the state-level lattice) from a token. + ForwardLink* next; // next in singly-linked list of forward arcs (arcs + // in the state-level lattice) from a token. inline ForwardLink(Token* next_tok, Label ilabel, Label olabel, BaseFloat graph_cost, BaseFloat acoustic_cost, ForwardLink* next) diff --git a/runtime/gpu/client/client.py b/runtime/gpu/client/client.py index 0ca362e76..bf8eebdb0 100644 --- a/runtime/gpu/client/client.py +++ b/runtime/gpu/client/client.py @@ -38,7 +38,8 @@ type=str, required=False, default="localhost:8001", - help="Inference server URL. Default is " "localhost:8001.", + help="Inference server URL. Default is " + "localhost:8001.", ) parser.add_argument( "--model_name", @@ -157,12 +158,10 @@ def single_job(client_files): with grpcclient.InferenceServerClient( - url=FLAGS.url, verbose=FLAGS.verbose - ) as triton_client: + url=FLAGS.url, verbose=FLAGS.verbose) as triton_client: protocol_client = grpcclient - speech_client = speech_client_cls( - triton_client, FLAGS.model_name, protocol_client, FLAGS - ) + speech_client = speech_client_cls(triton_client, FLAGS.model_name, + protocol_client, FLAGS) idx, audio_files = client_files predictions = [] for li in audio_files: diff --git a/runtime/gpu/client/decode_manifest_triton.py b/runtime/gpu/client/decode_manifest_triton.py index 3a8d57fed..d412d36a6 100644 --- a/runtime/gpu/client/decode_manifest_triton.py +++ b/runtime/gpu/client/decode_manifest_triton.py @@ -79,8 +79,7 @@ def get_args(): parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) + formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( "--server-addr", @@ -216,30 +215,30 @@ async def send( samples = np.zeros( ( 1, - 10 * sample_rate * (int(len(waveform) / sample_rate // 10) + 1), + 10 * sample_rate * + (int(len(waveform) / sample_rate // 10) + 1), ), dtype=np.float32, ) - samples[0, : len(waveform)] = waveform + samples[0, :len(waveform)] = waveform lengths = np.array([[len(waveform)]], dtype=np.int32) inputs = [ - protocol_client.InferInput( - "WAV", samples.shape, np_to_triton_dtype(samples.dtype) - ), - protocol_client.InferInput( - "WAV_LENS", lengths.shape, np_to_triton_dtype(lengths.dtype) - ), + protocol_client.InferInput("WAV", samples.shape, + np_to_triton_dtype(samples.dtype)), + protocol_client.InferInput("WAV_LENS", lengths.shape, + np_to_triton_dtype(lengths.dtype)), ] inputs[0].set_data_from_numpy(samples) inputs[1].set_data_from_numpy(lengths) outputs = [protocol_client.InferRequestedOutput("TRANSCRIPTS")] sequence_id = 10086 + i - response = await triton_client.infer( - model_name, inputs, request_id=str(sequence_id), outputs=outputs - ) + response = await triton_client.infer(model_name, + inputs, + request_id=str(sequence_id), + outputs=outputs) decoding_results = response.as_numpy("TRANSCRIPTS")[0] if type(decoding_results) == np.ndarray: @@ -257,13 +256,11 @@ async def send( hyp = list("".join(hyp)) results.append((c.id, ref, hyp)) else: - results.append( - ( - c.id, - c.supervisions[0].text.split(), - decoding_results.split(), - ) - ) # noqa + results.append(( + c.id, + c.supervisions[0].text.split(), + decoding_results.split(), + )) # noqa return total_duration, results @@ -298,10 +295,10 @@ async def send_streaming( while j < len(waveform): if j == 0: stride = int(first_chunk_in_secs * sample_rate) - wav_segs.append(waveform[j : j + stride]) + wav_segs.append(waveform[j:j + stride]) else: stride = int(other_chunk_in_secs * sample_rate) - wav_segs.append(waveform[j : j + stride]) + wav_segs.append(waveform[j:j + stride]) j += len(wav_segs[-1]) sequence_id = task_index + 10086 @@ -361,8 +358,7 @@ async def send_streaming( else: # For wenet decoding_results = response.as_numpy("TRANSCRIPTS")[0].decode( - "utf-8" - ) + "utf-8") chunk_end = time.time() - chunk_start latency_data.append((chunk_end, chunk_len / sample_rate)) @@ -375,13 +371,11 @@ async def send_streaming( hyp = list("".join(hyp)) results.append((c.id, ref, hyp)) else: - results.append( - ( - c.id, - c.supervisions[0].text.split(), - decoding_results.split(), - ) - ) # noqa + results.append(( + c.id, + c.supervisions[0].text.split(), + decoding_results.split(), + )) # noqa return total_duration, results, latency_data @@ -407,19 +401,17 @@ async def main(): frame_shift_ms = 10 frame_length_ms = 25 add_frames = math.ceil( - (frame_length_ms - frame_shift_ms) / frame_shift_ms - ) + (frame_length_ms - frame_shift_ms) / frame_shift_ms) # decode_window_length: input sequence length of streaming encoder if args.context > 0: # decode window length calculation for wenet - decode_window_length = ( - args.chunk_size - 1 - ) * args.subsampling + args.context + decode_window_length = (args.chunk_size - + 1) * args.subsampling + args.context else: # decode window length calculation for icefall decode_window_length = ( - args.chunk_size + 2 + args.encoder_right_context - ) * args.subsampling + 3 + args.chunk_size + 2 + + args.encoder_right_context) * args.subsampling + 3 first_chunk_ms = (decode_window_length + add_frames) * frame_shift_ms @@ -437,13 +429,10 @@ async def main(): compute_cer=compute_cer, model_name=args.model_name, first_chunk_in_secs=first_chunk_ms / 1000, - other_chunk_in_secs=args.chunk_size - * args.subsampling - * frame_shift_ms - / 1000, + other_chunk_in_secs=args.chunk_size * args.subsampling * + frame_shift_ms / 1000, task_index=i, - ) - ) + )) elif args.simulate_streaming: task = asyncio.create_task( send_streaming( @@ -455,14 +444,11 @@ async def main(): compute_cer=compute_cer, model_name=args.model_name, first_chunk_in_secs=first_chunk_ms / 1000, - other_chunk_in_secs=args.chunk_size - * args.subsampling - * frame_shift_ms - / 1000, + other_chunk_in_secs=args.chunk_size * args.subsampling * + frame_shift_ms / 1000, task_index=i, simulate_mode=True, - ) - ) + )) else: task = asyncio.create_task( send( @@ -473,8 +459,7 @@ async def main(): log_interval=log_interval, compute_cer=compute_cer, model_name=args.model_name, - ) - ) + )) tasks.append(task) ans_list = await asyncio.gather(*tasks) @@ -496,10 +481,8 @@ async def main(): s = f"RTF: {rtf:.4f}\n" s += f"total_duration: {total_duration:.3f} seconds\n" s += f"({total_duration/3600:.2f} hours)\n" - s += ( - f"processing time: {elapsed:.3f} seconds " - f"({elapsed/3600:.2f} hours)\n" - ) + s += (f"processing time: {elapsed:.3f} seconds " + f"({elapsed/3600:.2f} hours)\n") if args.streaming or args.simulate_streaming: latency_list = [ @@ -530,9 +513,8 @@ async def main(): print(f.readline()) # Detailed errors if args.stats_file: - stats = await triton_client.get_inference_statistics( - model_name="", as_json=True - ) + stats = await triton_client.get_inference_statistics(model_name="", + as_json=True) with open(args.stats_file, "w") as f: json.dump(stats, f) diff --git a/runtime/gpu/client/generate_perf_input.py b/runtime/gpu/client/generate_perf_input.py index 682936e58..351458af8 100644 --- a/runtime/gpu/client/generate_perf_input.py +++ b/runtime/gpu/client/generate_perf_input.py @@ -14,15 +14,13 @@ def generate_offline_input(args): mat = np.array([waveform] * batch_size, dtype=np.float32) out_dict = { - "data": [ - { - "WAV_LENS": [len(waveform)], - "WAV": { - "shape": [len(waveform)], - "content": mat.flatten().tolist(), - }, - } - ] + "data": [{ + "WAV_LENS": [len(waveform)], + "WAV": { + "shape": [len(waveform)], + "content": mat.flatten().tolist(), + }, + }] } json.dump(out_dict, open("offline_input.json", "w")) @@ -47,10 +45,10 @@ def generate_online_input(args): while i < len(waveform): if i == 0: stride = int(first_chunk_s * sample_rate) - wav_segs.append(waveform[i : i + stride]) + wav_segs.append(waveform[i:i + stride]) else: stride = int(other_chunk_s * sample_rate) - wav_segs.append(waveform[i : i + stride]) + wav_segs.append(waveform[i:i + stride]) i += len(wav_segs[-1]) data = {"data": [[]]} @@ -68,7 +66,10 @@ def generate_online_input(args): flat_chunk = expect_input.flatten().astype(np.float32).tolist() seq = { - "WAV": {"content": flat_chunk, "shape": expect_input[0].shape}, + "WAV": { + "content": flat_chunk, + "shape": expect_input[0].shape + }, "WAV_LENS": [chunk_len], } data["data"][0].append(seq) @@ -78,9 +79,10 @@ def generate_online_input(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - "--audio_file", type=str, default=None, help="single wav file" - ) + parser.add_argument("--audio_file", + type=str, + default=None, + help="single wav file") # below is only for streaming input parser.add_argument("--streaming", action="store_true", required=False) parser.add_argument( diff --git a/runtime/gpu/client/speech_client.py b/runtime/gpu/client/speech_client.py index cd6fb1cc5..1a33025cd 100644 --- a/runtime/gpu/client/speech_client.py +++ b/runtime/gpu/client/speech_client.py @@ -19,6 +19,7 @@ class OfflineSpeechClient(object): + def __init__(self, triton_client, model_name, protocol_client, args): self.triton_client = triton_client self.protocol_client = protocol_client @@ -36,12 +37,10 @@ def recognize(self, wav_file, idx=0): sequence_id = 10086 + idx result = "" inputs = [ - self.protocol_client.InferInput( - "WAV", samples.shape, np_to_triton_dtype(samples.dtype) - ), - self.protocol_client.InferInput( - "WAV_LENS", lengths.shape, np_to_triton_dtype(lengths.dtype) - ), + self.protocol_client.InferInput("WAV", samples.shape, + np_to_triton_dtype(samples.dtype)), + self.protocol_client.InferInput("WAV_LENS", lengths.shape, + np_to_triton_dtype(lengths.dtype)), ] inputs[0].set_data_from_numpy(samples) inputs[1].set_data_from_numpy(lengths) @@ -61,6 +60,7 @@ def recognize(self, wav_file, idx=0): class StreamingSpeechClient(object): + def __init__(self, triton_client, model_name, protocol_client, args): self.triton_client = triton_client self.protocol_client = protocol_client @@ -76,8 +76,7 @@ def __init__(self, triton_client, model_name, protocol_client, args): # since the subsampling will look ahead several frames first_chunk_length = (chunk_size - 1) * subsampling + context add_frames = math.ceil( - (frame_length_ms - frame_shift_ms) / frame_shift_ms - ) + (frame_length_ms - frame_shift_ms) / frame_shift_ms) first_chunk_ms = (first_chunk_length + add_frames) * frame_shift_ms other_chunk_ms = chunk_size * subsampling * frame_shift_ms self.first_chunk_in_secs = first_chunk_ms / 1000 @@ -90,10 +89,10 @@ def recognize(self, wav_file, idx=0): while i < len(waveform): if i == 0: stride = int(self.first_chunk_in_secs * sample_rate) - wav_segs.append(waveform[i : i + stride]) + wav_segs.append(waveform[i:i + stride]) else: stride = int(self.other_chunk_in_secs * sample_rate) - wav_segs.append(waveform[i : i + stride]) + wav_segs.append(waveform[i:i + stride]) i += len(wav_segs[-1]) sequence_id = idx + 10086 @@ -127,7 +126,9 @@ def recognize(self, wav_file, idx=0): inputs[0].set_data_from_numpy(input0_data) inputs[1].set_data_from_numpy(input1_data) - outputs = [self.protocol_client.InferRequestedOutput("TRANSCRIPTS")] + outputs = [ + self.protocol_client.InferRequestedOutput("TRANSCRIPTS") + ] end = False if idx == len(wav_segs) - 1: end = True diff --git a/runtime/gpu/client/stats_summary.py b/runtime/gpu/client/stats_summary.py index d16b7b4c4..eee0368a5 100644 --- a/runtime/gpu/client/stats_summary.py +++ b/runtime/gpu/client/stats_summary.py @@ -26,8 +26,7 @@ def get_args(): parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) + formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( "--stats_file", @@ -51,9 +50,8 @@ def get_args(): if __name__ == "__main__": args = get_args() - with open(args.stats_file) as stats_f, open( - args.summary_file, "w" - ) as summary_f: + with open(args.stats_file) as stats_f, open(args.summary_file, + "w") as summary_f: stats = json.load(stats_f) model_stats = stats["model_stats"] for model_state in model_stats: @@ -61,20 +59,16 @@ def get_args(): continue summary_f.write(f"model name is {model_state['name']} \n") model_inference_stats = model_state["inference_stats"] - total_queue_time_s = ( - int(model_inference_stats["queue"]["ns"]) / 1e9 - ) + total_queue_time_s = (int(model_inference_stats["queue"]["ns"]) / + 1e9) total_infer_time_s = ( - int(model_inference_stats["compute_infer"]["ns"]) / 1e9 - ) + int(model_inference_stats["compute_infer"]["ns"]) / 1e9) total_input_time_s = ( - int(model_inference_stats["compute_input"]["ns"]) / 1e9 - ) + int(model_inference_stats["compute_input"]["ns"]) / 1e9) total_output_time_s = ( - int(model_inference_stats["compute_output"]["ns"]) / 1e9 - ) + int(model_inference_stats["compute_output"]["ns"]) / 1e9) summary_f.write( - f"queue {total_queue_time_s:<5.2f} s, infer {total_infer_time_s:<5.2f} s, input {total_input_time_s:<5.2f} s, output {total_output_time_s:<5.2f} s \n" # noqa + f"queue {total_queue_time_s:<5.2f} s, infer {total_infer_time_s:<5.2f} s, input {total_input_time_s:<5.2f} s, output {total_output_time_s:<5.2f} s \n" # noqa ) model_batch_stats = model_state["batch_stats"] for batch in model_batch_stats: @@ -83,20 +77,17 @@ def get_args(): compute_output = batch["compute_output"] compute_infer = batch["compute_infer"] batch_count = int(compute_infer["count"]) - assert ( - compute_infer["count"] - == compute_output["count"] - == compute_input["count"] - ) + assert (compute_infer["count"] == compute_output["count"] == + compute_input["count"]) compute_infer_time_ms = int(compute_infer["ns"]) / 1e6 compute_input_time_ms = int(compute_input["ns"]) / 1e6 compute_output_time_ms = int(compute_output["ns"]) / 1e6 summary_f.write( - f"Batch_size {batch_size:<2}, {batch_count:<5} times, infer {compute_infer_time_ms:<9.2f} ms, avg {compute_infer_time_ms/batch_count:.2f} ms, {compute_infer_time_ms/batch_count/batch_size:.2f} ms " # noqa + f"Batch_size {batch_size:<2}, {batch_count:<5} times, infer {compute_infer_time_ms:<9.2f} ms, avg {compute_infer_time_ms/batch_count:.2f} ms, {compute_infer_time_ms/batch_count/batch_size:.2f} ms " # noqa ) summary_f.write( - f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms/batch_count:.2f} ms, " # noqa + f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms/batch_count:.2f} ms, " # noqa ) summary_f.write( - f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms/batch_count:.2f} ms \n" # noqa + f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms/batch_count:.2f} ms \n" # noqa ) diff --git a/runtime/gpu/cuda_decoders/model_repo_cuda_decoder/feature_extractor/1/model.py b/runtime/gpu/cuda_decoders/model_repo_cuda_decoder/feature_extractor/1/model.py index c54f6ca3e..50406df2d 100755 --- a/runtime/gpu/cuda_decoders/model_repo_cuda_decoder/feature_extractor/1/model.py +++ b/runtime/gpu/cuda_decoders/model_repo_cuda_decoder/feature_extractor/1/model.py @@ -21,7 +21,9 @@ from typing import List import json + class Fbank(torch.nn.Module): + def __init__(self, opts): super(Fbank, self).__init__() self.fbank = kaldifeat.Fbank(opts) @@ -131,7 +133,8 @@ def execute(self, requests): batch_len.append(cur_len) for wav, wav_len in zip(cur_b_wav, cur_b_wav_lens): wav_len = wav_len[0] - wav = torch.tensor(wav[0:wav_len], dtype=torch.float32, + wav = torch.tensor(wav[0:wav_len], + dtype=torch.float32, device=self.device) total_waves.append(wav) @@ -139,12 +142,15 @@ def execute(self, requests): for b, l in zip(batch_count, batch_len): expect_feat_len = _kaldifeat.num_frames(l, self.opts.frame_opts) speech = torch.zeros((b, expect_feat_len, self.feature_size), - dtype=self.output0_dtype, device=self.device) - speech_lengths = torch.zeros((b, 1), dtype=torch.int32, device=self.device) + dtype=self.output0_dtype, + device=self.device) + speech_lengths = torch.zeros((b, 1), + dtype=torch.int32, + device=self.device) for i in range(b): f = features.pop(0) f_l = f.shape[0] - speech[i, 0: f_l, :] = f.to(self.output0_dtype) + speech[i, 0:f_l, :] = f.to(self.output0_dtype) speech_lengths[i][0] = f_l # put speech feature on device will cause empty output # we will follow this issue and now temporarily put it on cpu @@ -153,6 +159,7 @@ def execute(self, requests): out0 = pb_utils.Tensor.from_dlpack("speech", to_dlpack(speech)) out1 = pb_utils.Tensor.from_dlpack("speech_lengths", to_dlpack(speech_lengths)) - inference_response = pb_utils.InferenceResponse(output_tensors=[out0, out1]) + inference_response = pb_utils.InferenceResponse( + output_tensors=[out0, out1]) responses.append(inference_response) return responses diff --git a/runtime/gpu/cuda_decoders/model_repo_cuda_decoder/scoring/1/decoder.py b/runtime/gpu/cuda_decoders/model_repo_cuda_decoder/scoring/1/decoder.py index f3e520166..ba1aac066 100755 --- a/runtime/gpu/cuda_decoders/model_repo_cuda_decoder/scoring/1/decoder.py +++ b/runtime/gpu/cuda_decoders/model_repo_cuda_decoder/scoring/1/decoder.py @@ -5,6 +5,7 @@ BatchedMappedDecoderCudaConfig) from frame_reducer import FrameReducer + def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: """Make mask tensor containing indices of padded part. See description of make_non_pad_mask. @@ -30,6 +31,7 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: mask = seq_range_expand >= seq_length_expand return mask + def remove_duplicates_and_blank(hyp: List[int], eos: int, blank_id: int = 0) -> List[int]: @@ -43,6 +45,7 @@ def remove_duplicates_and_blank(hyp: List[int], cur += 1 return new_hyp + def ctc_greedy_search(ctc_probs, encoder_out_lens, vocabulary, blank_id, eos): batch_size, maxlen = ctc_probs.size()[:2] topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1) @@ -56,6 +59,7 @@ def ctc_greedy_search(ctc_probs, encoder_out_lens, vocabulary, blank_id, eos): total_hyps.append("".join([vocabulary[i] for i in hyp])) return total_hyps + def load_word_symbols(path): word_id_to_word_str = {} with open(path, "rt", encoding="utf-8") as fh: @@ -64,20 +68,29 @@ def load_word_symbols(path): word_id_to_word_str[int(word_id)] = word_str return word_id_to_word_str + class RivaWFSTDecoder: + def __init__(self, vocab_size, tlg_dir, config_dict, nbest=10): config = BatchedMappedDecoderCudaConfig() - config.online_opts.decoder_opts.lattice_beam = config_dict['lattice_beam'] - config.online_opts.lattice_postprocessor_opts.acoustic_scale = config_dict['acoustic_scale'] # noqa + config.online_opts.decoder_opts.lattice_beam = config_dict[ + 'lattice_beam'] + config.online_opts.lattice_postprocessor_opts.acoustic_scale = config_dict[ + 'acoustic_scale'] # noqa config.n_input_per_chunk = config_dict['n_input_per_chunk'] - config.online_opts.decoder_opts.default_beam = config_dict['default_beam'] + config.online_opts.decoder_opts.default_beam = config_dict[ + 'default_beam'] config.online_opts.decoder_opts.max_active = config_dict['max_active'] - config.online_opts.determinize_lattice = config_dict['determinize_lattice'] + config.online_opts.determinize_lattice = config_dict[ + 'determinize_lattice'] config.online_opts.max_batch_size = config_dict['max_batch_size'] config.online_opts.num_channels = config_dict['num_channels'] - config.online_opts.frame_shift_seconds = config_dict['frame_shift_seconds'] - config.online_opts.lattice_postprocessor_opts.lm_scale = config_dict['lm_scale'] - config.online_opts.lattice_postprocessor_opts.word_ins_penalty = config_dict['word_ins_penalty'] # noqa + config.online_opts.frame_shift_seconds = config_dict[ + 'frame_shift_seconds'] + config.online_opts.lattice_postprocessor_opts.lm_scale = config_dict[ + 'lm_scale'] + config.online_opts.lattice_postprocessor_opts.word_ins_penalty = config_dict[ + 'word_ins_penalty'] # noqa config.online_opts.num_decoder_copy_threads = 2 config.online_opts.num_post_processing_worker_threads = 4 @@ -87,9 +100,9 @@ def __init__(self, vocab_size, tlg_dir, config_dict, nbest=10): self.decoder = BatchedMappedDecoderCuda( config, os.path.join(tlg_dir, "TLG.fst"), - os.path.join(tlg_dir, "words.txt"), vocab_size - ) - self.word_id_to_word_str = load_word_symbols(os.path.join(tlg_dir, "words.txt")) + os.path.join(tlg_dir, "words.txt"), vocab_size) + self.word_id_to_word_str = load_word_symbols( + os.path.join(tlg_dir, "words.txt")) self.nbest = nbest self.vocab_size = vocab_size self.frame_reducer = FrameReducer(0.98) @@ -108,7 +121,9 @@ def decode_nbest(self, logits, length): # since fst decoder adds 1 to the label id hyp_ids = [label - 1 for label in sent.ilabels] # padding for hyps_pad_sos_eos - new_hyp = [self.vocab_size - 1] + remove_duplicates_and_blank(hyp_ids, eos=self.vocab_size - 1, blank_id=0) + [self.vocab_size - 1] # noqa + new_hyp = [self.vocab_size - 1] + remove_duplicates_and_blank( + hyp_ids, eos=self.vocab_size - 1, + blank_id=0) + [self.vocab_size - 1] # noqa max_hyp_len = max(max_hyp_len, len(new_hyp)) nbest_id_list.append(new_hyp) @@ -118,7 +133,8 @@ def decode_nbest(self, logits, length): nbest_scores.append(sent.score) nbest_list += [""] * (self.nbest - len(nbest_list)) total_hyps.append(nbest_list) - nbest_id_list += [[self.vocab_size - 1, 0, self.vocab_size - 1]] * (self.nbest - len(nbest_id_list)) # noqa + nbest_id_list += [[self.vocab_size - 1, 0, self.vocab_size - 1] + ] * (self.nbest - len(nbest_id_list)) # noqa total_hyps_id.append(nbest_id_list) nbest_scores += [0.0] * (self.nbest - len(nbest_scores)) total_scores.append(nbest_scores) diff --git a/runtime/gpu/cuda_decoders/model_repo_cuda_decoder/scoring/1/frame_reducer.py b/runtime/gpu/cuda_decoders/model_repo_cuda_decoder/scoring/1/frame_reducer.py index 9c5e955be..453248188 100644 --- a/runtime/gpu/cuda_decoders/model_repo_cuda_decoder/scoring/1/frame_reducer.py +++ b/runtime/gpu/cuda_decoders/model_repo_cuda_decoder/scoring/1/frame_reducer.py @@ -25,6 +25,7 @@ import torch.nn as nn import torch.nn.functional as F + def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: """ Args: @@ -52,7 +53,6 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: return expaned_lengths >= lengths.unsqueeze(-1) - class FrameReducer(nn.Module): """The encoder output is first used to calculate the CTC posterior probability; then for each output frame, @@ -99,22 +99,17 @@ def forward( N, T, C = x.size() padding_mask = make_pad_mask(x_lens, x.size(1)) - non_blank_mask = (ctc_output[:, :, blank_id] < math.log(self.blank_threshlod)) * (~padding_mask) # noqa + non_blank_mask = (ctc_output[:, :, blank_id] < math.log( + self.blank_threshlod)) * (~padding_mask) # noqa if y_lens is not None: # Limit the maximum number of reduced frames limit_lens = T - y_lens max_limit_len = limit_lens.max().int() - fake_limit_indexes = torch.topk( - ctc_output[:, :, blank_id], max_limit_len - ).indices - T = ( - torch.arange(max_limit_len) - .expand_as( - fake_limit_indexes, - ) - .to(device=x.device) - ) + fake_limit_indexes = torch.topk(ctc_output[:, :, blank_id], + max_limit_len).indices + T = (torch.arange(max_limit_len).expand_as( + fake_limit_indexes, ).to(device=x.device)) T = torch.remainder(T, limit_lens.unsqueeze(1)) limit_indexes = torch.gather(fake_limit_indexes, 1, T) limit_mask = torch.full_like( @@ -127,20 +122,18 @@ def forward( out_lens = non_blank_mask.sum(dim=1) max_len = out_lens.max() - pad_lens_list = ( - torch.full_like( - out_lens, - max_len.item(), - device=x.device, - ) - - out_lens - ) + pad_lens_list = (torch.full_like( + out_lens, + max_len.item(), + device=x.device, + ) - out_lens) max_pad_len = pad_lens_list.max() out = F.pad(x, (0, 0, 0, max_pad_len)) valid_pad_mask = ~make_pad_mask(pad_lens_list) - total_valid_mask = torch.concat([non_blank_mask, valid_pad_mask], dim=1) + total_valid_mask = torch.concat([non_blank_mask, valid_pad_mask], + dim=1) out = out[total_valid_mask].reshape(N, -1, C) @@ -159,8 +152,7 @@ def forward( x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device) y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device) ctc_output = torch.log( - torch.randn(15, 498, 500, dtype=torch.float32, device=device), - ) + torch.randn(15, 498, 500, dtype=torch.float32, device=device), ) avg_time = 0 for i in range(test_times): diff --git a/runtime/gpu/cuda_decoders/model_repo_cuda_decoder/scoring/1/model.py b/runtime/gpu/cuda_decoders/model_repo_cuda_decoder/scoring/1/model.py index b84dcbdfc..2ab31c896 100755 --- a/runtime/gpu/cuda_decoders/model_repo_cuda_decoder/scoring/1/model.py +++ b/runtime/gpu/cuda_decoders/model_repo_cuda_decoder/scoring/1/model.py @@ -22,6 +22,7 @@ import yaml from decoder import RivaWFSTDecoder, ctc_greedy_search + class TritonPythonModel: """Your Python model must use the same class name. Every Python model that is created must have "TritonPythonModel" as the class name. @@ -124,8 +125,7 @@ def init_decoder(self, parameters): self.ignore_id = ignore_id if "tlg" in self.decoding_method: - self.decoder = RivaWFSTDecoder(len(self.vocabulary), - self.tlg_dir, + self.decoder = RivaWFSTDecoder(len(self.vocabulary), self.tlg_dir, self.tlg_decoding_config, self.beam_size) @@ -149,7 +149,8 @@ def collect_inputs(self, requests): for request in requests: # Perform inference on the request and append it to responses list... in_0 = pb_utils.get_input_tensor_by_name(request, "encoder_out") - in_1 = pb_utils.get_input_tensor_by_name(request, "encoder_out_lens") + in_1 = pb_utils.get_input_tensor_by_name(request, + "encoder_out_lens") in_2 = pb_utils.get_input_tensor_by_name(request, "ctc_log_probs") in_0_tensor = from_dlpack(in_0.to_dlpack()) @@ -163,9 +164,14 @@ def collect_inputs(self, requests): batch_count_list.append(in_0_tensor.shape[0]) encoder_tensors, logits_tensors = [], [] - for encoder_tensor, logits_tensor in zip(encoder_out_list, ctc_log_probs_list): - encoder_tensors += [item.squeeze(0) for item in encoder_tensor.split(1)] - logits_tensors += [item.squeeze(0) for item in logits_tensor.split(1)] + for encoder_tensor, logits_tensor in zip(encoder_out_list, + ctc_log_probs_list): + encoder_tensors += [ + item.squeeze(0) for item in encoder_tensor.split(1) + ] + logits_tensors += [ + item.squeeze(0) for item in logits_tensor.split(1) + ] encoder_out = torch.nn.utils.rnn.pad_sequence(encoder_tensors, batch_first=True, padding_value=0.0) @@ -175,8 +181,8 @@ def collect_inputs(self, requests): encoder_out_len = torch.cat(encoder_out_lens_list, dim=0) return encoder_out, encoder_out_len, logits, batch_count_list - def rescore_hyps(self, total_tokens, nbest_scores, - max_hyp_len, encoder_out, encoder_out_len): + def rescore_hyps(self, total_tokens, nbest_scores, max_hyp_len, + encoder_out, encoder_out_len): """ Rescore the hypotheses with attention rescoring """ @@ -193,7 +199,8 @@ def prepare_response(self, hyps, batch_count_list): for b in batch_count_list: sents = np.array(hyps[st:st + b]) out0 = pb_utils.Tensor("OUTPUT0", sents.astype(self.out0_dtype)) - inference_response = pb_utils.InferenceResponse(output_tensors=[out0]) + inference_response = pb_utils.InferenceResponse( + output_tensors=[out0]) responses.append(inference_response) st += b return responses @@ -221,25 +228,29 @@ def execute(self, requests): # as they will be overridden in subsequent inference requests. You can # make a copy of the underlying NumPy array and store it if it is # required. - encoder_out, encoder_out_len, ctc_log_probs, batch_count = self.collect_inputs(requests) # noqa + encoder_out, encoder_out_len, ctc_log_probs, batch_count = self.collect_inputs( + requests) # noqa ctc_log_probs = ctc_log_probs.cuda() if self.decoding_method == "tlg_mbr": - total_hyps = self.decoder.decode_mbr(ctc_log_probs, encoder_out_len) + total_hyps = self.decoder.decode_mbr(ctc_log_probs, + encoder_out_len) elif self.decoding_method == "ctc_greedy_search": total_hyps = ctc_greedy_search(ctc_log_probs, encoder_out_len, - self.vocabulary, self.blank_id, self.eos) + self.vocabulary, self.blank_id, + self.eos) elif self.decoding_method == "tlg": - nbest_hyps, nbest_ids, nbest_scores, max_hyp_len = self.decoder.decode_nbest(ctc_log_probs, encoder_out_len) # noqa + nbest_hyps, nbest_ids, nbest_scores, max_hyp_len = self.decoder.decode_nbest( \ + ctc_log_probs, encoder_out_len) # noqa total_hyps = [nbest[0] for nbest in nbest_hyps] if self.decoding_method == "tlg" and self.rescore: assert self.beam_size > 1, "Beam size must be greater than 1 for rescoring" - selected_ids = self.rescore_hyps(nbest_ids, - nbest_scores, - max_hyp_len, - encoder_out, + selected_ids = self.rescore_hyps(nbest_ids, nbest_scores, + max_hyp_len, encoder_out, encoder_out_len) - total_hyps = [nbest[i] for nbest, i in zip(nbest_hyps, selected_ids)] + total_hyps = [ + nbest[i] for nbest, i in zip(nbest_hyps, selected_ids) + ] responses = self.prepare_response(total_hyps, batch_count) return responses diff --git a/runtime/gpu/cuda_decoders/model_repo_stateful_cuda_decoder/feature_extractor/1/model.py b/runtime/gpu/cuda_decoders/model_repo_stateful_cuda_decoder/feature_extractor/1/model.py index ce1f340f8..459855e59 100644 --- a/runtime/gpu/cuda_decoders/model_repo_stateful_cuda_decoder/feature_extractor/1/model.py +++ b/runtime/gpu/cuda_decoders/model_repo_stateful_cuda_decoder/feature_extractor/1/model.py @@ -21,7 +21,9 @@ import json import numpy as np + class Fbank(torch.nn.Module): + def __init__(self, opts): super(Fbank, self).__init__() self.fbank = kaldifeat.Fbank(opts) @@ -29,9 +31,16 @@ def __init__(self, opts): def forward(self, waves: List[torch.Tensor]): return self.fbank(waves) + class Feat(object): - def __init__(self, seqid, offset_ms, sample_rate, - first_chunk_sz, frame_stride, device='cpu'): + + def __init__(self, + seqid, + offset_ms, + sample_rate, + first_chunk_sz, + frame_stride, + device='cpu'): self.seqid = seqid self.sample_rate = sample_rate self.wav = torch.tensor([], device=device) @@ -62,10 +71,11 @@ def add_frames(self, frames: torch.tensor): self.frames = torch.cat([self.frames, frames], axis=0) def get_frames(self, num_frames: int): - seg = self.frames[0: num_frames] + seg = self.frames[0:num_frames] self.frames = self.frames[self.frame_stride:] return seg + class TritonPythonModel: """Your Python model must use the same class name. Every Python model that is created must have "TritonPythonModel" as the class name. @@ -138,7 +148,8 @@ def initialize(self, args): cur_frames = _kaldifeat.num_frames(first_chunk_size, opts.frame_opts) while cur_frames < self.decoding_window: first_chunk_size += frame_shift_ms * sample_rate // 1000 - cur_frames = _kaldifeat.num_frames(first_chunk_size, opts.frame_opts) + cur_frames = _kaldifeat.num_frames(first_chunk_size, + opts.frame_opts) # self.pad_silence = first_chunk_size - self.chunk_size self.first_chunk_size = first_chunk_size self.offset_ms = self.get_offset(frame_length_ms, frame_shift_ms) @@ -157,7 +168,8 @@ def parse_model_params(self, model_params): "frame_length_ms": 25, "frame_shift_ms": 10, "sample_rate": 16000, - "chunk_size_s": 0.64} + "chunk_size_s": 0.64 + } # get parameter configurations for li in model_params.items(): key, value = li @@ -212,8 +224,7 @@ def execute(self, requests): self.seq_feat[corrid] = Feat(corrid, self.offset_ms, self.sample_rate, self.first_chunk_size, - self.frame_stride, - self.device) + self.frame_stride, self.device) if ready: self.seq_feat[corrid].add_wavs(wavs[0:wav_lens]) @@ -226,7 +237,8 @@ def execute(self, requests): wav = self.seq_feat[corrid].get_seg_wav() * 32768 if len(wav) < self.min_seg: - temp = torch.zeros(self.min_seg, dtype=torch.float32, + temp = torch.zeros(self.min_seg, + dtype=torch.float32, device=self.device) temp[0:len(wav)] = wav[:] wav = temp @@ -235,15 +247,16 @@ def execute(self, requests): features = self.feature_extractor(total_waves) batch_size = len(batch_seqid) - batch_speech = torch.zeros((batch_size, self.decoding_window, - self.feature_size), dtype=self.dtype) + batch_speech = torch.zeros( + (batch_size, self.decoding_window, self.feature_size), + dtype=self.dtype) batch_speech_lens = torch.zeros((batch_size, 1), dtype=torch.int32) i = 0 for corrid, frames in zip(batch_seqid, features): self.seq_feat[corrid].add_frames(frames) r_frames = self.seq_feat[corrid].get_frames(self.decoding_window) - speech = batch_speech[i: i + 1] - speech_lengths = batch_speech_lens[i: i + 1] + speech = batch_speech[i:i + 1] + speech_lengths = batch_speech_lens[i:i + 1] i += 1 speech_lengths[0] = r_frames.size(0) speech[0][0:r_frames.size(0)] = r_frames.to(speech.device) @@ -251,9 +264,11 @@ def execute(self, requests): # out_tensor1 = pb_utils.Tensor.from_dlpack("speech_lengths", # to_dlpack(speech_lengths)) out_tensor0 = pb_utils.Tensor("speech", speech.numpy()) - out_tensor1 = pb_utils.Tensor("speech_lengths", speech_lengths.numpy()) + out_tensor1 = pb_utils.Tensor("speech_lengths", + speech_lengths.numpy()) output_tensors = [out_tensor0, out_tensor1] - response = pb_utils.InferenceResponse(output_tensors=output_tensors) + response = pb_utils.InferenceResponse( + output_tensors=output_tensors) responses.append(response) if corrid in end_seqid: del self.seq_feat[corrid] diff --git a/runtime/gpu/cuda_decoders/model_repo_stateful_cuda_decoder/scoring/1/decoder.py b/runtime/gpu/cuda_decoders/model_repo_stateful_cuda_decoder/scoring/1/decoder.py index f59be2364..353310089 100644 --- a/runtime/gpu/cuda_decoders/model_repo_stateful_cuda_decoder/scoring/1/decoder.py +++ b/runtime/gpu/cuda_decoders/model_repo_stateful_cuda_decoder/scoring/1/decoder.py @@ -24,18 +24,19 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: """ batch_size = lengths.size(0) max_len = max_len if max_len > 0 else lengths.max().item() - seq_range = torch.arange( - 0, max_len, dtype=torch.int64, device=lengths.device - ) + seq_range = torch.arange(0, + max_len, + dtype=torch.int64, + device=lengths.device) seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) seq_length_expand = lengths.unsqueeze(-1) mask = seq_range_expand >= seq_length_expand return mask -def remove_duplicates_and_blank( - hyp: List[int], eos: int, blank_id: int = 0 -) -> List[int]: +def remove_duplicates_and_blank(hyp: List[int], + eos: int, + blank_id: int = 0) -> List[int]: new_hyp: List[int] = [] cur = 0 while cur < len(hyp): @@ -71,33 +72,27 @@ def load_word_symbols(path): class RivaWFSTOnlineDecoder: + def __init__(self, vocab_size, tlg_dir, config_dict): config = BatchedMappedDecoderCudaConfig() config.online_opts.decoder_opts.lattice_beam = config_dict[ - "lattice_beam" - ] + "lattice_beam"] config.online_opts.lattice_postprocessor_opts.acoustic_scale = config_dict[ - "acoustic_scale" - ] # noqa + "acoustic_scale"] # noqa config.n_input_per_chunk = config_dict["n_input_per_chunk"] config.online_opts.decoder_opts.default_beam = config_dict[ - "default_beam" - ] + "default_beam"] config.online_opts.decoder_opts.max_active = config_dict["max_active"] config.online_opts.determinize_lattice = config_dict[ - "determinize_lattice" - ] + "determinize_lattice"] config.online_opts.max_batch_size = config_dict["max_batch_size"] config.online_opts.num_channels = config_dict["num_channels"] config.online_opts.frame_shift_seconds = config_dict[ - "frame_shift_seconds" - ] + "frame_shift_seconds"] config.online_opts.lattice_postprocessor_opts.lm_scale = config_dict[ - "lm_scale" - ] + "lm_scale"] config.online_opts.lattice_postprocessor_opts.word_ins_penalty = config_dict[ - "word_ins_penalty" - ] # noqa + "word_ins_penalty"] # noqa config.online_opts.decoder_opts.ntokens_pre_allocated = 10_000_000 config.online_opts.num_decoder_copy_threads = 2 @@ -111,8 +106,7 @@ def __init__(self, vocab_size, tlg_dir, config_dict): vocab_size, ) self.word_id_to_word_str = load_word_symbols( - os.path.join(tlg_dir, "words.txt") - ) + os.path.join(tlg_dir, "words.txt")) # self.frame_reducer = FrameReducer(0.98) def decode_batch( @@ -135,15 +129,14 @@ def decode_batch( # log_probs_list = [t for t in torch.unbind(ctc_log_probs, dim=0)] log_probs_list = [] for i, ctc_log_prob in enumerate(ctc_log_probs): - log_probs_list.append(ctc_log_prob[: encoder_out_lens[i]]) - _, hypos = self.decoder.decode_batch( - corr_ids, log_probs_list, is_first_chunk_list, is_last_chunk_list - ) + log_probs_list.append(ctc_log_prob[:encoder_out_lens[i]]) + _, hypos = self.decoder.decode_batch(corr_ids, log_probs_list, + is_first_chunk_list, + is_last_chunk_list) total_hyps = [] for sent in hypos: - hyp = sep_symbol.join( - self.word_id_to_word_str[word] for word in sent.words - ) + hyp = sep_symbol.join(self.word_id_to_word_str[word] + for word in sent.words) total_hyps.append(hyp) return total_hyps diff --git a/runtime/gpu/cuda_decoders/model_repo_stateful_cuda_decoder/scoring/1/frame_reducer.py b/runtime/gpu/cuda_decoders/model_repo_stateful_cuda_decoder/scoring/1/frame_reducer.py index 048387c0e..fb2ad41e4 100644 --- a/runtime/gpu/cuda_decoders/model_repo_stateful_cuda_decoder/scoring/1/frame_reducer.py +++ b/runtime/gpu/cuda_decoders/model_repo_stateful_cuda_decoder/scoring/1/frame_reducer.py @@ -25,6 +25,7 @@ import torch.nn as nn import torch.nn.functional as F + def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: """ Args: @@ -52,7 +53,6 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: return expaned_lengths >= lengths.unsqueeze(-1) - class FrameReducer(nn.Module): """The encoder output is first used to calculate the CTC posterior probability; then for each output frame, @@ -99,22 +99,17 @@ def forward( N, T, C = x.size() padding_mask = make_pad_mask(x_lens, x.size(1)) - non_blank_mask = (ctc_output[:, :, blank_id] < math.log(self.blank_threshlod)) * (~padding_mask) # noqa + non_blank_mask = (ctc_output[:, :, blank_id] < math.log( + self.blank_threshlod)) * (~padding_mask) # noqa if y_lens is not None: # Limit the maximum number of reduced frames limit_lens = T - y_lens max_limit_len = limit_lens.max().int() - fake_limit_indexes = torch.topk( - ctc_output[:, :, blank_id], max_limit_len - ).indices - T = ( - torch.arange(max_limit_len) - .expand_as( - fake_limit_indexes, - ) - .to(device=x.device) - ) + fake_limit_indexes = torch.topk(ctc_output[:, :, blank_id], + max_limit_len).indices + T = (torch.arange(max_limit_len).expand_as( + fake_limit_indexes, ).to(device=x.device)) T = torch.remainder(T, limit_lens.unsqueeze(1)) limit_indexes = torch.gather(fake_limit_indexes, 1, T) limit_mask = torch.full_like( @@ -127,20 +122,18 @@ def forward( out_lens = non_blank_mask.sum(dim=1) max_len = out_lens.max() - pad_lens_list = ( - torch.full_like( - out_lens, - max_len.item(), - device=x.device, - ) - - out_lens - ) + pad_lens_list = (torch.full_like( + out_lens, + max_len.item(), + device=x.device, + ) - out_lens) max_pad_len = pad_lens_list.max() out = F.pad(x, (0, 0, 0, max_pad_len)) valid_pad_mask = ~make_pad_mask(pad_lens_list) - total_valid_mask = torch.concat([non_blank_mask, valid_pad_mask], dim=1) + total_valid_mask = torch.concat([non_blank_mask, valid_pad_mask], + dim=1) out = out[total_valid_mask].reshape(N, -1, C) @@ -160,8 +153,7 @@ def forward( x_lens = torch.tensor([seq_len] * 15, dtype=torch.int64, device=device) y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device) ctc_output = torch.log( - torch.randn(15, seq_len, 500, dtype=torch.float32, device=device), - ) + torch.randn(15, seq_len, 500, dtype=torch.float32, device=device), ) avg_time = 0 for i in range(test_times): diff --git a/runtime/gpu/cuda_decoders/model_repo_stateful_cuda_decoder/scoring/1/model.py b/runtime/gpu/cuda_decoders/model_repo_stateful_cuda_decoder/scoring/1/model.py index 12e4cb9ed..6bf20737e 100644 --- a/runtime/gpu/cuda_decoders/model_repo_stateful_cuda_decoder/scoring/1/model.py +++ b/runtime/gpu/cuda_decoders/model_repo_stateful_cuda_decoder/scoring/1/model.py @@ -59,13 +59,11 @@ def initialize(self, args): # Get OUTPUT0 configuration output0_config = pb_utils.get_output_config_by_name( - model_config, "OUTPUT0" - ) + model_config, "OUTPUT0") # Convert Triton types to numpy types self.output0_dtype = pb_utils.triton_string_to_numpy( - output0_config["data_type"] - ) + output0_config["data_type"]) self.init_decoder(self.model_config["parameters"]) @@ -96,9 +94,9 @@ def init_decoder(self, parameters): self.sos = self.eos = len(vocab) - 1 if "tlg" in self.decoding_method: - self.decoder = RivaWFSTOnlineDecoder( - len(self.vocabulary), self.tlg_dir, self.tlg_decoding_config - ) + self.decoder = RivaWFSTOnlineDecoder(len(self.vocabulary), + self.tlg_dir, + self.tlg_decoding_config) def load_vocab(self, vocab_file): """ @@ -179,10 +177,10 @@ def execute(self, requests): responses = [] for sentence in total_hyps: sent = np.array(sentence) - out_tensor_0 = pb_utils.Tensor( - "OUTPUT0", sent.astype(self.output0_dtype) - ) - response = pb_utils.InferenceResponse(output_tensors=[out_tensor_0]) + out_tensor_0 = pb_utils.Tensor("OUTPUT0", + sent.astype(self.output0_dtype)) + response = pb_utils.InferenceResponse( + output_tensors=[out_tensor_0]) responses.append(response) assert len(requests) == len(responses) return responses diff --git a/runtime/gpu/model_repo/feature_extractor/1/model.py b/runtime/gpu/model_repo/feature_extractor/1/model.py index 4a2c258cc..0623335c0 100644 --- a/runtime/gpu/model_repo/feature_extractor/1/model.py +++ b/runtime/gpu/model_repo/feature_extractor/1/model.py @@ -7,7 +7,9 @@ from typing import List import json + class Fbank(torch.nn.Module): + def __init__(self, opts): super(Fbank, self).__init__() self.fbank = kaldifeat.Fbank(opts) @@ -117,7 +119,8 @@ def execute(self, requests): batch_len.append(cur_len) for wav, wav_len in zip(cur_b_wav, cur_b_wav_lens): wav_len = wav_len[0] - wav = torch.tensor(wav[0:wav_len], dtype=torch.float32, + wav = torch.tensor(wav[0:wav_len], + dtype=torch.float32, device=self.device) total_waves.append(wav) @@ -126,12 +129,15 @@ def execute(self, requests): for b, l in zip(batch_count, batch_len): expect_feat_len = _kaldifeat.num_frames(l, self.opts.frame_opts) speech = torch.zeros((b, expect_feat_len, self.feature_size), - dtype=self.output0_dtype, device=self.device) - speech_lengths = torch.zeros((b, 1), dtype=torch.int32, device=self.device) + dtype=self.output0_dtype, + device=self.device) + speech_lengths = torch.zeros((b, 1), + dtype=torch.int32, + device=self.device) for i in range(b): f = features[idx] f_l = f.shape[0] - speech[i, 0: f_l, :] = f.to(self.output0_dtype) + speech[i, 0:f_l, :] = f.to(self.output0_dtype) speech_lengths[i][0] = f_l idx += 1 # put speech feature on device will cause empty output @@ -141,6 +147,7 @@ def execute(self, requests): out0 = pb_utils.Tensor.from_dlpack("speech", to_dlpack(speech)) out1 = pb_utils.Tensor.from_dlpack("speech_lengths", to_dlpack(speech_lengths)) - inference_response = pb_utils.InferenceResponse(output_tensors=[out0, out1]) + inference_response = pb_utils.InferenceResponse( + output_tensors=[out0, out1]) responses.append(inference_response) return responses diff --git a/runtime/gpu/model_repo/scoring/1/model.py b/runtime/gpu/model_repo/scoring/1/model.py index 63c1a36fa..98b331ad5 100644 --- a/runtime/gpu/model_repo/scoring/1/model.py +++ b/runtime/gpu/model_repo/scoring/1/model.py @@ -23,6 +23,7 @@ import os import yaml + class TritonPythonModel: """Your Python model must use the same class name. Every Python model that is created must have "TritonPythonModel" as the class name. @@ -114,11 +115,14 @@ def init_ctc_rescore(self, parameters): if self.hotwords is not None: for w in self.hotwords: max_order = max(max_order, len(w)) - self.hotwords_scorer = HotWordsScorer(self.hotwords, vocab, + self.hotwords_scorer = HotWordsScorer(self.hotwords, + vocab, window_length=max_order, SPACE_ID=-2, is_character_based=True) - print(f"Successfully load hotwords! Hotwords orders = {max_order}") + print( + f"Successfully load hotwords! Hotwords orders = {max_order}" + ) self.vocabulary = vocab self.bidecoder = bidecoder sos = eos = len(vocab) - 1 @@ -188,12 +192,16 @@ def execute(self, requests): for request in requests: # Perform inference on the request and append it to responses list... in_0 = pb_utils.get_input_tensor_by_name(request, "encoder_out") - in_1 = pb_utils.get_input_tensor_by_name(request, "encoder_out_lens") - in_2 = pb_utils.get_input_tensor_by_name(request, "batch_log_probs") - in_3 = pb_utils.get_input_tensor_by_name(request, "batch_log_probs_idx") + in_1 = pb_utils.get_input_tensor_by_name(request, + "encoder_out_lens") + in_2 = pb_utils.get_input_tensor_by_name(request, + "batch_log_probs") + in_3 = pb_utils.get_input_tensor_by_name(request, + "batch_log_probs_idx") batch_encoder_out.append(in_0.as_numpy()) - encoder_max_len = max(encoder_max_len, batch_encoder_out[-1].shape[1]) + encoder_max_len = max(encoder_max_len, + batch_encoder_out[-1].shape[1]) cur_b_lens = in_1.as_numpy() batch_encoder_lens.append(cur_b_lens) @@ -204,8 +212,10 @@ def execute(self, requests): cur_b_log_probs_idx = in_3.as_numpy() for i in range(cur_batch): cur_len = cur_b_lens[i] - cur_probs = cur_b_log_probs[i][0:cur_len, :].tolist() # T X Beam - cur_idx = cur_b_log_probs_idx[i][0:cur_len, :].tolist() # T x Beam + cur_probs = cur_b_log_probs[i][ + 0:cur_len, :].tolist() # T X Beam + cur_idx = cur_b_log_probs_idx[i][ + 0:cur_len, :].tolist() # T x Beam batch_log_probs.append(cur_probs) batch_log_probs_idx.append(cur_idx) root_dict[total] = PathTrie() @@ -213,17 +223,18 @@ def execute(self, requests): batch_start.append(True) total += 1 - score_hyps = ctc_beam_search_decoder_batch(batch_log_probs, - batch_log_probs_idx, - batch_root, - batch_start, - self.beam_size, - min(total, self.num_processes), - blank_id=self.blank_id, - space_id=-2, - cutoff_prob=self.cutoff_prob, - ext_scorer=self.lm, - hotwords_scorer=self.hotwords_scorer) + score_hyps = ctc_beam_search_decoder_batch( + batch_log_probs, + batch_log_probs_idx, + batch_root, + batch_start, + self.beam_size, + min(total, self.num_processes), + blank_id=self.blank_id, + space_id=-2, + cutoff_prob=self.cutoff_prob, + ext_scorer=self.lm, + hotwords_scorer=self.hotwords_scorer) all_hyps = [] all_ctc_score = [] max_seq_len = 0 @@ -231,7 +242,8 @@ def execute(self, requests): # if candidates less than beam size if len(seq_cand) != self.beam_size: seq_cand = list(seq_cand) - seq_cand += (self.beam_size - len(seq_cand)) * [(-float("INF"), (0,))] + seq_cand += (self.beam_size - len(seq_cand)) * [(-float("INF"), + (0, ))] for score, hyps in seq_cand: all_hyps.append(list(hyps)) @@ -277,7 +289,8 @@ def execute(self, requests): in_tensor_3 = pb_utils.Tensor("hyps_lens_sos", in_hyps_lens_sos) input_tensors = [in_tensor_0, in_tensor_1, in_tensor_2, in_tensor_3] if self.bidecoder: - in_tensor_4 = pb_utils.Tensor("r_hyps_pad_sos_eos", in_r_hyps_pad_sos_eos) + in_tensor_4 = pb_utils.Tensor("r_hyps_pad_sos_eos", + in_r_hyps_pad_sos_eos) input_tensors.append(in_tensor_4) in_tensor_5 = pb_utils.Tensor("ctc_score", in_ctc_score) input_tensors.append(in_tensor_5) @@ -289,11 +302,12 @@ def execute(self, requests): inference_response = inference_request.exec() if inference_response.has_error(): - raise pb_utils.TritonModelException(inference_response.error().message()) + raise pb_utils.TritonModelException( + inference_response.error().message()) else: # Extract the output tensors from the inference response. - best_index = pb_utils.get_output_tensor_by_name(inference_response, - 'best_index') + best_index = pb_utils.get_output_tensor_by_name( + inference_response, 'best_index') if best_index.is_cpu(): best_index = best_index.as_numpy() else: @@ -304,17 +318,20 @@ def execute(self, requests): for cands, cand_lens in zip(in_hyps_pad_sos_eos, in_hyps_lens_sos): best_idx = best_index[idx][0] best_cand_len = cand_lens[best_idx] - 1 # remove sos - best_cand = cands[best_idx][1: 1 + best_cand_len].tolist() + best_cand = cands[best_idx][1:1 + best_cand_len].tolist() hyps.append(best_cand) idx += 1 - hyps = map_batch(hyps, self.vocabulary, - min(multiprocessing.cpu_count(), len(in_ctc_score))) + hyps = map_batch( + hyps, self.vocabulary, + min(multiprocessing.cpu_count(), len(in_ctc_score))) st = 0 for b in batch_count: sents = np.array(hyps[st:st + b]) - out0 = pb_utils.Tensor("OUTPUT0", sents.astype(self.out0_dtype)) - inference_response = pb_utils.InferenceResponse(output_tensors=[out0]) + out0 = pb_utils.Tensor("OUTPUT0", + sents.astype(self.out0_dtype)) + inference_response = pb_utils.InferenceResponse( + output_tensors=[out0]) responses.append(inference_response) st += b return responses diff --git a/runtime/gpu/model_repo_stateful/feature_extractor/1/model.py b/runtime/gpu/model_repo_stateful/feature_extractor/1/model.py index ce1f340f8..459855e59 100644 --- a/runtime/gpu/model_repo_stateful/feature_extractor/1/model.py +++ b/runtime/gpu/model_repo_stateful/feature_extractor/1/model.py @@ -21,7 +21,9 @@ import json import numpy as np + class Fbank(torch.nn.Module): + def __init__(self, opts): super(Fbank, self).__init__() self.fbank = kaldifeat.Fbank(opts) @@ -29,9 +31,16 @@ def __init__(self, opts): def forward(self, waves: List[torch.Tensor]): return self.fbank(waves) + class Feat(object): - def __init__(self, seqid, offset_ms, sample_rate, - first_chunk_sz, frame_stride, device='cpu'): + + def __init__(self, + seqid, + offset_ms, + sample_rate, + first_chunk_sz, + frame_stride, + device='cpu'): self.seqid = seqid self.sample_rate = sample_rate self.wav = torch.tensor([], device=device) @@ -62,10 +71,11 @@ def add_frames(self, frames: torch.tensor): self.frames = torch.cat([self.frames, frames], axis=0) def get_frames(self, num_frames: int): - seg = self.frames[0: num_frames] + seg = self.frames[0:num_frames] self.frames = self.frames[self.frame_stride:] return seg + class TritonPythonModel: """Your Python model must use the same class name. Every Python model that is created must have "TritonPythonModel" as the class name. @@ -138,7 +148,8 @@ def initialize(self, args): cur_frames = _kaldifeat.num_frames(first_chunk_size, opts.frame_opts) while cur_frames < self.decoding_window: first_chunk_size += frame_shift_ms * sample_rate // 1000 - cur_frames = _kaldifeat.num_frames(first_chunk_size, opts.frame_opts) + cur_frames = _kaldifeat.num_frames(first_chunk_size, + opts.frame_opts) # self.pad_silence = first_chunk_size - self.chunk_size self.first_chunk_size = first_chunk_size self.offset_ms = self.get_offset(frame_length_ms, frame_shift_ms) @@ -157,7 +168,8 @@ def parse_model_params(self, model_params): "frame_length_ms": 25, "frame_shift_ms": 10, "sample_rate": 16000, - "chunk_size_s": 0.64} + "chunk_size_s": 0.64 + } # get parameter configurations for li in model_params.items(): key, value = li @@ -212,8 +224,7 @@ def execute(self, requests): self.seq_feat[corrid] = Feat(corrid, self.offset_ms, self.sample_rate, self.first_chunk_size, - self.frame_stride, - self.device) + self.frame_stride, self.device) if ready: self.seq_feat[corrid].add_wavs(wavs[0:wav_lens]) @@ -226,7 +237,8 @@ def execute(self, requests): wav = self.seq_feat[corrid].get_seg_wav() * 32768 if len(wav) < self.min_seg: - temp = torch.zeros(self.min_seg, dtype=torch.float32, + temp = torch.zeros(self.min_seg, + dtype=torch.float32, device=self.device) temp[0:len(wav)] = wav[:] wav = temp @@ -235,15 +247,16 @@ def execute(self, requests): features = self.feature_extractor(total_waves) batch_size = len(batch_seqid) - batch_speech = torch.zeros((batch_size, self.decoding_window, - self.feature_size), dtype=self.dtype) + batch_speech = torch.zeros( + (batch_size, self.decoding_window, self.feature_size), + dtype=self.dtype) batch_speech_lens = torch.zeros((batch_size, 1), dtype=torch.int32) i = 0 for corrid, frames in zip(batch_seqid, features): self.seq_feat[corrid].add_frames(frames) r_frames = self.seq_feat[corrid].get_frames(self.decoding_window) - speech = batch_speech[i: i + 1] - speech_lengths = batch_speech_lens[i: i + 1] + speech = batch_speech[i:i + 1] + speech_lengths = batch_speech_lens[i:i + 1] i += 1 speech_lengths[0] = r_frames.size(0) speech[0][0:r_frames.size(0)] = r_frames.to(speech.device) @@ -251,9 +264,11 @@ def execute(self, requests): # out_tensor1 = pb_utils.Tensor.from_dlpack("speech_lengths", # to_dlpack(speech_lengths)) out_tensor0 = pb_utils.Tensor("speech", speech.numpy()) - out_tensor1 = pb_utils.Tensor("speech_lengths", speech_lengths.numpy()) + out_tensor1 = pb_utils.Tensor("speech_lengths", + speech_lengths.numpy()) output_tensors = [out_tensor0, out_tensor1] - response = pb_utils.InferenceResponse(output_tensors=output_tensors) + response = pb_utils.InferenceResponse( + output_tensors=output_tensors) responses.append(response) if corrid in end_seqid: del self.seq_feat[corrid] diff --git a/runtime/gpu/model_repo_stateful/wenet/1/model.py b/runtime/gpu/model_repo_stateful/wenet/1/model.py index 91e008dc9..bef3bd8b2 100644 --- a/runtime/gpu/model_repo_stateful/wenet/1/model.py +++ b/runtime/gpu/model_repo_stateful/wenet/1/model.py @@ -26,6 +26,7 @@ from torch.utils.dlpack import from_dlpack + class TritonPythonModel: """Your Python model must use the same class name. Every Python model that is created must have "TritonPythonModel" as the class name. @@ -146,14 +147,20 @@ def execute(self, requests): batch_idx += 1 - batch_states = [trieVector, batch_start, batch_encoder_hist, cur_encoder_out] - res_sents, new_states = self.model.infer(batch_log_probs, batch_log_probs_idx, - batch_len, rescore_index, batch_states) + batch_states = [ + trieVector, batch_start, batch_encoder_hist, cur_encoder_out + ] + res_sents, new_states = self.model.infer(batch_log_probs, + batch_log_probs_idx, + batch_len, rescore_index, + batch_states) cur_encoder_out = new_states for i in range(len(res_sents)): sent = np.array(res_sents[i]) - out_tensor_0 = pb_utils.Tensor("OUTPUT0", sent.astype(self.output0_dtype)) - response = pb_utils.InferenceResponse(output_tensors=[out_tensor_0]) + out_tensor_0 = pb_utils.Tensor("OUTPUT0", + sent.astype(self.output0_dtype)) + response = pb_utils.InferenceResponse( + output_tensors=[out_tensor_0]) responses.append(response) corr = batch_idx2_corrid[i] if i in rescore_index: @@ -164,8 +171,9 @@ def execute(self, requests): if self.seq_states[corr][1] is None: self.seq_states[corr][1] = cur_encoder_out[i] else: - new_hist = torch.cat([self.seq_states[corr][1], - cur_encoder_out[i]], axis=0) + new_hist = torch.cat( + [self.seq_states[corr][1], cur_encoder_out[i]], + axis=0) self.seq_states[corr][1] = new_hist assert len(requests) == len(responses) diff --git a/runtime/gpu/model_repo_stateful/wenet/1/wenet_onnx_model.py b/runtime/gpu/model_repo_stateful/wenet/1/wenet_onnx_model.py index 07303712d..6ea503591 100644 --- a/runtime/gpu/model_repo_stateful/wenet/1/wenet_onnx_model.py +++ b/runtime/gpu/model_repo_stateful/wenet/1/wenet_onnx_model.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import multiprocessing import numpy as np import os @@ -24,7 +23,9 @@ Scorer, HotWordsScorer, map_batch import yaml + class WenetModel(object): + def __init__(self, model_config, device): params = self.parse_model_parameters(model_config['parameters']) @@ -61,11 +62,14 @@ def __init__(self, model_config, device): if self.hotwords is not None: for w in self.hotwords: max_order = max(max_order, len(w)) - self.hotwords_scorer = HotWordsScorer(self.hotwords, self.vocab, + self.hotwords_scorer = HotWordsScorer(self.hotwords, + self.vocab, window_length=max_order, SPACE_ID=-2, is_character_based=True) - print(f"Successfully load hotwords! Hotwords orders = {max_order}") + print( + f"Successfully load hotwords! Hotwords orders = {max_order}" + ) self.bidecoder = params.get('bidecoder') # rescore setting @@ -119,15 +123,17 @@ def load_hotwords(self, hotwords_file): return configs def parse_model_parameters(self, model_parameters): - model_p = {"beam_size": 10, - "cutoff_prob": 0.999, - "vocab_path": None, - "lm_path": None, - "hotwords_path": None, - "alpha": 2.0, - "beta": 1.0, - "rescoring": 0, - "bidecoder": 1} + model_p = { + "beam_size": 10, + "cutoff_prob": 0.999, + "vocab_path": None, + "lm_path": None, + "hotwords_path": None, + "alpha": 2.0, + "beta": 1.0, + "rescoring": 0, + "bidecoder": 1 + } # get parameter configurations for li in model_parameters.items(): key, value = li @@ -142,8 +148,8 @@ def parse_model_parameters(self, model_parameters): assert model_p["vocab_path"] is not None return model_p - def infer(self, batch_log_probs, batch_log_probs_idx, - seq_lens, rescore_index, batch_states): + def infer(self, batch_log_probs, batch_log_probs_idx, seq_lens, + rescore_index, batch_states): """ batch_states = [trieVector, batch_start, batch_encoder_hist, cur_encoder_out] @@ -151,19 +157,20 @@ def infer(self, batch_log_probs, batch_log_probs_idx, trie_vector, batch_start, batch_encoder_hist, cur_encoder_out = batch_states num_processes = min(multiprocessing.cpu_count(), len(batch_log_probs)) - score_hyps = self.batch_ctc_prefix_beam_search_cpu(batch_log_probs, - batch_log_probs_idx, - seq_lens, - trie_vector, - batch_start, - self.beam_size, - self.blank_id, - self.space_id, - self.cutoff_prob, - num_processes, - self.scorer, - self.hotwords_scorer, - ) + score_hyps = self.batch_ctc_prefix_beam_search_cpu( + batch_log_probs, + batch_log_probs_idx, + seq_lens, + trie_vector, + batch_start, + self.beam_size, + self.blank_id, + self.space_id, + self.cutoff_prob, + num_processes, + self.scorer, + self.hotwords_scorer, + ) if self.rescoring and len(rescore_index) != 0: # find the end of sequence @@ -177,14 +184,16 @@ def infer(self, batch_log_probs, batch_log_probs_idx, if hist_enc is None: cur_enc = cur_encoder_out[idx] else: - cur_enc = torch.cat([hist_enc, cur_encoder_out[idx]], axis=0) + cur_enc = torch.cat([hist_enc, cur_encoder_out[idx]], + axis=0) rescore_encoder_hist.append(cur_enc) cur_mask_len = int(len(hist_enc) + seq_lens[idx]) rescore_encoder_lens.append(cur_mask_len) rescore_hyps.append(score_hyps[idx]) if cur_enc.shape[0] > max_length: max_length = cur_enc.shape[0] - best_index = self.batch_rescoring(rescore_hyps, rescore_encoder_hist, + best_index = self.batch_rescoring(rescore_hyps, + rescore_encoder_hist, rescore_encoder_lens, max_length) best_sent = [] @@ -201,12 +210,10 @@ def infer(self, batch_log_probs, batch_log_probs_idx, return final_result, cur_encoder_out def batch_ctc_prefix_beam_search_cpu(self, batch_log_probs_seq, - batch_log_probs_idx, - batch_len, batch_root, - batch_start, beam_size, - blank_id, space_id, - cutoff_prob, num_processes, - scorer, + batch_log_probs_idx, batch_len, + batch_root, batch_start, beam_size, + blank_id, space_id, cutoff_prob, + num_processes, scorer, hotwords_scorer): """ Return: Batch x Beam_size elements, each element is a tuple @@ -218,19 +225,14 @@ def batch_ctc_prefix_beam_search_cpu(self, batch_log_probs_seq, batch_log_probs_idx_list = [] for i in range(len(batch_len_list)): cur_len = int(batch_len_list[i]) - batch_log_probs_seq_list.append(batch_log_probs_seq[i][0:cur_len].tolist()) - batch_log_probs_idx_list.append(batch_log_probs_idx[i][0:cur_len].tolist()) - score_hyps = ctc_beam_search_decoder_batch(batch_log_probs_seq_list, - batch_log_probs_idx_list, - batch_root, - batch_start, - beam_size, - num_processes, - blank_id, - space_id, - cutoff_prob, - scorer, - hotwords_scorer) + batch_log_probs_seq_list.append( + batch_log_probs_seq[i][0:cur_len].tolist()) + batch_log_probs_idx_list.append( + batch_log_probs_idx[i][0:cur_len].tolist()) + score_hyps = ctc_beam_search_decoder_batch( + batch_log_probs_seq_list, batch_log_probs_idx_list, batch_root, + batch_start, beam_size, num_processes, blank_id, space_id, + cutoff_prob, scorer, hotwords_scorer) return score_hyps def batch_rescoring(self, score_hyps, hist_enc, hist_mask_len, max_len): @@ -267,10 +269,12 @@ def batch_rescoring(self, score_hyps, hist_enc, hist_mask_len, max_len): max_seq_len = len(hyps[-1]) max_seq_len += 2 - hyps_pad_sos_eos = np.ones((bz, beam_size, max_seq_len), dtype=np.int64) + hyps_pad_sos_eos = np.ones((bz, beam_size, max_seq_len), + dtype=np.int64) hyps_pad_sos_eos = hyps_pad_sos_eos * self.eos # fill eos if self.bidecoder: - r_hyps_pad_sos_eos = np.ones((bz, beam_size, max_seq_len), dtype=np.int64) + r_hyps_pad_sos_eos = np.ones((bz, beam_size, max_seq_len), + dtype=np.int64) r_hyps_pad_sos_eos = r_hyps_pad_sos_eos * self.eos hyps_lens_sos = np.ones((bz, beam_size), dtype=np.int32) @@ -280,12 +284,13 @@ def batch_rescoring(self, score_hyps, hist_enc, hist_mask_len, max_len): length = len(cand) + 2 bz_offset = idx % beam_size pad_cand = [self.sos] + cand + [self.eos] - hyps_pad_sos_eos[bz_id][bz_offset][0 : length] = pad_cand + hyps_pad_sos_eos[bz_id][bz_offset][0:length] = pad_cand if self.bidecoder: r_pad_cand = [self.sos] + cand[::-1] + [self.eos] r_hyps_pad_sos_eos[bz_id][bz_offset][0:length] = r_pad_cand hyps_lens_sos[bz_id][idx % beam_size] = len(cand) + 1 - in0 = pb_utils.Tensor.from_dlpack("encoder_out", to_dlpack(encoder_out)) + in0 = pb_utils.Tensor.from_dlpack("encoder_out", + to_dlpack(encoder_out)) in1 = pb_utils.Tensor("encoder_out_lens", encoder_lens) in2 = pb_utils.Tensor("hyps_pad_sos_eos", hyps_pad_sos_eos) in3 = pb_utils.Tensor("hyps_lens_sos", hyps_lens_sos) @@ -295,9 +300,10 @@ def batch_rescoring(self, score_hyps, hist_enc, hist_mask_len, max_len): input_tensors.append(in4) in5 = pb_utils.Tensor.from_dlpack("ctc_score", to_dlpack(ctc_score)) input_tensors.append(in5) - request = pb_utils.InferenceRequest(model_name='decoder', - requested_output_names=['best_index'], - inputs=input_tensors) + request = pb_utils.InferenceRequest( + model_name='decoder', + requested_output_names=['best_index'], + inputs=input_tensors) response = request.exec() best_index = pb_utils.get_output_tensor_by_name(response, 'best_index') best_index = from_dlpack(best_index.to_dlpack()).clone() diff --git a/runtime/gpu/scripts/benchmark_onnx_throughput.py b/runtime/gpu/scripts/benchmark_onnx_throughput.py index 077cec68e..066f0cba3 100755 --- a/runtime/gpu/scripts/benchmark_onnx_throughput.py +++ b/runtime/gpu/scripts/benchmark_onnx_throughput.py @@ -15,7 +15,6 @@ # limitations under the License. # Modified from below: # https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/onnx_exporter.py - """ Usage: export CUDA_VISIBLE_DEVICES="0" @@ -37,8 +36,7 @@ def get_parser(): parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) + formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( "--batch_sizes", @@ -57,7 +55,7 @@ def get_parser(): parser.add_argument( "--onnxFile", type=str, - default="/mnt/samsung-t7/yuekai/benchmark/wenet/wenet/bin/u2pp_aishell2_onnx/encoder_fp16.onnx", + default="wenet/bin/u2pp_aishell2_onnx/encoder_fp16.onnx", help="Path to the onnx file", ) @@ -96,9 +94,10 @@ def get_parser(): return parser -def allocateOutputBuffers( - output_buffers, output_buffer_max_sizes, device, data_type=torch.float32 -): +def allocateOutputBuffers(output_buffers, + output_buffer_max_sizes, + device, + data_type=torch.float32): # Allocate output tensors with the largest test size needed. # So the allocated memory can be reused # for each test run. @@ -114,20 +113,22 @@ def get_latency_result(latency_list, batch_size): throughput_trt = 1000.0 / latency_ms return { - "test_times": len(latency_list), - "latency_variance": "{:.2f}".format(latency_variance), - "latency_90_percentile": "{:.2f}".format( - numpy.percentile(latency_list, 90) * 1000.0 - ), - "latency_95_percentile": "{:.2f}".format( - numpy.percentile(latency_list, 95) * 1000.0 - ), - "latency_99_percentile": "{:.2f}".format( - numpy.percentile(latency_list, 99) * 1000.0 - ), - "average_latency_ms": "{:.2f}".format(latency_ms), - "QPS": "{:.2f}".format(throughput), - f"QPS_trt_batch{batch_size}": "{:.2f}".format(throughput_trt), + "test_times": + len(latency_list), + "latency_variance": + "{:.2f}".format(latency_variance), + "latency_90_percentile": + "{:.2f}".format(numpy.percentile(latency_list, 90) * 1000.0), + "latency_95_percentile": + "{:.2f}".format(numpy.percentile(latency_list, 95) * 1000.0), + "latency_99_percentile": + "{:.2f}".format(numpy.percentile(latency_list, 99) * 1000.0), + "average_latency_ms": + "{:.2f}".format(latency_ms), + "QPS": + "{:.2f}".format(throughput), + f"QPS_trt_batch{batch_size}": + "{:.2f}".format(throughput_trt), } @@ -149,17 +150,16 @@ def create_onnxruntime_input( d_k = int(output_size / head) cnn_module_kernel = 7 - chunk_xs = torch.randn( - batch_size, sequence_length, feature_size, dtype=data_type - ).numpy() + chunk_xs = torch.randn(batch_size, + sequence_length, + feature_size, + dtype=data_type).numpy() inputs["chunk_xs"] = chunk_xs - chunk_lens = ( - torch.ones(batch_size, dtype=torch.int32).numpy() * sequence_length - ) + chunk_lens = (torch.ones(batch_size, dtype=torch.int32).numpy() * + sequence_length) inputs["chunk_lens"] = chunk_lens - offset = ( - torch.arange(0, batch_size, dtype=torch.int64).unsqueeze(1).numpy() - ) + offset = (torch.arange(0, batch_size, + dtype=torch.int64).unsqueeze(1).numpy()) inputs["offset"] = offset att_cache = torch.randn( batch_size, @@ -178,9 +178,10 @@ def create_onnxruntime_input( dtype=data_type, ).numpy() inputs["cnn_cache"] = cnn_cache - cache_mask = torch.ones( - batch_size, 1, required_cache_size, dtype=data_type - ).numpy() + cache_mask = torch.ones(batch_size, + 1, + required_cache_size, + dtype=data_type).numpy() inputs["cache_mask"] = cache_mask else: @@ -202,9 +203,9 @@ def inference_ort( number=1, repeat=warm_up_repeat, ) # Dry run - latency_list = timeit.repeat( - lambda: ort_session.run(None, ort_inputs), number=1, repeat=repeat_times - ) + latency_list = timeit.repeat(lambda: ort_session.run(None, ort_inputs), + number=1, + repeat=repeat_times) result.update(result_template) result.update({"io_binding": False}) result.update(get_latency_result(latency_list, batch_size)) @@ -241,11 +242,9 @@ def inference_ort_with_io_binding( # Bind inputs to device for name in ort_inputs.keys(): np_input = torch.from_numpy(ort_inputs[name]).to(device) - input_type = ( - IO_BINDING_DATA_TYPE_MAP[str(ort_inputs[name].dtype)] - if str(ort_inputs[name].dtype) in IO_BINDING_DATA_TYPE_MAP - else data_type - ) + input_type = (IO_BINDING_DATA_TYPE_MAP[str(ort_inputs[name].dtype)] + if str(ort_inputs[name].dtype) + in IO_BINDING_DATA_TYPE_MAP else data_type) io_binding.bind_input( name, @@ -260,11 +259,9 @@ def inference_ort_with_io_binding( allocateOutputBuffers(output_buffers, output_buffer_max_sizes, device) for i, ort_output_name in enumerate(ort_output_names): - output_type = ( - IO_BINDING_DATA_TYPE_MAP[str(ort_outputs[i].dtype)] - if str(ort_outputs[i].dtype) in IO_BINDING_DATA_TYPE_MAP - else data_type - ) + output_type = (IO_BINDING_DATA_TYPE_MAP[str(ort_outputs[i].dtype)] + if str(ort_outputs[i].dtype) in IO_BINDING_DATA_TYPE_MAP + else data_type) io_binding.bind_output( ort_output_name, output_buffers[i].device.type, @@ -306,12 +303,10 @@ def create_onnxruntime_session( if enable_all_optimization: sess_options.graph_optimization_level = ( - onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL - ) + onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL) else: sess_options.graph_optimization_level = ( - onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC - ) + onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC) # sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL # noqa if enable_profiling: @@ -341,13 +336,14 @@ def create_onnxruntime_session( if provider_options: providers = [ - (name, provider_options[name]) if name in provider_options else name + (name, + provider_options[name]) if name in provider_options else name for name in providers ] - session = onnxruntime.InferenceSession( - onnx_model_path, sess_options, providers=providers - ) + session = onnxruntime.InferenceSession(onnx_model_path, + sess_options, + providers=providers) return session @@ -403,10 +399,8 @@ def create_onnxruntime_session( results = [] for batch_size in batch_sizes: for sequence_length in sequence_lengths: - if ( - max_sequence_length is not None - and sequence_length > max_sequence_length - ): + if (max_sequence_length is not None + and sequence_length > max_sequence_length): continue ort_inputs = create_onnxruntime_input( @@ -423,11 +417,8 @@ def create_onnxruntime_session( "sequence_length": sequence_length, } - print( - "Run onnxruntime on {} with input shape {}".format( - args.onnxFile, [batch_size, sequence_length] - ) - ) + print("Run onnxruntime on {} with input shape {}".format( + args.onnxFile, [batch_size, sequence_length])) if args.disable_ort_io_binding: result = inference_ort( @@ -444,8 +435,7 @@ def create_onnxruntime_session( output_buffer_max_sizes = [] for i in range(len(ort_outputs)): output_buffer_max_sizes.append( - numpy.prod(ort_outputs[i].shape) - ) + numpy.prod(ort_outputs[i].shape)) data_type = numpy.intc output_buffers = [] diff --git a/runtime/gpu/scripts/compute_hotwords_f1.py b/runtime/gpu/scripts/compute_hotwords_f1.py index 05e821381..f2f031933 100644 --- a/runtime/gpu/scripts/compute_hotwords_f1.py +++ b/runtime/gpu/scripts/compute_hotwords_f1.py @@ -14,15 +14,19 @@ --hotword="data/hotwords.yaml" """ + def _sorted_iteritems(d): return sorted(d.items()) + def _iteritems(d): return iter(d.items()) + def _iterkeys(d): return iter(d.keys()) + _basestring = str _SENTINEL = object() @@ -163,6 +167,7 @@ def __setstate__(self, state): class CharTrie(collectionsAbc.MutableMapping): """A trie implementation with dict interface plus some extensions. """ + def __init__(self, *args, **kwargs): self._root = _Node() self._sorted = False @@ -384,8 +389,8 @@ def popitem(self): step = next(_iterkeys(node.children)) node = node.children[step] trace.append((step, node)) - return (self._key_from_path((step for step, _ in trace[1:])), - self._pop_from_node(node, trace)) + return (self._key_from_path( + (step for step, _ in trace[1:])), self._pop_from_node(node, trace)) def __delitem__(self, key_or_slice): """Deletes value associated with given key or raises KeyError. @@ -531,8 +536,9 @@ def cal_f1(ner_dict_pred, ner_dict_label): def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--label", default="data/aishell1_text_hotwords") - parser.add_argument("--preds", - default="data/with_hotwords_ali.log;data/without_hotwords_ali.log") + parser.add_argument( + "--preds", + default="data/with_hotwords_ali.log;data/without_hotwords_ali.log") parser.add_argument("--hotword", default="data/hotwords.yaml") return parser.parse_args() diff --git a/runtime/gpu/scripts/convert.py b/runtime/gpu/scripts/convert.py index 2a72f45a8..9ae25c87a 100755 --- a/runtime/gpu/scripts/convert.py +++ b/runtime/gpu/scripts/convert.py @@ -23,13 +23,21 @@ parser = argparse.ArgumentParser( description='generate config.pbtxt for model_repo') parser.add_argument('--config', required=True, help='config file') - parser.add_argument('--vocab', required=True, + parser.add_argument('--vocab', + required=True, help='vocabulary file, units.txt') - parser.add_argument('--model_repo', required=True, + parser.add_argument('--model_repo', + required=True, help='model repo directory') - parser.add_argument('--onnx_model_dir', default=True, type=str, required=False, + parser.add_argument('--onnx_model_dir', + default=True, + type=str, + required=False, help="onnx model path") - parser.add_argument('--lm_path', default=None, type=str, required=False, + parser.add_argument('--lm_path', + default=None, + type=str, + required=False, help="the additional language model path") args = parser.parse_args() with open(args.config, 'r') as fin: @@ -39,9 +47,9 @@ onnx_configs = yaml.load(fin, Loader=yaml.FullLoader) params = [("#beam_size", 10), ("#num_mel_bins", 80), ("#frame_shift", 10), - ("#frame_length", 25), ("#sample_rate", 16000), ("#output_size", 256), - ("#lm_path", ""), ("#bidecoder", 0), ("#vocabulary_path", ""), - ("#DTYPE", "FP32")] + ("#frame_length", 25), ("#sample_rate", 16000), + ("#output_size", 256), ("#lm_path", ""), ("#bidecoder", 0), + ("#vocabulary_path", ""), ("#DTYPE", "FP32")] model_params = dict(params) # fill values model_params["#beam_size"] = onnx_configs["beam_size"] @@ -74,7 +82,8 @@ model_params["#chunk_size_in_seconds"] = chunk_seconds model_params["#num_layers"] = configs["encoder_conf"]["num_blocks"] model_params["#context"] = onnx_configs["context"] - model_params["#cnn_module_cache"] = onnx_configs["cnn_module_kernel_cache"] + model_params["#cnn_module_cache"] = onnx_configs[ + "cnn_module_kernel_cache"] model_params["#decoding_window"] = onnx_configs["decoding_window"] head = configs["encoder_conf"]["attention_heads"] model_params["#num_head"] = head @@ -87,7 +96,8 @@ if "decoder" == model and model_params["#bidecoder"] == 0: template = "config_template2.pbtxt" # streaming transformer encoder - if "encoder" == model and model_params.get("#cnn_module_cache", -1) == 0: + if "encoder" == model and model_params.get("#cnn_module_cache", + -1) == 0: template = "config_template2.pbtxt" model_dir = os.path.join(args.model_repo, model) @@ -101,8 +111,8 @@ model_name = model + ".onnx" source_model = os.path.join(args.onnx_model_dir, model_name) target_model = os.path.join(model_dir, "1", model + ".onnx") - res = subprocess.call( - ["cp", source_model, target_model], shell=False) + res = subprocess.call(["cp", source_model, target_model], + shell=False) if model == "encoder": # currently, with torch 1.10, the # exported conformer encoder output size is -1 @@ -113,11 +123,13 @@ encoder_out = model.graph.output[2] else: encoder_out = model.graph.output[0] - output_dim = encoder_out.type.tensor_type.shape.dim[2].dim_param + output_dim = encoder_out.type.tensor_type.shape.dim[ + 2].dim_param if output_dim.startswith("Add"): model_params["#encoder_output_size"] = -1 - with open(os.path.join(model_dir, template), "r", encoding="utf-8") as f: + with open(os.path.join(model_dir, template), "r", + encoding="utf-8") as f: for line in f: if line.startswith("#"): continue diff --git a/runtime/gpu/tensorrt/LayerNormPlugin/LayerNormPlugin.cu b/runtime/gpu/tensorrt/LayerNormPlugin/LayerNormPlugin.cu old mode 100755 new mode 100644 index 61d7e8a81..8ca3d8a3c --- a/runtime/gpu/tensorrt/LayerNormPlugin/LayerNormPlugin.cu +++ b/runtime/gpu/tensorrt/LayerNormPlugin/LayerNormPlugin.cu @@ -13,7 +13,7 @@ // limitations under the License. #include "LayerNormPlugin.h" -using namespace nvinfer1; // NOLINT +using namespace nvinfer1; // NOLINT PluginFieldCollection LayerNormPluginCreator::fc_{}; std::vector LayerNormPluginCreator::attr_; @@ -23,106 +23,107 @@ struct BytesToType; template <> struct BytesToType<2> { - using type = uint16_t; + using type = uint16_t; }; template <> struct BytesToType<4> { - using type = uint32_t; + using type = uint32_t; }; template <> struct BytesToType<8> { - using type = uint64_t; + using type = uint64_t; }; template <> struct BytesToType<16> { - using type = float4; + using type = float4; }; template __device__ inline void copy(const void* local, void* data) { - using T = typename BytesToType::type; + using T = typename BytesToType::type; - const T* in = static_cast(local); - T* out = static_cast(data); - *out = *in; + const T* in = static_cast(local); + T* out = static_cast(data); + *out = *in; } struct mySum { - __host__ __device__ __forceinline__ float2 operator()( - const float2 &a, const float2 &b) const { - return make_float2(a.x + b.x, a.y + b.y); - } + __host__ __device__ __forceinline__ float2 operator()(const float2& a, + const float2& b) const { + return make_float2(a.x + b.x, a.y + b.y); + } }; template -__global__ void layerNormKernel(const T* input, const T* gamma, - const T* beta, T* output) { - const int idx = blockIdx.x * 256 + threadIdx.x * VPT; - T localX[VPT], localGamma[VPT], localBeta[VPT]; +__global__ void layerNormKernel(const T* input, const T* gamma, const T* beta, + T* output) { + const int idx = blockIdx.x * 256 + threadIdx.x * VPT; + T localX[VPT], localGamma[VPT], localBeta[VPT]; - copy(&input[idx], localX); - float2 localFloat2 = {0.f, 0.f}; + copy(&input[idx], localX); + float2 localFloat2 = {0.f, 0.f}; - const float rld = float(1)/ float(256); // NOLINT + const float rld = float(1) / float(256); // NOLINT #pragma unroll - for (int it = 0; it < VPT; it++) { - const float tmp = rld * (float)localX[it]; // NOLINT - localFloat2.x += tmp; - localFloat2.y += tmp * (float)localX[it]; // NOLINT - } - - copy(&beta[threadIdx.x * VPT], localBeta); - copy(&gamma[threadIdx.x * VPT], localGamma); - - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ float mu; // mean - __shared__ float rsigma; // 1 / std.dev. - - // const float2 sumKV = - // BlockReduce(temp_storage).Reduce(localFloat2, cub::Sum()); - const float2 sumKV = BlockReduce(temp_storage).Reduce(localFloat2, mySum()); - - if (threadIdx.x == 0) { - mu = sumKV.x; - rsigma = rsqrt(sumKV.y - mu * mu + 1e-6); - } - __syncthreads(); + for (int it = 0; it < VPT; it++) { + const float tmp = rld * (float)localX[it]; // NOLINT + localFloat2.x += tmp; + localFloat2.y += tmp * (float)localX[it]; // NOLINT + } + + copy(&beta[threadIdx.x * VPT], localBeta); + copy(&gamma[threadIdx.x * VPT], localGamma); + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ float mu; // mean + __shared__ float rsigma; // 1 / std.dev. + + // const float2 sumKV = + // BlockReduce(temp_storage).Reduce(localFloat2, cub::Sum()); + const float2 sumKV = BlockReduce(temp_storage).Reduce(localFloat2, mySum()); + + if (threadIdx.x == 0) { + mu = sumKV.x; + rsigma = rsqrt(sumKV.y - mu * mu + 1e-6); + } + __syncthreads(); #pragma unroll - for (int it = 0; it < VPT; it++) { - localX[it] = (float)localGamma[it] * ((float)localX[it] - mu) * rsigma // NOLINT - + (float)localBeta[it]; // NOLINT - } + for (int it = 0; it < VPT; it++) { + localX[it] = + (float)localGamma[it] * ((float)localX[it] - mu) * rsigma // NOLINT + + (float)localBeta[it]; // NOLINT + } - copy(localX, &output[idx]); // NOLINT + copy(localX, &output[idx]); // NOLINT } template __global__ void layerNormKernel(const float*, - const float*, const float*, float*); -template __global__ void layerNormKernel(const half*, - const half*, const half*, half*); + const float*, + const float*, float*); +template __global__ void layerNormKernel(const half*, const half*, + const half*, half*); int LayerNormPlugin::enqueue(const PluginTensorDesc* inputDesc, const PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept { - const int gridSize = inputDesc[0].dims.d[0] * inputDesc[0].dims.d[1]; - - if (inputDesc[0].type == DataType::kFLOAT) { - constexpr int VPT = 16 / sizeof(float); - constexpr int TPB = 256 / VPT; - (layerNormKernel) - <<>> ((const float*)inputs[0], - (const float*)inputs[1], (const float*)inputs[2], (float*)outputs[0]); // NOLINT - } else { - constexpr int VPT = 16 / sizeof(half); - constexpr int TPB = 256 / VPT; - (layerNormKernel) <<>> (( - const half*)inputs[0], (const half*)inputs[1], - (const half*)inputs[2], (half*)outputs[0]); // NOLINT - } - return 0; + const int gridSize = inputDesc[0].dims.d[0] * inputDesc[0].dims.d[1]; + + if (inputDesc[0].type == DataType::kFLOAT) { + constexpr int VPT = 16 / sizeof(float); + constexpr int TPB = 256 / VPT; + (layerNormKernel)<<>>( + (const float*)inputs[0], (const float*)inputs[1], + (const float*)inputs[2], (float*)outputs[0]); // NOLINT + } else { + constexpr int VPT = 16 / sizeof(half); + constexpr int TPB = 256 / VPT; + (layerNormKernel)<<>>( + (const half*)inputs[0], (const half*)inputs[1], (const half*)inputs[2], + (half*)outputs[0]); // NOLINT + } + return 0; } REGISTER_TENSORRT_PLUGIN(LayerNormPluginCreator); - diff --git a/runtime/gpu/tensorrt/LayerNormPlugin/LayerNormPlugin.h b/runtime/gpu/tensorrt/LayerNormPlugin/LayerNormPlugin.h old mode 100755 new mode 100644 index 267666cf4..0c17f770d --- a/runtime/gpu/tensorrt/LayerNormPlugin/LayerNormPlugin.h +++ b/runtime/gpu/tensorrt/LayerNormPlugin/LayerNormPlugin.h @@ -14,38 +14,41 @@ #ifndef RUNTIME_GPU_TENSORRT_LAYERNORMPLUGIN_LAYERNORMPLUGIN_H_ #define RUNTIME_GPU_TENSORRT_LAYERNORMPLUGIN_LAYERNORMPLUGIN_H_ -#include -#include -#include #include -#include // NOLINT -#include // NOLINT +#include // NOLINT +#include +#include // NOLINT +#include +#include -#define CEIL_DIVIDE(X, Y) (((X)+(Y)-1)/(Y)) -#define CEIL_TO(X, Y) (((X)+(Y)-1)/(Y)*(Y)) +#define CEIL_DIVIDE(X, Y) (((X) + (Y)-1) / (Y)) +#define CEIL_TO(X, Y) (((X) + (Y)-1) / (Y) * (Y)) template __device__ T epsilon(); template <> __device__ float epsilon() { - return (float)6.0e-12; // NOLINT + return (float)6.0e-12; // NOLINT } template <> __device__ half epsilon() { - return (half)6.0e-6; + return (half)6.0e-6; } // +------- Debug wrapper ----------------------------------- #if DEBUG -#define WHERE_AM_I() do {printf("[%s]:this=->%p\n", __func__, this);} while (0); +#define WHERE_AM_I() \ + do { \ + printf("[%s]:this=->%p\n", __func__, this); \ + } while (0); #else #define WHERE_AM_I() #endif // DEBUG // +------- Plguin ------------------------------------------- -namespace { // NOLINT +namespace { // NOLINT static const char* PLUGIN_NAME{"LayerNorm"}; static const char* PLUGIN_VERSION{"1"}; } // namespace @@ -53,177 +56,170 @@ static const char* PLUGIN_VERSION{"1"}; namespace nvinfer1 { // +------- Plugin body --------------------------------------- -class LayerNormPlugin: public IPluginV2DynamicExt { +class LayerNormPlugin : public IPluginV2DynamicExt { private: - std::string name_; - std::string namespace_; + std::string name_; + std::string namespace_; public: - LayerNormPlugin(const std::string& name) : name_(name) { // NOLINT - WHERE_AM_I(); - } - - LayerNormPlugin(const std::string& name, - const void* data, size_t length) : name_(name) { - WHERE_AM_I(); - } - - LayerNormPlugin() = delete; - - ~LayerNormPlugin() { - WHERE_AM_I(); - } - - size_t getSerializationSize() const noexcept override { - WHERE_AM_I(); - return 0; - } - - void serialize(void *buffer) const noexcept override { - WHERE_AM_I(); - } - - IPluginV2DynamicExt* clone() const noexcept override { - WHERE_AM_I(); - return new LayerNormPlugin(name_); - } - - int getNbOutputs() const noexcept override { - WHERE_AM_I(); - return 1; - } - - DimsExprs getOutputDimensions(int32_t outputIndex, const DimsExprs* inputs, - int32_t nbInputs, - IExprBuilder& exprBuilder) noexcept override { - WHERE_AM_I(); - return inputs[0]; - } - - bool supportsFormatCombination(int32_t pos, const PluginTensorDesc* inOut, - int32_t nbInputs, - int32_t nbOutputs) noexcept override { - WHERE_AM_I(); - if (inOut[pos].format != TensorFormat::kLINEAR) { - return false; - } - - bool res = false; - switch (pos) { - case 0: - res = (inOut[pos].type == DataType::kFLOAT - || inOut[pos].type == DataType::kHALF); break; - case 1: - case 2: - case 3: - res = inOut[pos].type == inOut[0].type; break; - default: // should NOT be here - res = false; break; - } - - return res; - } - - DataType getOutputDataType(int outputIndex, - const DataType* inputTypes, - int nbInputs) const noexcept override { - WHERE_AM_I(); - return inputTypes[0]; - } - - void configurePlugin(const DynamicPluginTensorDesc* in, int32_t nbInputs, - const DynamicPluginTensorDesc* out, - int32_t nbOutputs) noexcept override { - WHERE_AM_I(); - } - - size_t getWorkspaceSize(const PluginTensorDesc* inputs, int32_t nbInputs, - const PluginTensorDesc* outputs, - int32_t nbOutputs) const noexcept override { - WHERE_AM_I(); - return 0; - } - - void setPluginNamespace(const char* szNamespace) noexcept override { - WHERE_AM_I(); - namespace_ = szNamespace; - } - const char* getPluginNamespace() const noexcept override { - WHERE_AM_I(); - return namespace_.c_str(); - } - const char* getPluginType() const noexcept override { - WHERE_AM_I(); - return PLUGIN_NAME; - } - const char* getPluginVersion() const noexcept override { - WHERE_AM_I(); - return PLUGIN_VERSION; - } - int initialize() noexcept override { - WHERE_AM_I(); - return 0; - } - void terminate() noexcept override { - WHERE_AM_I(); - return; - } - - void destroy() noexcept override { - WHERE_AM_I(); - } - - int32_t enqueue(const PluginTensorDesc* inputDesc, - const PluginTensorDesc* outputDesc, - const void* const* inputs, - void* const* outputs, void* workspace, - cudaStream_t stream) noexcept override; + LayerNormPlugin(const std::string& name) : name_(name) { // NOLINT + WHERE_AM_I(); + } + + LayerNormPlugin(const std::string& name, const void* data, size_t length) + : name_(name) { + WHERE_AM_I(); + } + + LayerNormPlugin() = delete; + + ~LayerNormPlugin() { WHERE_AM_I(); } + + size_t getSerializationSize() const noexcept override { + WHERE_AM_I(); + return 0; + } + + void serialize(void* buffer) const noexcept override { WHERE_AM_I(); } + + IPluginV2DynamicExt* clone() const noexcept override { + WHERE_AM_I(); + return new LayerNormPlugin(name_); + } + + int getNbOutputs() const noexcept override { + WHERE_AM_I(); + return 1; + } + + DimsExprs getOutputDimensions(int32_t outputIndex, const DimsExprs* inputs, + int32_t nbInputs, + IExprBuilder& exprBuilder) noexcept override { + WHERE_AM_I(); + return inputs[0]; + } + + bool supportsFormatCombination(int32_t pos, const PluginTensorDesc* inOut, + int32_t nbInputs, + int32_t nbOutputs) noexcept override { + WHERE_AM_I(); + if (inOut[pos].format != TensorFormat::kLINEAR) { + return false; + } + + bool res = false; + switch (pos) { + case 0: + res = (inOut[pos].type == DataType::kFLOAT || + inOut[pos].type == DataType::kHALF); + break; + case 1: + case 2: + case 3: + res = inOut[pos].type == inOut[0].type; + break; + default: // should NOT be here + res = false; + break; + } + + return res; + } + + DataType getOutputDataType(int outputIndex, const DataType* inputTypes, + int nbInputs) const noexcept override { + WHERE_AM_I(); + return inputTypes[0]; + } + + void configurePlugin(const DynamicPluginTensorDesc* in, int32_t nbInputs, + const DynamicPluginTensorDesc* out, + int32_t nbOutputs) noexcept override { + WHERE_AM_I(); + } + + size_t getWorkspaceSize(const PluginTensorDesc* inputs, int32_t nbInputs, + const PluginTensorDesc* outputs, + int32_t nbOutputs) const noexcept override { + WHERE_AM_I(); + return 0; + } + + void setPluginNamespace(const char* szNamespace) noexcept override { + WHERE_AM_I(); + namespace_ = szNamespace; + } + const char* getPluginNamespace() const noexcept override { + WHERE_AM_I(); + return namespace_.c_str(); + } + const char* getPluginType() const noexcept override { + WHERE_AM_I(); + return PLUGIN_NAME; + } + const char* getPluginVersion() const noexcept override { + WHERE_AM_I(); + return PLUGIN_VERSION; + } + int initialize() noexcept override { + WHERE_AM_I(); + return 0; + } + void terminate() noexcept override { + WHERE_AM_I(); + return; + } + + void destroy() noexcept override { WHERE_AM_I(); } + + int32_t enqueue(const PluginTensorDesc* inputDesc, + const PluginTensorDesc* outputDesc, const void* const* inputs, + void* const* outputs, void* workspace, + cudaStream_t stream) noexcept override; }; // class LayerNormPlugin class LayerNormPluginCreator : public IPluginCreator { private: - static PluginFieldCollection fc_; - static std::vector attr_; - std::string namespace_; + static PluginFieldCollection fc_; + static std::vector attr_; + std::string namespace_; public: - LayerNormPluginCreator() { - fc_.nbFields = attr_.size(); - fc_.fields = attr_.data(); - } + LayerNormPluginCreator() { + fc_.nbFields = attr_.size(); + fc_.fields = attr_.data(); + } - ~LayerNormPluginCreator() {} + ~LayerNormPluginCreator() {} - IPluginV2* createPlugin(const char* name, - const PluginFieldCollection* fc) noexcept override { - WHERE_AM_I(); - return new LayerNormPlugin(name); - } + IPluginV2* createPlugin(const char* name, + const PluginFieldCollection* fc) noexcept override { + WHERE_AM_I(); + return new LayerNormPlugin(name); + } - IPluginV2* deserializePlugin(const char* name, const void* serialData, - size_t serialLength) noexcept override { - return new LayerNormPlugin(name, serialData, serialLength); - } + IPluginV2* deserializePlugin(const char* name, const void* serialData, + size_t serialLength) noexcept override { + return new LayerNormPlugin(name, serialData, serialLength); + } - void setPluginNamespace(const char* szNamespace) noexcept override { - namespace_ = szNamespace; - } + void setPluginNamespace(const char* szNamespace) noexcept override { + namespace_ = szNamespace; + } - const char* getPluginNamespace() const noexcept override { - return namespace_.c_str(); - } + const char* getPluginNamespace() const noexcept override { + return namespace_.c_str(); + } - const char* getPluginName() const noexcept override { - return PLUGIN_NAME; - } + const char* getPluginName() const noexcept override { return PLUGIN_NAME; } - const char* getPluginVersion() const noexcept override { - return PLUGIN_VERSION; - } + const char* getPluginVersion() const noexcept override { + return PLUGIN_VERSION; + } - const PluginFieldCollection* getFieldNames() noexcept override { - return &fc_; - } + const PluginFieldCollection* getFieldNames() noexcept override { + return &fc_; + } }; // class LayerNormPluginCreator } // namespace nvinfer1 diff --git a/runtime/gpu/tensorrt/LayerNormPlugin/testLayerNormPlugin.py b/runtime/gpu/tensorrt/LayerNormPlugin/testLayerNormPlugin.py index 43e953f43..407e76a85 100755 --- a/runtime/gpu/tensorrt/LayerNormPlugin/testLayerNormPlugin.py +++ b/runtime/gpu/tensorrt/LayerNormPlugin/testLayerNormPlugin.py @@ -17,7 +17,7 @@ import numpy as np from time import time_ns import tensorrt as trt -import pycuda.autoinit # noqa +import pycuda.autoinit # noqa import pycuda.driver as cuda useFile = False @@ -73,9 +73,8 @@ def layerNormCPU(bufferH): _15 = _14 - _12 # a-bμ/sqrt(...) _16 = _x * _11 # bx/sqrt(...) _17 = _15 + _16 # b(x-μ)/sqrt(...)+a - _18 = _17.reshape( - bufferH[0].shape[0], bufferH[0].shape[1], bufferH[0].shape[2] - ) + _18 = _17.reshape(bufferH[0].shape[0], bufferH[0].shape[1], + bufferH[0].shape[2]) return _18 @@ -88,32 +87,25 @@ def testLayerNormCPU(): bufferH.append(io["(Unnamed Layer* 13) [Constant]_output"]) temp1 = layerNormCPU(bufferH) - print( - "outputCPU: %s,SumAbs=%.5e,Var=%.5f,Max=%.5f,Min=%.5f,SAD=%.5f" - % ( - str(temp1.shape), - np.sum(abs(temp1)), - np.var(temp1), - np.max(temp1), - np.min(temp1), - np.sum(np.abs(np.diff(temp1.reshape(-1)))), - ) - ) + print("outputCPU: %s,SumAbs=%.5e,Var=%.5f,Max=%.5f,Min=%.5f,SAD=%.5f" % ( + str(temp1.shape), + np.sum(abs(temp1)), + np.var(temp1), + np.max(temp1), + np.min(temp1), + np.sum(np.abs(np.diff(temp1.reshape(-1)))), + )) # print(temp1) temp2 = io[ - "seq2seq/encoder_1/layer_0/multi_head/conv1d/conv1d/ExpandDims:0" - ] - print( - "outputRef: %s,SumAbs=%.5e,Var=%.5f,Max=%.5f,Min=%.5f,SAD=%.5f" - % ( - str(temp2.shape), - np.sum(abs(temp2)), - np.var(temp2), - np.max(temp2), - np.min(temp2), - np.sum(np.abs(np.diff(temp2.reshape(-1)))), - ) - ) + "seq2seq/encoder_1/layer_0/multi_head/conv1d/conv1d/ExpandDims:0"] + print("outputRef: %s,SumAbs=%.5e,Var=%.5f,Max=%.5f,Min=%.5f,SAD=%.5f" % ( + str(temp2.shape), + np.sum(abs(temp2)), + np.var(temp2), + np.max(temp2), + np.min(temp2), + np.sum(np.abs(np.diff(temp2.reshape(-1)))), + )) # print(temp2) print("check result:") print(check(temp1, temp2, True)) @@ -147,13 +139,13 @@ def run(): inputTensorList = [] inputTensorList.append( - network.add_input("inputT", trt.float32, [-1, -1, 256]) - ) + network.add_input("inputT", trt.float32, [-1, -1, 256])) inputTensorList.append(network.add_input("inputB", trt.float32, [256])) inputTensorList.append(network.add_input("inputA", trt.float32, [256])) profile = builder.create_optimization_profile() - profile.set_shape("inputT", [1, 4, 256], [1024, 256, 256], [1024, 256, 256]) + profile.set_shape("inputT", [1, 4, 256], [1024, 256, 256], + [1024, 256, 256]) config.add_optimization_profile(profile) pluginLayer = network.add_plugin_v2(inputTensorList, getLayerNormPlugin()) @@ -167,15 +159,12 @@ def run(): context.set_binding_shape(0, [nBS, nSL, nEmbedding]) context.set_binding_shape(1, [nEmbedding]) context.set_binding_shape(2, [nEmbedding]) - print( - "Binding all? %s" - % (["No", "Yes"][int(context.all_binding_shapes_specified)]) - ) + print("Binding all? %s" % + (["No", "Yes"][int(context.all_binding_shapes_specified)])) stream = cuda.Stream() nInput = np.sum( - [engine.binding_is_input(i) for i in range(engine.num_bindings)] - ) + [engine.binding_is_input(i) for i in range(engine.num_bindings)]) nOutput = engine.num_bindings - nInput for i in range(engine.num_bindings): print( @@ -187,35 +176,31 @@ def run(): bufferH = [] bufferH.append( - np.random.rand(nBS, nSL, nEmbedding) - .astype(np.float32) - .reshape(nBS, nSL, nEmbedding) - * 2 - - 1 - ) + np.random.rand(nBS, nSL, nEmbedding).astype(np.float32).reshape( + nBS, nSL, nEmbedding) * 2 - 1) bufferH.append(np.ones(nEmbedding).astype(np.float32)) bufferH.append(np.zeros(nEmbedding).astype(np.float32)) bufferH.append( np.empty( context.get_binding_shape(3), dtype=trt.nptype(engine.get_binding_dtype(3)), - ) - ) + )) bufferD = [] for i in range(engine.num_bindings): bufferD.append(cuda.mem_alloc(bufferH[i].nbytes)) for i in range(nInput): - cuda.memcpy_htod_async( - bufferD[i], np.ascontiguousarray(bufferH[i].reshape(-1)), stream - ) + cuda.memcpy_htod_async(bufferD[i], + np.ascontiguousarray(bufferH[i].reshape(-1)), + stream) context.execute_async_v2(bufferD, stream.handle) stream.synchronize() for i in range(nOutput): - cuda.memcpy_dtoh_async(bufferH[nInput + i], bufferD[nInput + i], stream) + cuda.memcpy_dtoh_async(bufferH[nInput + i], bufferD[nInput + i], + stream) stream.synchronize() for i in range(nInput): @@ -254,10 +239,8 @@ def run(): context.execute_async_v2(bufferD, stream.handle) stream.synchronize() time1 = time_ns() - print( - testCase - + "average %fms per inference\n" % ((time1 - time0) / nTime / 1000000) - ) + print(testCase + "average %fms per inference\n" % + ((time1 - time0) / nTime / 1000000)) print("check result:") temp1 = bufferH[-1] diff --git a/runtime/gpu/tensorrt/export_streaming_conformer_trt.py b/runtime/gpu/tensorrt/export_streaming_conformer_trt.py index 5dfb4c084..daef661b1 100755 --- a/runtime/gpu/tensorrt/export_streaming_conformer_trt.py +++ b/runtime/gpu/tensorrt/export_streaming_conformer_trt.py @@ -43,8 +43,7 @@ def get_parser(): parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) + formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( "--chunk_xs", @@ -144,20 +143,22 @@ def get_latency_result(latency_list, batch_size): throughput_trt = 1000.0 / latency_ms return { - "test_times": len(latency_list), - "latency_variance": "{:.2f}".format(latency_variance), - "latency_90_percentile": "{:.2f}".format( - np.percentile(latency_list, 90) * 1000.0 - ), - "latency_95_percentile": "{:.2f}".format( - np.percentile(latency_list, 95) * 1000.0 - ), - "latency_99_percentile": "{:.2f}".format( - np.percentile(latency_list, 99) * 1000.0 - ), - "average_latency_ms": "{:.2f}".format(latency_ms), - "QPS": "{:.2f}".format(throughput), - f"QPS_trt_batch{batch_size}": "{:.2f}".format(throughput_trt), + "test_times": + len(latency_list), + "latency_variance": + "{:.2f}".format(latency_variance), + "latency_90_percentile": + "{:.2f}".format(np.percentile(latency_list, 90) * 1000.0), + "latency_95_percentile": + "{:.2f}".format(np.percentile(latency_list, 95) * 1000.0), + "latency_99_percentile": + "{:.2f}".format(np.percentile(latency_list, 99) * 1000.0), + "average_latency_ms": + "{:.2f}".format(latency_ms), + "QPS": + "{:.2f}".format(throughput), + f"QPS_trt_batch{batch_size}": + "{:.2f}".format(throughput_trt), } @@ -183,8 +184,7 @@ def test(engine, context, nBatchSize, batch_threshold=8): context.set_binding_shape(bindingBias + 5, [nBatchSize, 1, 80]) nInput = np.sum( - [engine.binding_is_input(i) for i in range(engine.num_bindings)] - ) + [engine.binding_is_input(i) for i in range(engine.num_bindings)]) nOutput = engine.num_bindings - nInput nInput = nInput // nProfile @@ -197,9 +197,12 @@ def test(engine, context, nBatchSize, batch_threshold=8): # (elayers, b, head, cache_t1, d_k * 2) head = 4 d_k = 64 - att_cache = torch.randn( - nBatchSize, 12, head, 80, d_k * 2, dtype=torch.float32 - ).numpy() + att_cache = torch.randn(nBatchSize, + 12, + head, + 80, + d_k * 2, + dtype=torch.float32).numpy() cnn_cache = torch.randn(nBatchSize, 12, 256, 7, dtype=torch.float32) cache_mask = torch.ones(nBatchSize, 1, 80, dtype=torch.float32) @@ -220,8 +223,7 @@ def test(engine, context, nBatchSize, batch_threshold=8): np.empty( context.get_binding_shape(bindingBias + i), dtype=trt.nptype(engine.get_binding_dtype(bindingBias + i)), - ) - ) + )) bufferD = [] for i in range(nInput + nOutput): bufferD.append(cudart.cudaMalloc(bufferH[i].nbytes)[1]) @@ -240,12 +242,11 @@ def test(engine, context, nBatchSize, batch_threshold=8): ) nWarm, nTest = 5, 10 - timeit.repeat( - lambda: context.execute_v2(bufferD), number=1, repeat=nWarm - ) # Dry run - latency_list = timeit.repeat( - lambda: context.execute_v2(bufferD), number=1, repeat=nTest - ) + timeit.repeat(lambda: context.execute_v2(bufferD), number=1, + repeat=nWarm) # Dry run + latency_list = timeit.repeat(lambda: context.execute_v2(bufferD), + number=1, + repeat=nTest) print(get_latency_result(latency_list, nBatchSize)) if nProfile == 1 or nBatchSize > batch_threshold: @@ -280,8 +281,7 @@ def main(): else: builder = trt.Builder(logger) network = builder.create_network( - 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) - ) + 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) config = builder.create_builder_config() if args.useTimeCache: diff --git a/runtime/gpu/tensorrt/model_repo_stateful_trt/feature_extractor/1/model.py b/runtime/gpu/tensorrt/model_repo_stateful_trt/feature_extractor/1/model.py index ce1f340f8..459855e59 100644 --- a/runtime/gpu/tensorrt/model_repo_stateful_trt/feature_extractor/1/model.py +++ b/runtime/gpu/tensorrt/model_repo_stateful_trt/feature_extractor/1/model.py @@ -21,7 +21,9 @@ import json import numpy as np + class Fbank(torch.nn.Module): + def __init__(self, opts): super(Fbank, self).__init__() self.fbank = kaldifeat.Fbank(opts) @@ -29,9 +31,16 @@ def __init__(self, opts): def forward(self, waves: List[torch.Tensor]): return self.fbank(waves) + class Feat(object): - def __init__(self, seqid, offset_ms, sample_rate, - first_chunk_sz, frame_stride, device='cpu'): + + def __init__(self, + seqid, + offset_ms, + sample_rate, + first_chunk_sz, + frame_stride, + device='cpu'): self.seqid = seqid self.sample_rate = sample_rate self.wav = torch.tensor([], device=device) @@ -62,10 +71,11 @@ def add_frames(self, frames: torch.tensor): self.frames = torch.cat([self.frames, frames], axis=0) def get_frames(self, num_frames: int): - seg = self.frames[0: num_frames] + seg = self.frames[0:num_frames] self.frames = self.frames[self.frame_stride:] return seg + class TritonPythonModel: """Your Python model must use the same class name. Every Python model that is created must have "TritonPythonModel" as the class name. @@ -138,7 +148,8 @@ def initialize(self, args): cur_frames = _kaldifeat.num_frames(first_chunk_size, opts.frame_opts) while cur_frames < self.decoding_window: first_chunk_size += frame_shift_ms * sample_rate // 1000 - cur_frames = _kaldifeat.num_frames(first_chunk_size, opts.frame_opts) + cur_frames = _kaldifeat.num_frames(first_chunk_size, + opts.frame_opts) # self.pad_silence = first_chunk_size - self.chunk_size self.first_chunk_size = first_chunk_size self.offset_ms = self.get_offset(frame_length_ms, frame_shift_ms) @@ -157,7 +168,8 @@ def parse_model_params(self, model_params): "frame_length_ms": 25, "frame_shift_ms": 10, "sample_rate": 16000, - "chunk_size_s": 0.64} + "chunk_size_s": 0.64 + } # get parameter configurations for li in model_params.items(): key, value = li @@ -212,8 +224,7 @@ def execute(self, requests): self.seq_feat[corrid] = Feat(corrid, self.offset_ms, self.sample_rate, self.first_chunk_size, - self.frame_stride, - self.device) + self.frame_stride, self.device) if ready: self.seq_feat[corrid].add_wavs(wavs[0:wav_lens]) @@ -226,7 +237,8 @@ def execute(self, requests): wav = self.seq_feat[corrid].get_seg_wav() * 32768 if len(wav) < self.min_seg: - temp = torch.zeros(self.min_seg, dtype=torch.float32, + temp = torch.zeros(self.min_seg, + dtype=torch.float32, device=self.device) temp[0:len(wav)] = wav[:] wav = temp @@ -235,15 +247,16 @@ def execute(self, requests): features = self.feature_extractor(total_waves) batch_size = len(batch_seqid) - batch_speech = torch.zeros((batch_size, self.decoding_window, - self.feature_size), dtype=self.dtype) + batch_speech = torch.zeros( + (batch_size, self.decoding_window, self.feature_size), + dtype=self.dtype) batch_speech_lens = torch.zeros((batch_size, 1), dtype=torch.int32) i = 0 for corrid, frames in zip(batch_seqid, features): self.seq_feat[corrid].add_frames(frames) r_frames = self.seq_feat[corrid].get_frames(self.decoding_window) - speech = batch_speech[i: i + 1] - speech_lengths = batch_speech_lens[i: i + 1] + speech = batch_speech[i:i + 1] + speech_lengths = batch_speech_lens[i:i + 1] i += 1 speech_lengths[0] = r_frames.size(0) speech[0][0:r_frames.size(0)] = r_frames.to(speech.device) @@ -251,9 +264,11 @@ def execute(self, requests): # out_tensor1 = pb_utils.Tensor.from_dlpack("speech_lengths", # to_dlpack(speech_lengths)) out_tensor0 = pb_utils.Tensor("speech", speech.numpy()) - out_tensor1 = pb_utils.Tensor("speech_lengths", speech_lengths.numpy()) + out_tensor1 = pb_utils.Tensor("speech_lengths", + speech_lengths.numpy()) output_tensors = [out_tensor0, out_tensor1] - response = pb_utils.InferenceResponse(output_tensors=output_tensors) + response = pb_utils.InferenceResponse( + output_tensors=output_tensors) responses.append(response) if corrid in end_seqid: del self.seq_feat[corrid] diff --git a/runtime/gpu/tensorrt/model_repo_stateful_trt/wenet/1/model.py b/runtime/gpu/tensorrt/model_repo_stateful_trt/wenet/1/model.py index 91e008dc9..bef3bd8b2 100644 --- a/runtime/gpu/tensorrt/model_repo_stateful_trt/wenet/1/model.py +++ b/runtime/gpu/tensorrt/model_repo_stateful_trt/wenet/1/model.py @@ -26,6 +26,7 @@ from torch.utils.dlpack import from_dlpack + class TritonPythonModel: """Your Python model must use the same class name. Every Python model that is created must have "TritonPythonModel" as the class name. @@ -146,14 +147,20 @@ def execute(self, requests): batch_idx += 1 - batch_states = [trieVector, batch_start, batch_encoder_hist, cur_encoder_out] - res_sents, new_states = self.model.infer(batch_log_probs, batch_log_probs_idx, - batch_len, rescore_index, batch_states) + batch_states = [ + trieVector, batch_start, batch_encoder_hist, cur_encoder_out + ] + res_sents, new_states = self.model.infer(batch_log_probs, + batch_log_probs_idx, + batch_len, rescore_index, + batch_states) cur_encoder_out = new_states for i in range(len(res_sents)): sent = np.array(res_sents[i]) - out_tensor_0 = pb_utils.Tensor("OUTPUT0", sent.astype(self.output0_dtype)) - response = pb_utils.InferenceResponse(output_tensors=[out_tensor_0]) + out_tensor_0 = pb_utils.Tensor("OUTPUT0", + sent.astype(self.output0_dtype)) + response = pb_utils.InferenceResponse( + output_tensors=[out_tensor_0]) responses.append(response) corr = batch_idx2_corrid[i] if i in rescore_index: @@ -164,8 +171,9 @@ def execute(self, requests): if self.seq_states[corr][1] is None: self.seq_states[corr][1] = cur_encoder_out[i] else: - new_hist = torch.cat([self.seq_states[corr][1], - cur_encoder_out[i]], axis=0) + new_hist = torch.cat( + [self.seq_states[corr][1], cur_encoder_out[i]], + axis=0) self.seq_states[corr][1] = new_hist assert len(requests) == len(responses) diff --git a/runtime/gpu/tensorrt/model_repo_stateful_trt/wenet/1/wenet_onnx_model.py b/runtime/gpu/tensorrt/model_repo_stateful_trt/wenet/1/wenet_onnx_model.py index d9db48839..e2dc2cd49 100644 --- a/runtime/gpu/tensorrt/model_repo_stateful_trt/wenet/1/wenet_onnx_model.py +++ b/runtime/gpu/tensorrt/model_repo_stateful_trt/wenet/1/wenet_onnx_model.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import multiprocessing import numpy as np import os @@ -21,7 +20,9 @@ from torch.utils.dlpack import to_dlpack, from_dlpack from swig_decoders import ctc_beam_search_decoder_batch, Scorer, map_batch + class WenetModel(object): + def __init__(self, model_config, device): params = self.parse_model_parameters(model_config['parameters']) @@ -93,14 +94,16 @@ def load_vocab(self, vocab_file): return (id2vocab, vocab, space_id, blank_id, sos_eos) def parse_model_parameters(self, model_parameters): - model_p = {"beam_size": 10, - "cutoff_prob": 0.999, - "vocab_path": None, - "lm_path": None, - "alpha": 2.0, - "beta": 1.0, - "rescoring": 0, - "bidecoder": 1} + model_p = { + "beam_size": 10, + "cutoff_prob": 0.999, + "vocab_path": None, + "lm_path": None, + "alpha": 2.0, + "beta": 1.0, + "rescoring": 0, + "bidecoder": 1 + } # get parameter configurations for li in model_parameters.items(): key, value = li @@ -115,8 +118,8 @@ def parse_model_parameters(self, model_parameters): assert model_p["vocab_path"] is not None return model_p - def infer(self, batch_log_probs, batch_log_probs_idx, - seq_lens, rescore_index, batch_states): + def infer(self, batch_log_probs, batch_log_probs_idx, seq_lens, + rescore_index, batch_states): """ batch_states = [trieVector, batch_start, batch_encoder_hist, cur_encoder_out] @@ -124,17 +127,10 @@ def infer(self, batch_log_probs, batch_log_probs_idx, trie_vector, batch_start, batch_encoder_hist, cur_encoder_out = batch_states num_processes = min(multiprocessing.cpu_count(), len(batch_log_probs)) - score_hyps = self.batch_ctc_prefix_beam_search_cpu(batch_log_probs, - batch_log_probs_idx, - seq_lens, - trie_vector, - batch_start, - self.beam_size, - self.blank_id, - self.space_id, - self.cutoff_prob, - num_processes, - self.scorer) + score_hyps = self.batch_ctc_prefix_beam_search_cpu( + batch_log_probs, batch_log_probs_idx, seq_lens, trie_vector, + batch_start, self.beam_size, self.blank_id, self.space_id, + self.cutoff_prob, num_processes, self.scorer) if self.rescoring and len(rescore_index) != 0: # find the end of sequence @@ -148,14 +144,16 @@ def infer(self, batch_log_probs, batch_log_probs_idx, if hist_enc is None: cur_enc = cur_encoder_out[idx] else: - cur_enc = torch.cat([hist_enc, cur_encoder_out[idx]], axis=0) + cur_enc = torch.cat([hist_enc, cur_encoder_out[idx]], + axis=0) rescore_encoder_hist.append(cur_enc) cur_mask_len = int(len(hist_enc) + seq_lens[idx]) rescore_encoder_lens.append(cur_mask_len) rescore_hyps.append(score_hyps[idx]) if cur_enc.shape[0] > max_length: max_length = cur_enc.shape[0] - best_index = self.batch_rescoring(rescore_hyps, rescore_encoder_hist, + best_index = self.batch_rescoring(rescore_hyps, + rescore_encoder_hist, rescore_encoder_lens, max_length) best_sent = [] @@ -172,12 +170,10 @@ def infer(self, batch_log_probs, batch_log_probs_idx, return final_result, cur_encoder_out def batch_ctc_prefix_beam_search_cpu(self, batch_log_probs_seq, - batch_log_probs_idx, - batch_len, batch_root, - batch_start, beam_size, - blank_id, space_id, - cutoff_prob, num_processes, - scorer): + batch_log_probs_idx, batch_len, + batch_root, batch_start, beam_size, + blank_id, space_id, cutoff_prob, + num_processes, scorer): """ Return: Batch x Beam_size elements, each element is a tuple (score, list of ids), @@ -188,18 +184,16 @@ def batch_ctc_prefix_beam_search_cpu(self, batch_log_probs_seq, batch_log_probs_idx_list = [] for i in range(len(batch_len_list)): cur_len = int(batch_len_list[i]) - batch_log_probs_seq_list.append(batch_log_probs_seq[i][0:cur_len].tolist()) - batch_log_probs_idx_list.append(batch_log_probs_idx[i][0:cur_len].tolist()) + batch_log_probs_seq_list.append( + batch_log_probs_seq[i][0:cur_len].tolist()) + batch_log_probs_idx_list.append( + batch_log_probs_idx[i][0:cur_len].tolist()) score_hyps = ctc_beam_search_decoder_batch(batch_log_probs_seq_list, batch_log_probs_idx_list, - batch_root, - batch_start, - beam_size, - num_processes, - blank_id, - space_id, - cutoff_prob, - scorer) + batch_root, batch_start, + beam_size, num_processes, + blank_id, space_id, + cutoff_prob, scorer) return score_hyps def batch_rescoring(self, score_hyps, hist_enc, hist_mask_len, max_len): @@ -236,10 +230,12 @@ def batch_rescoring(self, score_hyps, hist_enc, hist_mask_len, max_len): max_seq_len = len(hyps[-1]) max_seq_len += 2 - hyps_pad_sos_eos = np.ones((bz, beam_size, max_seq_len), dtype=np.int64) + hyps_pad_sos_eos = np.ones((bz, beam_size, max_seq_len), + dtype=np.int64) hyps_pad_sos_eos = hyps_pad_sos_eos * self.eos # fill eos if self.bidecoder: - r_hyps_pad_sos_eos = np.ones((bz, beam_size, max_seq_len), dtype=np.int64) + r_hyps_pad_sos_eos = np.ones((bz, beam_size, max_seq_len), + dtype=np.int64) r_hyps_pad_sos_eos = r_hyps_pad_sos_eos * self.eos hyps_lens_sos = np.ones((bz, beam_size), dtype=np.int32) @@ -249,12 +245,13 @@ def batch_rescoring(self, score_hyps, hist_enc, hist_mask_len, max_len): length = len(cand) + 2 bz_offset = idx % beam_size pad_cand = [self.sos] + cand + [self.eos] - hyps_pad_sos_eos[bz_id][bz_offset][0 : length] = pad_cand + hyps_pad_sos_eos[bz_id][bz_offset][0:length] = pad_cand if self.bidecoder: r_pad_cand = [self.sos] + cand[::-1] + [self.eos] r_hyps_pad_sos_eos[bz_id][bz_offset][0:length] = r_pad_cand hyps_lens_sos[bz_id][idx % beam_size] = len(cand) + 1 - in0 = pb_utils.Tensor.from_dlpack("encoder_out", to_dlpack(encoder_out)) + in0 = pb_utils.Tensor.from_dlpack("encoder_out", + to_dlpack(encoder_out)) in1 = pb_utils.Tensor("encoder_out_lens", encoder_lens) in2 = pb_utils.Tensor("hyps_pad_sos_eos", hyps_pad_sos_eos) in3 = pb_utils.Tensor("hyps_lens_sos", hyps_lens_sos) @@ -264,9 +261,10 @@ def batch_rescoring(self, score_hyps, hist_enc, hist_mask_len, max_len): input_tensors.append(in4) in5 = pb_utils.Tensor.from_dlpack("ctc_score", to_dlpack(ctc_score)) input_tensors.append(in5) - request = pb_utils.InferenceRequest(model_name='decoder', - requested_output_names=['best_index'], - inputs=input_tensors) + request = pb_utils.InferenceRequest( + model_name='decoder', + requested_output_names=['best_index'], + inputs=input_tensors) response = request.exec() best_index = pb_utils.get_output_tensor_by_name(response, 'best_index') best_index = from_dlpack(best_index.to_dlpack()).clone() diff --git a/runtime/gpu/tensorrt/replace_layernorm.py b/runtime/gpu/tensorrt/replace_layernorm.py index c6c126e82..060991103 100755 --- a/runtime/gpu/tensorrt/replace_layernorm.py +++ b/runtime/gpu/tensorrt/replace_layernorm.py @@ -20,45 +20,40 @@ if __name__ == "__main__": parser = argparse.ArgumentParser( - description="process onnx file for trt engine generation" - ) - parser.add_argument( - "--input_onnx", type=str, required=True, help="input onnx model path" - ) - parser.add_argument( - "--output_onnx", type=str, required=True, help="output .npy file path" - ) + description="process onnx file for trt engine generation") + parser.add_argument("--input_onnx", + type=str, + required=True, + help="input onnx model path") + parser.add_argument("--output_onnx", + type=str, + required=True, + help="output .npy file path") args = parser.parse_args() sourceOnnx = args.input_onnx destinationOnnx = args.output_onnx graph = gs.import_onnx( - onnx.shape_inference.infer_shapes(onnx.load(sourceOnnx)) - ) + onnx.shape_inference.infer_shapes(onnx.load(sourceOnnx))) nLayerNormPlugin = 0 for node in graph.nodes: - if ( - node.op == "ReduceMean" - and node.o().op == "Sub" - and node.o().inputs[0] == node.inputs[0] - and node.o().o(0).op == "Pow" - and node.o().o(1).op == "Div" - and node.o().o(0).o().op == "ReduceMean" - and node.o().o(0).o().o().op == "Add" - and node.o().o(0).o().o().o().op == "Sqrt" - and node.o().o(0).o().o().o().o().op == "Div" - and node.o().o(0).o().o().o().o() == node.o().o(1) - and node.o().o(0).o().o().o().o().o().op == "Mul" - and node.o().o(0).o().o().o().o().o().o().op == "Add" - ): + if (node.op == "ReduceMean" and node.o().op == "Sub" + and node.o().inputs[0] == node.inputs[0] + and node.o().o(0).op == "Pow" and node.o().o(1).op == "Div" + and node.o().o(0).o().op == "ReduceMean" + and node.o().o(0).o().o().op == "Add" + and node.o().o(0).o().o().o().op == "Sqrt" + and node.o().o(0).o().o().o().o().op == "Div" + and node.o().o(0).o().o().o().o() == node.o().o(1) + and node.o().o(0).o().o().o().o().o().op == "Mul" + and node.o().o(0).o().o().o().o().o().o().op == "Add"): inputTensor = node.inputs[0] lastMultipyNode = node.o().o(0).o().o().o().o().o() - index = ["weight" in i.name for i in lastMultipyNode.inputs].index( - True - ) + index = ["weight" in i.name + for i in lastMultipyNode.inputs].index(True) b = np.array( deepcopy(lastMultipyNode.inputs[index].values.tolist()), dtype=np.float32, @@ -104,8 +99,7 @@ index = graph.outputs.index(lastAddNode.outputs[0]) # TODO: FIX ME YUEKAI, for offline asr encoder_out dtype graph.outputs[index] = layerNormN.outputs[0].to_variable( - np.float16 - ) + np.float16) # graph.outputs[index] = layerNormN.outputs[0] else: # other LayerNorm contain the subsequent Squeeze operation for n in graph.nodes: diff --git a/runtime/gpu/tensorrt_fastertransformer/extract_weights.py b/runtime/gpu/tensorrt_fastertransformer/extract_weights.py index 78397f50b..3647bd056 100755 --- a/runtime/gpu/tensorrt_fastertransformer/extract_weights.py +++ b/runtime/gpu/tensorrt_fastertransformer/extract_weights.py @@ -36,7 +36,8 @@ def export_GetAllWeight(model, gsg): if 'encoder' in str(w.name) or 'ctc' in str(w.name): print("export ", w.name, w.dims, w.data_type) dtype = utils.onnx2np_type(w.data_type) - res[w.name] = np.frombuffer(w.raw_data, dtype=dtype).reshape(w.dims) + res[w.name] = np.frombuffer(w.raw_data, + dtype=dtype).reshape(w.dims) exported_name.append(w.name) if w.name.endswith("bias"): new_name = w.name[0:len(w.name) - 4] + "weight" @@ -45,11 +46,11 @@ def export_GetAllWeight(model, gsg): continue w = utils.onnx_GetWeight(model, wname) dtype = utils.onnx2np_type(w.data_type) - res[new_name] = np.frombuffer( - w.raw_data, dtype=dtype).reshape(w.dims) + res[new_name] = np.frombuffer(w.raw_data, + dtype=dtype).reshape(w.dims) res[new_name] = np.transpose(res[new_name], (1, 0)) - print("export ", w.name, w.dims, w.data_type, - " -> ", new_name, res[new_name].shape) + print("export ", w.name, w.dims, w.data_type, " -> ", new_name, + res[new_name].shape) exported_name.append(w.name) not_name = get_not(model, exported_name) @@ -63,14 +64,14 @@ def export_GetAllWeight(model, gsg): and node.inputs[1].name == w.name: new_name = "encoder.encoders." + \ str(cur_idx) + ".self_attn.linear_pos.weight" - print("export ", w.name, w.dims, - w.data_type, " -> ", new_name) + print("export ", w.name, w.dims, w.data_type, " -> ", + new_name) dtype = utils.onnx2np_type(w.data_type) - res[new_name] = np.frombuffer( - w.raw_data, dtype=dtype).reshape(w.dims) + res[new_name] = np.frombuffer(w.raw_data, + dtype=dtype).reshape(w.dims) res[new_name] = np.transpose(res[new_name], (1, 0)) - print("export ", w.name, w.dims, w.data_type, - " -> ", new_name, res[new_name].shape) + print("export ", w.name, w.dims, w.data_type, " -> ", + new_name, res[new_name].shape) exported_name.append(w.name) cur_idx += 1 @@ -83,22 +84,22 @@ def export_GetAllWeight(model, gsg): and node.inputs[1].name == w.name: new_name = "encoder.encoders." + \ str(cur_idx) + ".conv_module.depthwise_conv.weight" - print("export ", w.name, w.dims, - w.data_type, " -> ", new_name) + print("export ", w.name, w.dims, w.data_type, " -> ", + new_name) dtype = utils.onnx2np_type(w.data_type) - res[new_name] = np.frombuffer( - w.raw_data, dtype=dtype).reshape(w.dims) + res[new_name] = np.frombuffer(w.raw_data, + dtype=dtype).reshape(w.dims) exported_name.append(w.name) bname = node.inputs[2].name w = utils.onnx_GetWeight(model, bname) new_name = "encoder.encoders." + \ str(cur_idx) + ".conv_module.depthwise_conv.bias" - print("export ", w.name, w.dims, - w.data_type, " -> ", new_name) + print("export ", w.name, w.dims, w.data_type, " -> ", + new_name) dtype = utils.onnx2np_type(w.data_type) - res[new_name] = np.frombuffer( - w.raw_data, dtype=dtype).reshape(w.dims) + res[new_name] = np.frombuffer(w.raw_data, + dtype=dtype).reshape(w.dims) exported_name.append(w.name) cur_idx += 1 @@ -124,7 +125,8 @@ def export_decoder_GetAllWeight(model, gsg): if len(str(w.name)) > 4: print("export ", w.name, w.dims, w.data_type) dtype = utils.onnx2np_type(w.data_type) - res[w.name] = np.frombuffer(w.raw_data, dtype=dtype).reshape(w.dims) + res[w.name] = np.frombuffer(w.raw_data, + dtype=dtype).reshape(w.dims) exported_name.append(w.name) if w.name.endswith("bias"): new_name = w.name[0:len(w.name) - 4] + "weight" @@ -133,11 +135,11 @@ def export_decoder_GetAllWeight(model, gsg): continue w = utils.onnx_GetWeight(model, wname) dtype = utils.onnx2np_type(w.data_type) - res[new_name] = np.frombuffer( - w.raw_data, dtype=dtype).reshape(w.dims) + res[new_name] = np.frombuffer(w.raw_data, + dtype=dtype).reshape(w.dims) res[new_name] = np.transpose(res[new_name], (1, 0)) - print("export ", w.name, w.dims, w.data_type, - " -> ", new_name, res[new_name].shape) + print("export ", w.name, w.dims, w.data_type, " -> ", new_name, + res[new_name].shape) exported_name.append(w.name) for node in gsg.nodes: @@ -163,10 +165,14 @@ def export_decoder_GetAllWeight(model, gsg): if __name__ == "__main__": parser = argparse.ArgumentParser( description='process onnx file for trt engine generation') - parser.add_argument('--input_onnx', type=str, - required=True, help="input onnx model path") - parser.add_argument('--output_dir', type=str, - required=True, help="output weights dir") + parser.add_argument('--input_onnx', + type=str, + required=True, + help="input onnx model path") + parser.add_argument('--output_dir', + type=str, + required=True, + help="output weights dir") args = parser.parse_args() diff --git a/runtime/gpu/tensorrt_fastertransformer/model_repo_ft/feature_extractor/1/model.py b/runtime/gpu/tensorrt_fastertransformer/model_repo_ft/feature_extractor/1/model.py index c54f6ca3e..50406df2d 100755 --- a/runtime/gpu/tensorrt_fastertransformer/model_repo_ft/feature_extractor/1/model.py +++ b/runtime/gpu/tensorrt_fastertransformer/model_repo_ft/feature_extractor/1/model.py @@ -21,7 +21,9 @@ from typing import List import json + class Fbank(torch.nn.Module): + def __init__(self, opts): super(Fbank, self).__init__() self.fbank = kaldifeat.Fbank(opts) @@ -131,7 +133,8 @@ def execute(self, requests): batch_len.append(cur_len) for wav, wav_len in zip(cur_b_wav, cur_b_wav_lens): wav_len = wav_len[0] - wav = torch.tensor(wav[0:wav_len], dtype=torch.float32, + wav = torch.tensor(wav[0:wav_len], + dtype=torch.float32, device=self.device) total_waves.append(wav) @@ -139,12 +142,15 @@ def execute(self, requests): for b, l in zip(batch_count, batch_len): expect_feat_len = _kaldifeat.num_frames(l, self.opts.frame_opts) speech = torch.zeros((b, expect_feat_len, self.feature_size), - dtype=self.output0_dtype, device=self.device) - speech_lengths = torch.zeros((b, 1), dtype=torch.int32, device=self.device) + dtype=self.output0_dtype, + device=self.device) + speech_lengths = torch.zeros((b, 1), + dtype=torch.int32, + device=self.device) for i in range(b): f = features.pop(0) f_l = f.shape[0] - speech[i, 0: f_l, :] = f.to(self.output0_dtype) + speech[i, 0:f_l, :] = f.to(self.output0_dtype) speech_lengths[i][0] = f_l # put speech feature on device will cause empty output # we will follow this issue and now temporarily put it on cpu @@ -153,6 +159,7 @@ def execute(self, requests): out0 = pb_utils.Tensor.from_dlpack("speech", to_dlpack(speech)) out1 = pb_utils.Tensor.from_dlpack("speech_lengths", to_dlpack(speech_lengths)) - inference_response = pb_utils.InferenceResponse(output_tensors=[out0, out1]) + inference_response = pb_utils.InferenceResponse( + output_tensors=[out0, out1]) responses.append(inference_response) return responses diff --git a/runtime/gpu/tensorrt_fastertransformer/model_repo_ft/scoring/1/model.py b/runtime/gpu/tensorrt_fastertransformer/model_repo_ft/scoring/1/model.py index 3a4dd5040..685bb8646 100755 --- a/runtime/gpu/tensorrt_fastertransformer/model_repo_ft/scoring/1/model.py +++ b/runtime/gpu/tensorrt_fastertransformer/model_repo_ft/scoring/1/model.py @@ -170,16 +170,17 @@ def execute(self, requests): for request in requests: # Perform inference on the request and append it to responses list... in_0 = pb_utils.get_input_tensor_by_name(request, "encoder_out") - in_1 = pb_utils.get_input_tensor_by_name( - request, "encoder_out_lens") - in_2 = pb_utils.get_input_tensor_by_name(request, "batch_log_probs") - in_3 = pb_utils.get_input_tensor_by_name( - request, "batch_log_probs_idx") + in_1 = pb_utils.get_input_tensor_by_name(request, + "encoder_out_lens") + in_2 = pb_utils.get_input_tensor_by_name(request, + "batch_log_probs") + in_3 = pb_utils.get_input_tensor_by_name(request, + "batch_log_probs_idx") batch_encoder_out.append(in_0.as_numpy()) - encoder_max_len = max( - encoder_max_len, batch_encoder_out[-1].shape[1]) + encoder_max_len = max(encoder_max_len, + batch_encoder_out[-1].shape[1]) cur_b_lens = in_1.as_numpy() batch_encoder_lens.append(cur_b_lens) @@ -201,16 +202,17 @@ def execute(self, requests): batch_start.append(True) total += 1 - score_hyps = ctc_beam_search_decoder_batch(batch_log_probs, - batch_log_probs_idx, - batch_root, - batch_start, - self.beam_size, - min(total, self.num_processes), - blank_id=self.blank_id, - space_id=-2, - cutoff_prob=self.cutoff_prob, - ext_scorer=self.lm) + score_hyps = ctc_beam_search_decoder_batch( + batch_log_probs, + batch_log_probs_idx, + batch_root, + batch_start, + self.beam_size, + min(total, self.num_processes), + blank_id=self.blank_id, + space_id=-2, + cutoff_prob=self.cutoff_prob, + ext_scorer=self.lm) all_hyps = [] all_ctc_score = [] @@ -236,8 +238,8 @@ def execute(self, requests): in_hyps_pad_sos_eos = np.ones( (total, beam_size, hyps_max_len), dtype=np.int32) * self.ignore_id if self.bidecoder: - in_r_hyps_pad_sos_eos = np.ones( - (total, beam_size, hyps_max_len), dtype=np.int32) * self.ignore_id + in_r_hyps_pad_sos_eos = np.ones((total, beam_size, hyps_max_len), + dtype=np.int32) * self.ignore_id in_hyps_lens_sos = np.ones((total, beam_size), dtype=np.int32) @@ -270,8 +272,8 @@ def execute(self, requests): in_tensor_3 = pb_utils.Tensor("hyps_lens_sos", in_hyps_lens_sos) input_tensors = [in_tensor_0, in_tensor_1, in_tensor_2, in_tensor_3] if self.bidecoder: - in_tensor_4 = pb_utils.Tensor( - "r_hyps_pad_sos_eos", in_r_hyps_pad_sos_eos) + in_tensor_4 = pb_utils.Tensor("r_hyps_pad_sos_eos", + in_r_hyps_pad_sos_eos) input_tensors.append(in_tensor_4) in_tensor_5 = pb_utils.Tensor("ctc_score", in_ctc_score) input_tensors.append(in_tensor_5) @@ -287,8 +289,8 @@ def execute(self, requests): inference_response.error().message()) else: # Extract the output tensors from the inference response. - best_index = pb_utils.get_output_tensor_by_name(inference_response, - 'best_index') + best_index = pb_utils.get_output_tensor_by_name( + inference_response, 'best_index') best_index = from_dlpack(best_index.to_dlpack()) best_index = best_index.cpu().numpy() @@ -297,16 +299,18 @@ def execute(self, requests): for cands, cand_lens in zip(in_hyps_pad_sos_eos, in_hyps_lens_sos): best_idx = best_index[idx][0] best_cand_len = cand_lens[best_idx] - 1 # remove sos - best_cand = cands[best_idx][1: 1 + best_cand_len].tolist() + best_cand = cands[best_idx][1:1 + best_cand_len].tolist() hyps.append(best_cand) idx += 1 - hyps = map_batch(hyps, self.vocabulary, - min(multiprocessing.cpu_count(), len(in_ctc_score))) + hyps = map_batch( + hyps, self.vocabulary, + min(multiprocessing.cpu_count(), len(in_ctc_score))) st = 0 for b in batch_count: sents = np.array(hyps[st:st + b]) - out0 = pb_utils.Tensor("OUTPUT0", sents.astype(self.out0_dtype)) + out0 = pb_utils.Tensor("OUTPUT0", + sents.astype(self.out0_dtype)) inference_response = pb_utils.InferenceResponse( output_tensors=[out0]) responses.append(inference_response) diff --git a/runtime/gpu/tensorrt_fastertransformer/replace_plugin.py b/runtime/gpu/tensorrt_fastertransformer/replace_plugin.py index 93e8eb0f6..4fa242f0d 100755 --- a/runtime/gpu/tensorrt_fastertransformer/replace_plugin.py +++ b/runtime/gpu/tensorrt_fastertransformer/replace_plugin.py @@ -30,39 +30,63 @@ def replace_plugin(self, inputs, outputs, op, name, attrs): inputs=inputs, outputs=outputs, name=name, - attrs=attrs - ) + attrs=attrs) if __name__ == "__main__": parser = argparse.ArgumentParser( description='process onnx file for trt engine generation') - parser.add_argument('--input_onnx', type=str, - required=True, help="input onnx model path") - parser.add_argument('--output_onnx', type=str, - required=True, help="output .npy file path") - parser.add_argument('--max_len', type=int, default=5000, + parser.add_argument('--input_onnx', + type=str, + required=True, + help="input onnx model path") + parser.add_argument('--output_onnx', + type=str, + required=True, + help="output .npy file path") + parser.add_argument('--max_len', + type=int, + default=5000, help="Max seq for pos embedding, TODO: remove this") - parser.add_argument('--head_num', type=int, default=4, - choices=[4, 8], help="") + parser.add_argument('--head_num', + type=int, + default=4, + choices=[4, 8], + help="") parser.add_argument('--feature_size', type=int, default=80, help="") parser.add_argument('--inter_size', type=int, default=2048, help="") - parser.add_argument('--d_model', type=int, default=256, - choices=[256, 512], help="") + parser.add_argument('--d_model', + type=int, + default=256, + choices=[256, 512], + help="") parser.add_argument('--num_layer', type=int, default=12, help="") parser.add_argument('--vocab_size', type=int, default=4233, help="") - parser.add_argument('--conv_module_kernel_size', type=int, default=15, - choices=[15, 31], help="kernel size for conv module") + parser.add_argument('--conv_module_kernel_size', + type=int, + default=15, + choices=[15, 31], + help="kernel size for conv module") # TODO: hard-coding below encoder decoder weight path, pls don't change it for now - parser.add_argument('--decoder_weight_path', type=str, - default="/weight/dec/", help="decoder weights path") - parser.add_argument('--encoder_weight_path', type=str, - default="/weight/enc/", help="encoder weights path") - parser.add_argument('--useFP16', type=bool, - default=True, help="using fp16 mode") - parser.add_argument('--use_layernorm_in_conv_module', action='store_true', - default=False, help="using layernorm in conformer conv module") - parser.add_argument('--q_scaling', type=float, default=1.0, + parser.add_argument('--decoder_weight_path', + type=str, + default="/weight/dec/", + help="decoder weights path") + parser.add_argument('--encoder_weight_path', + type=str, + default="/weight/enc/", + help="encoder weights path") + parser.add_argument('--useFP16', + type=bool, + default=True, + help="using fp16 mode") + parser.add_argument('--use_layernorm_in_conv_module', + action='store_true', + default=False, + help="using layernorm in conformer conv module") + parser.add_argument('--q_scaling', + type=float, + default=1.0, help="please hard-coding it for now") args = parser.parse_args() @@ -73,8 +97,10 @@ def replace_plugin(self, inputs, outputs, op, name, attrs): if 'encoder' in args.input_onnx: inputs = [tmap[i] for i in ["speech", "speech_lengths"]] - outputs = [tmap[i] - for i in ["encoder_out", "encoder_out_lens", "ctc_log_probs"]] + outputs = [ + tmap[i] + for i in ["encoder_out", "encoder_out_lens", "ctc_log_probs"] + ] op = "WenetEncoderPlugin" name = "WenetEncoder" attrs = { @@ -94,8 +120,12 @@ def replace_plugin(self, inputs, outputs, op, name, attrs): } elif 'decoder' in args.input_onnx: - inputs = [tmap[i] for i in ["hyps_pad_sos_eos", "hyps_lens_sos", - "encoder_out", "encoder_out_lens", "ctc_score"]] + inputs = [ + tmap[i] for i in [ + "hyps_pad_sos_eos", "hyps_lens_sos", "encoder_out", + "encoder_out_lens", "ctc_score" + ] + ] outputs = [tmap[i] for i in ["decoder_out", "best_index"]] op = "WenetDecoderPlugin" name = "WenetDecoder" diff --git a/runtime/gpu/tensorrt_fastertransformer/utils.py b/runtime/gpu/tensorrt_fastertransformer/utils.py index 6bdab7bf7..4c9b2bae5 100755 --- a/runtime/gpu/tensorrt_fastertransformer/utils.py +++ b/runtime/gpu/tensorrt_fastertransformer/utils.py @@ -41,11 +41,7 @@ def onnx_GetAllWeight(model): def onnx2np_type(dtype): - maps = { - 1: np.float32, - 6: np.int32, - 7: np.int64 - } + maps = {1: np.float32, 6: np.int32, 7: np.int64} return maps[dtype] diff --git a/runtime/horizonbpu/bpu/bpu_asr_model.cc b/runtime/horizonbpu/bpu/bpu_asr_model.cc index 6d6cfedbe..98388f811 100644 --- a/runtime/horizonbpu/bpu/bpu_asr_model.cc +++ b/runtime/horizonbpu/bpu/bpu_asr_model.cc @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. - #include "bpu/bpu_asr_model.h" #include @@ -30,16 +29,20 @@ void BPUAsrModel::GetInputOutputInfo( for (size_t i = 0; i < input.size(); ++i) { auto& shapes = input[i]->properties.validShape.dimensionSize; std::string layout = (input[i]->properties.tensorLayout == - hbDNNTensorLayout::HB_DNN_LAYOUT_NHWC ? "NHWC" : "NCHW"); - LOG(INFO) << "\tInput-" << i << ": Shape [" << shapes[0] << "," - << shapes[1] << "," << shapes[2] << "," << shapes[3] - << "], Layout [" << layout << "]"; + hbDNNTensorLayout::HB_DNN_LAYOUT_NHWC + ? "NHWC" + : "NCHW"); + LOG(INFO) << "\tInput-" << i << ": Shape [" << shapes[0] << "," << shapes[1] + << "," << shapes[2] << "," << shapes[3] << "], Layout [" << layout + << "]"; } // Output info for (size_t i = 0; i < output.size(); ++i) { auto& shapes = output[i]->properties.validShape.dimensionSize; std::string layout = (output[i]->properties.tensorLayout == - hbDNNTensorLayout::HB_DNN_LAYOUT_NHWC ? "NHWC" : "NCHW"); + hbDNNTensorLayout::HB_DNN_LAYOUT_NHWC + ? "NHWC" + : "NCHW"); LOG(INFO) << "\tOutput-" << i << ": Shape [" << shapes[0] << "," << shapes[1] << "," << shapes[2] << "," << shapes[3] << "], Layout [" << layout << "]"; @@ -86,13 +89,14 @@ void BPUAsrModel::Read(const std::string& model_dir) { sos_ = ctc_output_[0]->properties.validShape.dimensionSize[1] - 1; eos_ = sos_; chunk_size_ = ctc_input_[0]->properties.validShape.dimensionSize[3]; - num_left_chunks_ = encoder_input_[3]->properties.validShape.dimensionSize[3] - / chunk_size_ - 1; + num_left_chunks_ = + encoder_input_[3]->properties.validShape.dimensionSize[3] / chunk_size_ - + 1; hidden_dim_ = ctc_input_[0]->properties.validShape.dimensionSize[1]; int frames = (chunk_size_ - 1) * subsampling_rate_ + right_context_ + 1; - CHECK_EQ(frames, encoder_input_[0]->properties.validShape.dimensionSize[2]) << - "NOTE(xcsong): Only support 1/8 subsample, since 1/4 subsample" << - " is too slow on edge-devices."; + CHECK_EQ(frames, encoder_input_[0]->properties.validShape.dimensionSize[2]) + << "NOTE(xcsong): Only support 1/8 subsample, since 1/4 subsample" + << " is too slow on edge-devices."; LOG(INFO) << "Bpu Model Info:"; LOG(INFO) << "\tchunk_size " << chunk_size_; LOG(INFO) << "\tnum_left_chunks " << num_left_chunks_; @@ -126,9 +130,9 @@ std::shared_ptr BPUAsrModel::Copy() const { auto asr_model = std::make_shared(*this); // Reset the inner states for new decoding asr_model->AllocMemory(encoder_model_, &(asr_model->encoder_input_), - &(asr_model->encoder_output_)); + &(asr_model->encoder_output_)); asr_model->AllocMemory(ctc_model_, &(asr_model->ctc_input_), - &(asr_model->ctc_output_)); + &(asr_model->ctc_output_)); asr_model->Reset(); return asr_model; } @@ -246,7 +250,7 @@ void BPUAsrModel::PrepareEncoderInput( cached_feature_[i].size() * sizeof(float)); addr_shift += cached_feature_[i].size(); } - for (size_t i = 0; i < chunk_feats.size(); ++i) { // copy chunk_feats + for (size_t i = 0; i < chunk_feats.size(); ++i) { // copy chunk_feats memcpy(feat_ptr + addr_shift, chunk_feats[i].data(), chunk_feats[i].size() * sizeof(float)); addr_shift += chunk_feats[i].size(); @@ -270,7 +274,9 @@ void BPUAsrModel::PrepareEncoderInput( int head = encoder_input_[3]->properties.validShape.dimensionSize[1]; if (valid_len <= total_len) { std::vector padding(total_len, 1.0f); - for (size_t i = 0; i < total_len - valid_len; ++i) { padding[i] = 0.0f;} + for (size_t i = 0; i < total_len - valid_len; ++i) { + padding[i] = 0.0f; + } for (size_t i = 0; i < head * chunk_size_; ++i) { float* start_ptr = reinterpret_cast(att_mask->sysMem[0].virAddr) + total_len * i; @@ -307,9 +313,17 @@ void BPUAsrModel::AttentionRescoring(const std::vector>& hyps, } BPUAsrModel::~BPUAsrModel() { - for (auto& tensor : encoder_input_) { hbSysFreeMem(tensor->sysMem); } - for (auto& tensor : encoder_output_) { hbSysFreeMem(tensor->sysMem); } - for (auto& tensor : ctc_input_) { hbSysFreeMem(tensor->sysMem); } - for (auto& tensor : ctc_output_) { hbSysFreeMem(tensor->sysMem); } + for (auto& tensor : encoder_input_) { + hbSysFreeMem(tensor->sysMem); + } + for (auto& tensor : encoder_output_) { + hbSysFreeMem(tensor->sysMem); + } + for (auto& tensor : ctc_input_) { + hbSysFreeMem(tensor->sysMem); + } + for (auto& tensor : ctc_output_) { + hbSysFreeMem(tensor->sysMem); + } } } // namespace wenet diff --git a/runtime/horizonbpu/bpu/bpu_asr_model.h b/runtime/horizonbpu/bpu/bpu_asr_model.h index 2eee1a8c7..dab5a83ac 100644 --- a/runtime/horizonbpu/bpu/bpu_asr_model.h +++ b/runtime/horizonbpu/bpu/bpu_asr_model.h @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. - #ifndef RUNTIME_HORIZONBPU_BPU_BPU_ASR_MODEL_H_ #define RUNTIME_HORIZONBPU_BPU_BPU_ASR_MODEL_H_ @@ -20,19 +19,19 @@ #include #include +#include "easy_dnn/data_structure.h" #include "easy_dnn/model.h" #include "easy_dnn/model_manager.h" #include "easy_dnn/task_manager.h" -#include "easy_dnn/data_structure.h" #include "decoder/asr_model.h" #include "utils/log.h" #include "utils/utils.h" -using hobot::easy_dnn::Model; using hobot::easy_dnn::DNNTensor; -using hobot::easy_dnn::TaskManager; +using hobot::easy_dnn::Model; using hobot::easy_dnn::ModelManager; +using hobot::easy_dnn::TaskManager; namespace wenet { @@ -47,10 +46,9 @@ class BPUAsrModel : public AsrModel { float reverse_weight, std::vector* rescoring_score) override; std::shared_ptr Copy() const override; - static void AllocMemory( - const std::shared_ptr& model, - std::vector>* input, - std::vector>* output); + static void AllocMemory(const std::shared_ptr& model, + std::vector>* input, + std::vector>* output); void GetInputOutputInfo( const std::vector>& input_tensors, const std::vector>& output_tensors); @@ -76,7 +74,7 @@ class BPUAsrModel : public AsrModel { // input/output tensors std::vector> encoder_input_, encoder_output_; std::vector> ctc_input_, ctc_output_; - std::vector > encoder_outs_; + std::vector> encoder_outs_; }; } // namespace wenet diff --git a/runtime/libtorch/web/app.py b/runtime/libtorch/web/app.py index b880cf7ff..40d8f3b9d 100644 --- a/runtime/libtorch/web/app.py +++ b/runtime/libtorch/web/app.py @@ -14,9 +14,11 @@ app = Flask(__name__) + @app.route('/') def index(): return render_template('index.html') + if __name__ == '__main__': app.run(host='0.0.0.0', port=args.port, debug=True) diff --git a/runtime/openvino/ov/ov_asr_model.cc b/runtime/openvino/ov/ov_asr_model.cc index aa79ec71d..a1a164ac1 100644 --- a/runtime/openvino/ov/ov_asr_model.cc +++ b/runtime/openvino/ov/ov_asr_model.cc @@ -129,10 +129,10 @@ void OVAsrModel::Read(const std::string& model_dir) { num_left_chunks_ = -1; } - encoder_compile_model_ = std::make_shared(std::move( - core_->compile_model(encoder_model, "CPU"))); - // {{"PERF_COUNT", "NO"} /* YES for profile */ - // }))); + encoder_compile_model_ = std::make_shared( + std::move(core_->compile_model(encoder_model, "CPU"))); + // {{"PERF_COUNT", "NO"} /* YES for profile */ + // }))); auto inputs = encoder_compile_model_->inputs(); for (auto& input : inputs) { @@ -144,10 +144,10 @@ void OVAsrModel::Read(const std::string& model_dir) { } std::shared_ptr ctc_model = core_->read_model(ctc_ir_path); if (ctc_model) { - ctc_compile_model_ = std::make_shared(std::move( - core_->compile_model(ctc_model, "CPU"))); - // {{"PERFORMANCE_HINT", "THROUGHPUT"}, - // {"PERFORMANCE_HINT_NUM_REQUESTS", 1}}))); + ctc_compile_model_ = std::make_shared( + std::move(core_->compile_model(ctc_model, "CPU"))); + // {{"PERFORMANCE_HINT", "THROUGHPUT"}, + // {"PERFORMANCE_HINT_NUM_REQUESTS", 1}}))); ctc_infer_ = std::make_shared( std::move(ctc_compile_model_->create_infer_request())); @@ -160,10 +160,10 @@ void OVAsrModel::Read(const std::string& model_dir) { std::shared_ptr rescore_model = core_->read_model(rescore_ir_path); if (rescore_model) { - rescore_compile_model_ = std::make_shared(std::move( - core_->compile_model(rescore_model, "CPU"))); - // {{"PERFORMANCE_HINT", "THROUGHPUT"}, - // {"PERFORMANCE_HINT_NUM_REQUESTS", 1}}))); + rescore_compile_model_ = std::make_shared( + std::move(core_->compile_model(rescore_model, "CPU"))); + // {{"PERFORMANCE_HINT", "THROUGHPUT"}, + // {"PERFORMANCE_HINT_NUM_REQUESTS", 1}}))); rescore_infer_ = std::make_shared( std::move(rescore_compile_model_->create_infer_request())); @@ -217,11 +217,11 @@ OVAsrModel::OVAsrModel(const OVAsrModel& other) { ctc_compile_model_ = other.ctc_compile_model_; rescore_compile_model_ = other.rescore_compile_model_; encoder_infer_ = std::make_shared( - std::move(encoder_compile_model_->create_infer_request())); + std::move(encoder_compile_model_->create_infer_request())); ctc_infer_ = std::make_shared( - std::move(ctc_compile_model_->create_infer_request())); + std::move(ctc_compile_model_->create_infer_request())); rescore_infer_ = std::make_shared( - std::move(rescore_compile_model_->create_infer_request())); + std::move(rescore_compile_model_->create_infer_request())); } std::shared_ptr OVAsrModel::Copy() const { @@ -315,16 +315,15 @@ void OVAsrModel::ForwardEncoderFunc( // set input tensor size_t idx = 0; - std::map>::iterator it \ - = encoder_inputs_map_.begin(); + std::map>::iterator it = + encoder_inputs_map_.begin(); while (it != encoder_inputs_map_.end()) { if (it->first == "chunk") { encoder_infer_->set_tensor(it->second, feats_ov); } else if (it->first == "offset") { encoder_infer_->set_tensor(it->second, offset_ov); } else if (it->first == "required_cache_size") { - encoder_infer_->set_tensor(it->second, - required_cache_size_ov); + encoder_infer_->set_tensor(it->second, required_cache_size_ov); } else if (it->first == "att_cache") { encoder_infer_->set_tensor(it->second, att_cache_ov_); } else if (it->first == "cnn_cache") { diff --git a/runtime/openvino/ov/ov_asr_model.h b/runtime/openvino/ov/ov_asr_model.h index be8b4e6d0..f7a2df651 100644 --- a/runtime/openvino/ov/ov_asr_model.h +++ b/runtime/openvino/ov/ov_asr_model.h @@ -5,14 +5,14 @@ #ifndef RUNTIME_OPENVINO_OV_OV_ASR_MODEL_H_ #define RUNTIME_OPENVINO_OV_OV_ASR_MODEL_H_ +#include #include #include #include -#include -#include "openvino/openvino.hpp" #include "decoder/asr_model.h" -#include "utils/utils.h" +#include "openvino/openvino.hpp" #include "utils/log.h" +#include "utils/utils.h" namespace wenet { diff --git a/test/test_file_utils.py b/test/test_file_utils.py index cc38ae3bc..55d0a216e 100644 --- a/test/test_file_utils.py +++ b/test/test_file_utils.py @@ -7,13 +7,10 @@ from wenet.utils.file_utils import read_non_lang_symbols -@pytest.mark.parametrize( - "non_lang_symbol_table_path", - [ - "test/resources/non-linguistic-symbols.valid", - "test/resources/non-linguistic-symbols.invalid" - ] -) +@pytest.mark.parametrize("non_lang_symbol_table_path", [ + "test/resources/non-linguistic-symbols.valid", + "test/resources/non-linguistic-symbols.invalid" +]) def test_read_non_lang_symbols(non_lang_symbol_table_path): path = non_lang_symbol_table_path try: @@ -21,8 +18,11 @@ def test_read_non_lang_symbols(non_lang_symbol_table_path): assert syms[0] == "{~!@#$%^&*()_+`1234567890-=[]|\\\\:;\"'<>,./?}" assert syms[1] == "[~!@#$%^&*()_+`1234567890-={}|\\\\:;\"'<>,./?]" assert syms[2] == "<~!@#$%^&*()_+`1234567890-={}|\\\\:;\"'[],./?>" - assert syms[3] == "{qwertyuiopasdfghjklzxcvbnmQWERTYUIOPASDFGHJKLZXCVBNM}" - assert syms[4] == "[qwertyuiopasdfghjklzxcvbnmQWERTYUIOPASDFGHJKLZXCVBNM]" - assert syms[5] == "" + assert syms[ + 3] == "{qwertyuiopasdfghjklzxcvbnmQWERTYUIOPASDFGHJKLZXCVBNM}" + assert syms[ + 4] == "[qwertyuiopasdfghjklzxcvbnmQWERTYUIOPASDFGHJKLZXCVBNM]" + assert syms[ + 5] == "" except Exception as e: assert path == "test/resources/non-linguistic-symbols.invalid" diff --git a/test/wenet/transformer/test_grad_ckpt.py b/test/wenet/transformer/test_grad_ckpt.py index d6d950367..ef0fa95a0 100644 --- a/test/wenet/transformer/test_grad_ckpt.py +++ b/test/wenet/transformer/test_grad_ckpt.py @@ -10,16 +10,22 @@ from wenet.transformer.decoder import TransformerDecoder, BiTransformerDecoder -@pytest.mark.parametrize( - "module", [TransformerEncoder, ConformerEncoder, - TransformerDecoder, BiTransformerDecoder] -) +@pytest.mark.parametrize("module", [ + TransformerEncoder, ConformerEncoder, TransformerDecoder, + BiTransformerDecoder +]) def test_encoder_gradient_checkpointing(module): torch.manual_seed(777) # Init model - model = module(80, 256, dropout_rate=0.0, positional_dropout_rate=0.0, + model = module(80, + 256, + dropout_rate=0.0, + positional_dropout_rate=0.0, gradient_checkpointing=False) - model_grad_ckpt = module(80, 256, dropout_rate=0.0, positional_dropout_rate=0.0, + model_grad_ckpt = module(80, + 256, + dropout_rate=0.0, + positional_dropout_rate=0.0, gradient_checkpointing=True) model_grad_ckpt.load_state_dict(model.state_dict(), strict=True) model.train() @@ -37,26 +43,30 @@ def test_encoder_gradient_checkpointing(module): logits_grad_ckpt = model_grad_ckpt(xs, xs_lens)[0] elif module in [TransformerDecoder, BiTransformerDecoder]: logits = model(memory, memory_mask, tgt, tgt_lens, r_tgt)[0] - logits_grad_ckpt = model_grad_ckpt(memory, memory_mask, tgt, tgt_lens, r_tgt)[0] + logits_grad_ckpt = model_grad_ckpt(memory, memory_mask, tgt, tgt_lens, + r_tgt)[0] else: raise NotImplementedError np.testing.assert_allclose(logits.detach().numpy(), logits_grad_ckpt.detach().numpy(), - rtol=1e-7, atol=1e-10) + rtol=1e-7, + atol=1e-10) # Backward model.zero_grad() logits.sum().backward() model_grad_ckpt.zero_grad() logits_grad_ckpt.sum().backward() - for (name1, param1), (name2, param2) in zip( - model.named_parameters(), model_grad_ckpt.named_parameters()): + for (name1, param1), (name2, + param2) in zip(model.named_parameters(), + model_grad_ckpt.named_parameters()): assert name1 == name2 if param1.grad is None or param2.grad is None: print("Ignore {}, due to grad = None".format(name1)) elif not param1.requires_grad or not param2.requires_grad: print("Ignore {}, due to requires_grad = False".format(name1)) else: - np.testing.assert_allclose( - param1.grad.detach().numpy(), param2.grad.detach().numpy(), - rtol=1e-7, atol=1e-10) + np.testing.assert_allclose(param1.grad.detach().numpy(), + param2.grad.detach().numpy(), + rtol=1e-7, + atol=1e-10) print("Pass {}".format(name1)) diff --git a/test/wenet/whisper/test_whisper.py b/test/wenet/whisper/test_whisper.py index a8cec25ae..0eb0df9a7 100644 --- a/test/wenet/whisper/test_whisper.py +++ b/test/wenet/whisper/test_whisper.py @@ -18,13 +18,11 @@ from wenet.text.whisper_tokenizer import WhisperTokenizer from wenet.transformer.embedding import WhisperPositionalEncoding from wenet.whisper.convert_whisper_to_wenet_config_and_ckpt import ( - convert_to_wenet_yaml, convert_to_wenet_state_dict, convert_to_wenet_units -) + convert_to_wenet_yaml, convert_to_wenet_state_dict, convert_to_wenet_units) from wenet.utils.common import add_whisper_tokens from wenet.utils.init_model import init_model from wenet.utils.mask import make_pad_mask, subsequent_mask - torch.manual_seed(777) np.random.seed(777) @@ -35,95 +33,102 @@ class DummyArguments: checkpoint = None -@pytest.mark.parametrize( - "audio_path", - [ - "test/resources/aishell-BAC009S0724W0121.wav", - "test/resources/librispeech-1995-1837-0001.wav" - ] -) +@pytest.mark.parametrize("audio_path", [ + "test/resources/aishell-BAC009S0724W0121.wav", + "test/resources/librispeech-1995-1837-0001.wav" +]) def test_load_audio(audio_path): waveform_wenet, sample_rate = torchaudio.load(audio_path) waveform_wenet = waveform_wenet.numpy().flatten().astype(np.float32) wavform_whisper = whisper.load_audio(audio_path) - np.testing.assert_allclose(waveform_wenet, wavform_whisper, - rtol=1e-7, atol=1e-10) + np.testing.assert_allclose(waveform_wenet, + wavform_whisper, + rtol=1e-7, + atol=1e-10) -@pytest.mark.parametrize( - "audio_path", - [ - "test/resources/aishell-BAC009S0724W0121.wav", - "test/resources/librispeech-1995-1837-0001.wav" - ] -) +@pytest.mark.parametrize("audio_path", [ + "test/resources/aishell-BAC009S0724W0121.wav", + "test/resources/librispeech-1995-1837-0001.wav" +]) def test_log_mel_spectrogram(audio_path): waveform_wenet, sample_rate = torchaudio.load(audio_path) - sample = {"wav": waveform_wenet, "sample_rate": sample_rate, - "key": audio_path, "label": ""} - log_spec_wenet = next(compute_log_mel_spectrogram( - [sample], n_fft=N_FFT, hop_length=HOP_LENGTH, num_mel_bins=128, padding=0 - ))["feat"] + sample = { + "wav": waveform_wenet, + "sample_rate": sample_rate, + "key": audio_path, + "label": "" + } + log_spec_wenet = next( + compute_log_mel_spectrogram([sample], + n_fft=N_FFT, + hop_length=HOP_LENGTH, + num_mel_bins=128, + padding=0))["feat"] log_spec_wenet = log_spec_wenet.transpose(0, 1).numpy().astype(np.float32) - log_spec_whisper = whisper.log_mel_spectrogram(audio_path, n_mels=128, padding=0) - np.testing.assert_allclose(log_spec_wenet, log_spec_whisper, - rtol=1e-7, atol=1e-10) + log_spec_whisper = whisper.log_mel_spectrogram(audio_path, + n_mels=128, + padding=0) + np.testing.assert_allclose(log_spec_wenet, + log_spec_whisper, + rtol=1e-7, + atol=1e-10) @pytest.mark.parametrize( - "length,channels", [(512, 80), (1024, 128), (2048, 256), (4096, 512)], + "length,channels", + [(512, 80), (1024, 128), (2048, 256), (4096, 512)], ) def test_sinusoids(length, channels): - sinusoids_whisper = whisper.model.sinusoids(length, channels, max_timescale=10000) - sinusoids_wenet = WhisperPositionalEncoding(d_model=channels, dropout_rate=0.0, + sinusoids_whisper = whisper.model.sinusoids(length, + channels, + max_timescale=10000) + sinusoids_wenet = WhisperPositionalEncoding(d_model=channels, + dropout_rate=0.0, max_len=length) np.testing.assert_allclose(sinusoids_wenet.pe.squeeze(0).numpy(), sinusoids_whisper.numpy(), - rtol=1e-7, atol=1e-10) + rtol=1e-7, + atol=1e-10) -@pytest.mark.parametrize( - "model,audio_path", - [ - ("tiny", "test/resources/aishell-BAC009S0724W0121.wav"), - ("base", "test/resources/librispeech-1995-1837-0001.wav"), - ("small", "test/resources/aishell-BAC009S0724W0121.wav"), - ("medium", "test/resources/librispeech-1995-1837-0001.wav"), - ] -) +@pytest.mark.parametrize("model,audio_path", [ + ("tiny", "test/resources/aishell-BAC009S0724W0121.wav"), + ("base", "test/resources/librispeech-1995-1837-0001.wav"), + ("small", "test/resources/aishell-BAC009S0724W0121.wav"), + ("medium", "test/resources/librispeech-1995-1837-0001.wav"), +]) def test_model(model, audio_path): default = os.path.join(os.path.expanduser("~"), ".cache") download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), - "whisper", - "{}".format(model)) + "whisper", "{}".format(model)) language = "zh" task = "transcribe" # 1. Init whisper - whisper_model = whisper.load_model(model, device="cpu", + whisper_model = whisper.load_model(model, + device="cpu", download_root=download_root).float() whisper_model.eval() # 2. Init wenet - checkpoint = torch.load("{}/{}.pt".format(download_root, model), map_location="cpu") + checkpoint = torch.load("{}/{}.pt".format(download_root, model), + map_location="cpu") multilingual = checkpoint["dims"]['n_vocab'] >= 51865 num_languages = checkpoint["dims"]['n_vocab'] - 51765 - int(multilingual) - tokenizer = WhisperTokenizer(multilingual, num_languages=num_languages, - language=language, task=task) + tokenizer = WhisperTokenizer(multilingual, + num_languages=num_languages, + language=language, + task=task) tokenizer._build_tiktoken() convert_to_wenet_state_dict( checkpoint["model_state_dict"], - os.path.join(download_root, 'wenet_whisper.pt') - ) - convert_to_wenet_units( - tokenizer.tokenizer, - os.path.join(download_root, 'units.txt') - ) - convert_to_wenet_yaml( - tokenizer.tokenizer, checkpoint["dims"], - os.path.join(download_root, 'train.yaml') - ) + os.path.join(download_root, 'wenet_whisper.pt')) + convert_to_wenet_units(tokenizer.tokenizer, + os.path.join(download_root, 'units.txt')) + convert_to_wenet_yaml(tokenizer.tokenizer, checkpoint["dims"], + os.path.join(download_root, 'train.yaml')) with open("{}/train.yaml".format(download_root), 'r') as fin: configs = yaml.load(fin, Loader=yaml.FullLoader) configs['cmvn_file'] = None @@ -137,8 +142,9 @@ def test_model(model, audio_path): _, dummy_tokens = tokenizer.tokenize("WeNet x OpenAI") # 3. Forward whisper.encoder - mel1 = whisper.log_mel_spectrogram( - audio_path, whisper_model.dims.n_mels, padding=N_SAMPLES).unsqueeze(0) + mel1 = whisper.log_mel_spectrogram(audio_path, + whisper_model.dims.n_mels, + padding=N_SAMPLES).unsqueeze(0) whisper_mel = pad_or_trim(mel1, N_FRAMES, axis=-1) x = F.gelu(whisper_model.encoder.conv1(whisper_mel)) x = F.gelu(whisper_model.encoder.conv2(x)) @@ -149,139 +155,222 @@ def test_model(model, audio_path): for i, layer in enumerate(whisper_model.encoder.blocks): prev_x = x.clone() attn_ln_x = layer.attn_ln(x) - whisper_layers_ouput.append({"name": "enc.layer{}.attn_ln".format(i), - "value": attn_ln_x.clone()}) + whisper_layers_ouput.append({ + "name": "enc.layer{}.attn_ln".format(i), + "value": attn_ln_x.clone() + }) attn_x = layer.attn(attn_ln_x, mask=None, kv_cache=None)[0] - whisper_layers_ouput.append({"name": "enc.layer{}.attn".format(i), - "value": attn_x.clone()}) + whisper_layers_ouput.append({ + "name": "enc.layer{}.attn".format(i), + "value": attn_x.clone() + }) x = x + attn_x - whisper_layers_ouput.append({"name": "enc.layer{}.attn_residul".format(i), - "value": x.clone()}) + whisper_layers_ouput.append({ + "name": + "enc.layer{}.attn_residul".format(i), + "value": + x.clone() + }) mlp_ln_x = layer.mlp_ln(x) - whisper_layers_ouput.append({"name": "enc.layer{}.mlp_ln".format(i), - "value": mlp_ln_x.clone()}) + whisper_layers_ouput.append({ + "name": "enc.layer{}.mlp_ln".format(i), + "value": mlp_ln_x.clone() + }) mlp_x = layer.mlp(mlp_ln_x) - whisper_layers_ouput.append({"name": "enc.layer{}.mlp".format(i), - "value": mlp_x.clone()}) + whisper_layers_ouput.append({ + "name": "enc.layer{}.mlp".format(i), + "value": mlp_x.clone() + }) x = x + mlp_x - whisper_layers_ouput.append({"name": "enc.layer{}.mlp_residul".format(i), - "value": x.clone()}) + whisper_layers_ouput.append({ + "name": + "enc.layer{}.mlp_residul".format(i), + "value": + x.clone() + }) np.testing.assert_allclose(x.numpy(), layer(prev_x).numpy(), - rtol=1e-7, atol=1e-10) + rtol=1e-7, + atol=1e-10) whisper_encoder_out = whisper_model.encoder.ln_post(x) np.testing.assert_allclose(whisper_encoder_out.numpy(), whisper_model.encoder(whisper_mel).numpy(), - rtol=1e-7, atol=1e-10) + rtol=1e-7, + atol=1e-10) # 4. Forward whisper.decoder - whisper_tokens = torch.tensor(list(tokenizer.tokenizer.sot_sequence) - + [tokenizer.tokenizer.no_timestamps] - + dummy_tokens, - dtype=torch.long).unsqueeze(0) # (B=1, 9) - whisper_decoder_embed = whisper_model.decoder.token_embedding(whisper_tokens) - whisper_decoder_pos = whisper_model.decoder.positional_embedding[ - :whisper_decoder_embed.shape[1], :].unsqueeze(0) # (B=1, 9, d_model) + whisper_tokens = torch.tensor( + list(tokenizer.tokenizer.sot_sequence) + + [tokenizer.tokenizer.no_timestamps] + dummy_tokens, + dtype=torch.long).unsqueeze(0) # (B=1, 9) + whisper_decoder_embed = whisper_model.decoder.token_embedding( + whisper_tokens) + pos_func = whisper_model.decoder.positional_embedding + whisper_decoder_pos = pos_func[:whisper_decoder_embed. + shape[1], :].unsqueeze(0) whisper_decoder_embed_posed = whisper_decoder_embed + whisper_decoder_pos x = whisper_decoder_embed_posed.clone() for i, layer in enumerate(whisper_model.decoder.blocks): prev_x = x.clone() attn_ln_x = layer.attn_ln(x) - whisper_layers_ouput.append({"name": "dec.layer{}.attn_ln".format(i), - "value": attn_ln_x.clone()}) - attn_x = layer.attn(attn_ln_x, mask=whisper_model.decoder.mask, + whisper_layers_ouput.append({ + "name": "dec.layer{}.attn_ln".format(i), + "value": attn_ln_x.clone() + }) + attn_x = layer.attn(attn_ln_x, + mask=whisper_model.decoder.mask, kv_cache=None)[0] - whisper_layers_ouput.append({"name": "dec.layer{}.attn".format(i), - "value": attn_x.clone()}) + whisper_layers_ouput.append({ + "name": "dec.layer{}.attn".format(i), + "value": attn_x.clone() + }) x = x + attn_x - whisper_layers_ouput.append({"name": "dec.layer{}.attn_residul".format(i), - "value": x.clone()}) + whisper_layers_ouput.append({ + "name": + "dec.layer{}.attn_residul".format(i), + "value": + x.clone() + }) cross_attn_ln_x = layer.cross_attn_ln(x) - whisper_layers_ouput.append({"name": "dec.layer{}.cross_attn_ln".format(i), - "value": cross_attn_ln_x.clone()}) - cross_attn_x = layer.cross_attn(cross_attn_ln_x, whisper_encoder_out, - mask=None, kv_cache=None)[0] - whisper_layers_ouput.append({"name": "dec.layer{}.cross_attn".format(i), - "value": cross_attn_x.clone()}) + whisper_layers_ouput.append({ + "name": + "dec.layer{}.cross_attn_ln".format(i), + "value": + cross_attn_ln_x.clone() + }) + cross_attn_x = layer.cross_attn(cross_attn_ln_x, + whisper_encoder_out, + mask=None, + kv_cache=None)[0] + whisper_layers_ouput.append({ + "name": + "dec.layer{}.cross_attn".format(i), + "value": + cross_attn_x.clone() + }) x = x + cross_attn_x - whisper_layers_ouput.append({"name": f"dec.layer{i}.cross_attn_residul", - "value": x.clone()}) + whisper_layers_ouput.append({ + "name": f"dec.layer{i}.cross_attn_residul", + "value": x.clone() + }) mlp_ln_x = layer.mlp_ln(x) - whisper_layers_ouput.append({"name": "dec.layer{}.mlp_ln".format(i), - "value": mlp_ln_x.clone()}) + whisper_layers_ouput.append({ + "name": "dec.layer{}.mlp_ln".format(i), + "value": mlp_ln_x.clone() + }) mlp_x = layer.mlp(mlp_ln_x) - whisper_layers_ouput.append({"name": "dec.layer{}.mlp".format(i), - "value": mlp_x.clone()}) + whisper_layers_ouput.append({ + "name": "dec.layer{}.mlp".format(i), + "value": mlp_x.clone() + }) x = x + mlp_x - whisper_layers_ouput.append({"name": "dec.layer{}.mlp_residul".format(i), - "value": x.clone()}) + whisper_layers_ouput.append({ + "name": + "dec.layer{}.mlp_residul".format(i), + "value": + x.clone() + }) np.testing.assert_allclose(x.numpy(), - layer(prev_x, whisper_encoder_out, + layer(prev_x, + whisper_encoder_out, mask=whisper_model.decoder.mask, kv_cache=None).numpy(), - rtol=1e-7, atol=1e-10) + rtol=1e-7, + atol=1e-10) x = whisper_model.decoder.ln(x) - whisper_logits = ( - x @ torch.transpose(whisper_model.decoder.token_embedding.weight, 0, 1) - ) - np.testing.assert_allclose( - whisper_logits.numpy(), - whisper_model.decoder(whisper_tokens, whisper_encoder_out).numpy(), - rtol=1e-7, atol=1e-10) + whisper_logits = (x @ torch.transpose( + whisper_model.decoder.token_embedding.weight, 0, 1)) + np.testing.assert_allclose(whisper_logits.numpy(), + whisper_model.decoder( + whisper_tokens, + whisper_encoder_out).numpy(), + rtol=1e-7, + atol=1e-10) # 5. Forward wenet.encoder waveform, sample_rate = torchaudio.load(audio_path) - sample = {"wav": waveform, "sample_rate": sample_rate, - "key": audio_path, "label": ""} - mel2 = next(compute_log_mel_spectrogram( - [sample], n_fft=N_FFT, hop_length=HOP_LENGTH, - num_mel_bins=whisper_model.dims.n_mels, padding=N_SAMPLES - ))["feat"].unsqueeze(0) + sample = { + "wav": waveform, + "sample_rate": sample_rate, + "key": audio_path, + "label": "" + } + mel2 = next( + compute_log_mel_spectrogram( + [sample], + n_fft=N_FFT, + hop_length=HOP_LENGTH, + num_mel_bins=whisper_model.dims.n_mels, + padding=N_SAMPLES))["feat"].unsqueeze(0) wenet_mel = pad_or_trim(mel2, N_FRAMES, axis=-2) T = wenet_mel.size(1) - masks = ~make_pad_mask( - torch.tensor([T], dtype=torch.long), T).unsqueeze(1) # (B=1, 1, T) - wenet_embed, pos_emb, masks = wenet_model.encoder.embed(wenet_mel, masks) + masks = ~make_pad_mask(torch.tensor([T], dtype=torch.long), + T).unsqueeze(1) # (B=1, 1, T) + wenet_embed, pos_emb, masks = wenet_model.encoder.embed( + wenet_mel, masks) wenet_subed = wenet_embed - pos_emb x = wenet_embed wenet_layers_output = [] for i, layer in enumerate(wenet_model.encoder.encoders): prev_x = x attn_ln_x = layer.norm1(x) - wenet_layers_output.append({"name": "enc.layer{}.attn_ln".format(i), - "value": attn_ln_x.clone()}) - x_att, _ = layer.self_attn( - attn_ln_x, attn_ln_x, attn_ln_x, - masks, cache=torch.zeros((0, 0, 0, 0))) - wenet_layers_output.append({"name": "enc.layer{}.attn".format(i), - "value": x_att.clone()}) + wenet_layers_output.append({ + "name": "enc.layer{}.attn_ln".format(i), + "value": attn_ln_x.clone() + }) + x_att, _ = layer.self_attn(attn_ln_x, + attn_ln_x, + attn_ln_x, + masks, + cache=torch.zeros((0, 0, 0, 0))) + wenet_layers_output.append({ + "name": "enc.layer{}.attn".format(i), + "value": x_att.clone() + }) x = x + x_att - wenet_layers_output.append({"name": "enc.layer{}.attn_residul".format(i), - "value": x.clone()}) + wenet_layers_output.append({ + "name": + "enc.layer{}.attn_residul".format(i), + "value": + x.clone() + }) mlp_ln_x = layer.norm2(x) - wenet_layers_output.append({"name": "enc.layer{}.mlp_ln".format(i), - "value": mlp_ln_x.clone()}) + wenet_layers_output.append({ + "name": "enc.layer{}.mlp_ln".format(i), + "value": mlp_ln_x.clone() + }) mlp_x = layer.feed_forward(mlp_ln_x) - wenet_layers_output.append({"name": "enc.layer{}.mlp".format(i), - "value": mlp_x.clone()}) + wenet_layers_output.append({ + "name": "enc.layer{}.mlp".format(i), + "value": mlp_x.clone() + }) x = x + mlp_x - wenet_layers_output.append({"name": "enc.layer{}.mlp_residul".format(i), - "value": x.clone()}) + wenet_layers_output.append({ + "name": + "enc.layer{}.mlp_residul".format(i), + "value": + x.clone() + }) np.testing.assert_allclose(x.numpy(), - layer(prev_x, masks, pos_emb, masks)[0].numpy(), - rtol=1e-7, atol=1e-10) + layer(prev_x, masks, pos_emb, + masks)[0].numpy(), + rtol=1e-7, + atol=1e-10) wenet_encoder_out = wenet_model.encoder.after_norm(x) # 6. Forward wenet.decoder wenet_tokens, _ = add_whisper_tokens( configs['model_conf']['special_tokens'], - torch.tensor([dummy_tokens], dtype=torch.long), ignore_id=-1, - task=task, no_timestamp=True, language=language, use_prev=False - ) + torch.tensor([dummy_tokens], dtype=torch.long), + ignore_id=-1, + task=task, + no_timestamp=True, + language=language, + use_prev=False) L = wenet_tokens.size(1) - tgt_mask = ~make_pad_mask( - torch.tensor([L], dtype=torch.long), L).unsqueeze(1) # (B=1, 1, L) + tgt_mask = ~make_pad_mask(torch.tensor([L], dtype=torch.long), + L).unsqueeze(1) # (B=1, 1, L) m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) # (B=1, L, L) tgt_mask = tgt_mask & m # (B=1, L, L) @@ -293,68 +382,113 @@ def test_model(model, audio_path): prev_x = x.clone() assert layer.normalize_before attn_ln_x = layer.norm1(x) - wenet_layers_output.append({"name": "dec.layer{}.attn_ln".format(i), - "value": attn_ln_x.clone()}) - attn_x = layer.self_attn(attn_ln_x, attn_ln_x, attn_ln_x, tgt_mask)[0] - wenet_layers_output.append({"name": "dec.layer{}.attn".format(i), - "value": attn_x.clone()}) + wenet_layers_output.append({ + "name": "dec.layer{}.attn_ln".format(i), + "value": attn_ln_x.clone() + }) + attn_x = layer.self_attn(attn_ln_x, attn_ln_x, attn_ln_x, + tgt_mask)[0] + wenet_layers_output.append({ + "name": "dec.layer{}.attn".format(i), + "value": attn_x.clone() + }) x = x + attn_x - wenet_layers_output.append({"name": "dec.layer{}.attn_residul".format(i), - "value": x.clone()}) + wenet_layers_output.append({ + "name": + "dec.layer{}.attn_residul".format(i), + "value": + x.clone() + }) assert layer.src_attn is not None assert layer.normalize_before cross_attn_ln_x = layer.norm2(x) - wenet_layers_output.append({"name": "dec.layer{}.cross_attn_ln".format(i), - "value": cross_attn_ln_x.clone()}) - cross_attn_x = layer.src_attn( - cross_attn_ln_x, wenet_encoder_out, wenet_encoder_out, masks)[0] - wenet_layers_output.append({"name": "dec.layer{}.cross_attn".format(i), - "value": cross_attn_x.clone()}) + wenet_layers_output.append({ + "name": + "dec.layer{}.cross_attn_ln".format(i), + "value": + cross_attn_ln_x.clone() + }) + cross_attn_x = layer.src_attn(cross_attn_ln_x, wenet_encoder_out, + wenet_encoder_out, masks)[0] + wenet_layers_output.append({ + "name": + "dec.layer{}.cross_attn".format(i), + "value": + cross_attn_x.clone() + }) x = x + cross_attn_x - wenet_layers_output.append({"name": f"dec.layer{i}.cross_attn_residul", - "value": x.clone()}) + wenet_layers_output.append({ + "name": f"dec.layer{i}.cross_attn_residul", + "value": x.clone() + }) assert layer.normalize_before mlp_ln_x = layer.norm3(x) - wenet_layers_output.append({"name": "dec.layer{}.mlp_ln".format(i), - "value": mlp_ln_x.clone()}) + wenet_layers_output.append({ + "name": "dec.layer{}.mlp_ln".format(i), + "value": mlp_ln_x.clone() + }) mlp_x = layer.feed_forward(mlp_ln_x) - wenet_layers_output.append({"name": "dec.layer{}.mlp".format(i), - "value": mlp_x.clone()}) + wenet_layers_output.append({ + "name": "dec.layer{}.mlp".format(i), + "value": mlp_x.clone() + }) x = x + mlp_x - wenet_layers_output.append({"name": "dec.layer{}.mlp_residul".format(i), - "value": x.clone()}) - np.testing.assert_allclose( - x.numpy(), layer(prev_x, tgt_mask, wenet_encoder_out, masks)[0].numpy(), - rtol=1e-7, atol=1e-10) + wenet_layers_output.append({ + "name": + "dec.layer{}.mlp_residul".format(i), + "value": + x.clone() + }) + np.testing.assert_allclose(x.numpy(), + layer(prev_x, tgt_mask, + wenet_encoder_out, + masks)[0].numpy(), + rtol=1e-7, + atol=1e-10) assert wenet_model.decoder.normalize_before x = wenet_model.decoder.after_norm(x) assert wenet_model.decoder.use_output_layer x = wenet_model.decoder.output_layer(x) wenet_logits = x - np.testing.assert_allclose(whisper_mel.numpy(), wenet_mel.transpose(1, 2).numpy(), - rtol=1e-7, atol=1e-10) - np.testing.assert_allclose(whisper_model.encoder.positional_embedding.numpy(), - pos_emb.squeeze(0).numpy(), - rtol=1e-7, atol=1e-10) - np.testing.assert_allclose(whisper_subed.numpy(), wenet_subed.numpy(), - rtol=1e-7, atol=3e-7) - np.testing.assert_allclose(whisper_embed.numpy(), wenet_embed.numpy(), - rtol=1e-7, atol=1e-10) + np.testing.assert_allclose(whisper_mel.numpy(), + wenet_mel.transpose(1, 2).numpy(), + rtol=1e-7, + atol=1e-10) + np.testing.assert_allclose( + whisper_model.encoder.positional_embedding.numpy(), + pos_emb.squeeze(0).numpy(), + rtol=1e-7, + atol=1e-10) + np.testing.assert_allclose(whisper_subed.numpy(), + wenet_subed.numpy(), + rtol=1e-7, + atol=3e-7) + np.testing.assert_allclose(whisper_embed.numpy(), + wenet_embed.numpy(), + rtol=1e-7, + atol=1e-10) for i, (whisper_layer, wenet_layer) in enumerate( - zip(whisper_layers_ouput, wenet_layers_output) - ): + zip(whisper_layers_ouput, wenet_layers_output)): assert whisper_layer["name"] == wenet_layer["name"] print("check layer {}".format(whisper_layer["name"])) np.testing.assert_allclose(whisper_layer["value"].numpy(), wenet_layer["value"].numpy(), - rtol=1e-7, atol=6e-3) - np.testing.assert_allclose(whisper_encoder_out.numpy(), wenet_encoder_out.numpy(), - rtol=1e-7, atol=6e-03) - np.testing.assert_allclose(whisper_tokens.numpy(), wenet_tokens.numpy(), - rtol=1e-7, atol=1e-10) - np.testing.assert_allclose(whisper_logits.numpy(), wenet_logits.numpy(), - rtol=1e-7, atol=6e-02) + rtol=1e-7, + atol=6e-3) + np.testing.assert_allclose(whisper_encoder_out.numpy(), + wenet_encoder_out.numpy(), + rtol=1e-7, + atol=6e-03) + np.testing.assert_allclose(whisper_tokens.numpy(), + wenet_tokens.numpy(), + rtol=1e-7, + atol=1e-10) + np.testing.assert_allclose(whisper_logits.numpy(), + wenet_logits.numpy(), + rtol=1e-7, + atol=6e-02) np.testing.assert_allclose(F.softmax(whisper_logits).numpy(), F.softmax(wenet_logits).numpy(), - rtol=1e-7, atol=1e-10) + rtol=1e-7, + atol=1e-10) diff --git a/tools/analyze_dataset.py b/tools/analyze_dataset.py index d4373b065..f5b8fab4b 100755 --- a/tools/analyze_dataset.py +++ b/tools/analyze_dataset.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """ Analyze Dataset, Duration/TextLength/Speed etc. @@ -47,16 +46,19 @@ def get_args(): default='wav_scp', choices=['wav_scp', 'raw', 'shard'], help='dataset type') - parser.add_argument('--output_dir', type=str, - default="exp", help='write info to output dir') - parser.add_argument('--data_list', default=None, + parser.add_argument('--output_dir', + type=str, + default="exp", + help='write info to output dir') + parser.add_argument('--data_list', + default=None, help='used in raw/shard mode') - parser.add_argument('--wav_scp', default=None, - help='used in wav_scp mode') - parser.add_argument('--text', default=None, - help='used in wav_scp mode') - parser.add_argument('--num_thread', type=int, - default=4, help='number of threads') + parser.add_argument('--wav_scp', default=None, help='used in wav_scp mode') + parser.add_argument('--text', default=None, help='used in wav_scp mode') + parser.add_argument('--num_thread', + type=int, + default=4, + help='number of threads') args = parser.parse_args() print(args) return args @@ -109,8 +111,7 @@ def read_tar(file): data['txt'] = file_obj.read().decode( 'utf8').strip() elif postfix in AUDIO_FORMAT_SETS: - waveform, sample_rate = torchaudio.load( - file_obj) + waveform, sample_rate = torchaudio.load(file_obj) # single channel data['wav'] = waveform.numpy()[0, :] data['sample_rate'] = sample_rate @@ -118,16 +119,16 @@ def read_tar(file): data[postfix] = file_obj.read() except Exception as ex: valid = False - logging.warning( - 'error: {} when parse {}'.format(ex, name)) + logging.warning('error: {} when parse {}'.format( + ex, name)) prev_prefix = prefix # The last data in tar if prev_prefix is not None: data['key'] = prev_prefix yield data except Exception as ex: - logging.warning( - 'tar_file error: {} when processing {}'.format(ex, file)) + logging.warning('tar_file error: {} when processing {}'.format( + ex, file)) def main(): @@ -173,16 +174,18 @@ def main(): # partition for i, (key1, key2) in enumerate(zip(wavs, texts)): assert key1 == key2 - datas[i % args.num_thread].append( - {'key': key1, "wav": wavs[key1], "txt": texts[key1]} - ) + datas[i % args.num_thread].append({ + 'key': key1, + "wav": wavs[key1], + "txt": texts[key1] + }) logging.info("Stage-2: Start Analyze") # threads pool = multiprocessing.Pool(processes=args.num_thread) for i in range(args.num_thread): - output_file = os.path.join( - args.output_dir, "partition", "part-{}".format(i)) + output_file = os.path.join(args.output_dir, "partition", + "part-{}".format(i)) pool.apply_async(analyze, (datas[i], output_file, i)) pool.close() pool.join() @@ -190,8 +193,8 @@ def main(): logging.info("Stage-3: Sort and Write Result") datas = [] for i in range(args.num_thread): - output_file = os.path.join( - args.output_dir, "partition", "part-{}".format(i)) + output_file = os.path.join(args.output_dir, "partition", + "part-{}".format(i)) with open(output_file, "r", encoding='utf8') as f: for line in f.readlines(): data = json.loads(line) @@ -201,26 +204,36 @@ def main(): total_leading_sil = sum([x['leading_sil'] for x in datas]) total_trailing_sil = sum([x['trailing_sil'] for x in datas]) num_datas = len(datas) - names = ['key', 'dur', 'txt_length', 'speed', - 'leading_sil', 'trailing_sil'] + names = [ + 'key', 'dur', 'txt_length', 'speed', 'leading_sil', 'trailing_sil' + ] units = ['', 's', '', 'char/s', 'ms', 'ms'] - avgs = [0, total_dur / num_datas, total_len / num_datas, - total_len / total_dur, total_leading_sil / num_datas, - total_trailing_sil / num_datas] - stds = [0, sum([(x['dur'] - avgs[1])**2 for x in datas]), - sum([(x['txt_length'] - avgs[2])**2 for x in datas]), - sum([(x['txt_length'] / x['dur'] - avgs[3])**2 for x in datas]), - sum([(x['leading_sil'] - avgs[4])**2 for x in datas]), - sum([(x['trailing_sil'] - avgs[5])**2 for x in datas])] + avgs = [ + 0, total_dur / num_datas, total_len / num_datas, total_len / total_dur, + total_leading_sil / num_datas, total_trailing_sil / num_datas + ] + stds = [ + 0, + sum([(x['dur'] - avgs[1])**2 for x in datas]), + sum([(x['txt_length'] - avgs[2])**2 for x in datas]), + sum([(x['txt_length'] / x['dur'] - avgs[3])**2 for x in datas]), + sum([(x['leading_sil'] - avgs[4])**2 for x in datas]), + sum([(x['trailing_sil'] - avgs[5])**2 for x in datas]) + ] stds = [math.sqrt(x / num_datas) for x in stds] parts = ['max', 'P99', 'P75', 'P50', 'P25', 'min'] - index = [num_datas - 1, int(num_datas * 0.99), int(num_datas * 0.75), - int(num_datas * 0.50), int(num_datas * 0.25), 0] + index = [ + num_datas - 1, + int(num_datas * 0.99), + int(num_datas * 0.75), + int(num_datas * 0.50), + int(num_datas * 0.25), 0 + ] - with open(args.output_dir + "/analyze_result_brief", - "w", encoding='utf8') as f: - for i, (name, unit, avg, std) in enumerate( - zip(names, units, avgs, stds)): + with open(args.output_dir + "/analyze_result_brief", "w", + encoding='utf8') as f: + for i, (name, unit, avg, + std) in enumerate(zip(names, units, avgs, stds)): if name == 'key': continue f.write("==================\n") @@ -229,10 +242,8 @@ def main(): for p, j in zip(parts, index): f.write("{} {}: {:.3f} {} (wav_id: {})\n".format( p, name, datas[j][name], unit, datas[j]['key'])) - f.write("avg {}: {:.3f} {}\n".format( - name, avg, unit)) - f.write("std {}: {:.3f}\n".format( - name, std)) + f.write("avg {}: {:.3f} {}\n".format(name, avg, unit)) + f.write("std {}: {:.3f}\n".format(name, std)) os.system("cat {}".format(args.output_dir + "/analyze_result_brief")) datas.sort(key=lambda x: x['dur']) diff --git a/tools/cmvn_kaldi2json.py b/tools/cmvn_kaldi2json.py index 9966046c9..35adf2bbd 100755 --- a/tools/cmvn_kaldi2json.py +++ b/tools/cmvn_kaldi2json.py @@ -4,6 +4,7 @@ import sys import json + def kaldi2json(kaldi_cmvn_file): means = [] variance = [] @@ -26,11 +27,10 @@ def kaldi2json(kaldi_cmvn_file): for i in range(feat_dim + 2, 2 * feat_dim + 2): variance.append(float(arr[i])) - cmvn_info = {'mean_stat:' : means, - 'var_stat' : variance, - 'frame_num' : count} + cmvn_info = {'mean_stat:': means, 'var_stat': variance, 'frame_num': count} return cmvn_info + if __name__ == '__main__': with open(sys.argv[2], 'w') as fout: cmvn = kaldi2json(sys.argv[1]) diff --git a/tools/compute-cer.py b/tools/compute-cer.py index a0a8f8fe1..c49ff88d9 100755 --- a/tools/compute-cer.py +++ b/tools/compute-cer.py @@ -1,18 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- - import sys import unicodedata import codecs remove_tag = True spacelist = [' ', '\t', '\r', '\n'] -puncts = ['!', ',', '?', - '、', '。', '!', ',', ';', '?', - ':', '「', '」', '︰', '『', '』', '《', '》'] +puncts = [ + '!', ',', '?', '、', '。', '!', ',', ';', '?', ':', '「', '」', '︰', '『', '』', + '《', '》' +] + -def characterize(string) : +def characterize(string): res = [] i = 0 while i < len(string): @@ -45,6 +46,7 @@ def characterize(string) : i = j return res + def stripoff_tags(x): if not x: return '' @@ -85,8 +87,10 @@ def normalize(sentence, ignore_words, cs, split=None): new_sentence.append(x) return new_sentence -class Calculator : - def __init__(self) : + +class Calculator: + + def __init__(self): self.data = {} self.space = [] self.cost = {} @@ -95,69 +99,86 @@ def __init__(self) : self.cost['del'] = 1 self.cost['ins'] = 1 - def calculate(self, lab, rec) : + def calculate(self, lab, rec): # Initialization lab.insert(0, '') rec.insert(0, '') - while len(self.space) < len(lab) : + while len(self.space) < len(lab): self.space.append([]) - for row in self.space : - for element in row : + for row in self.space: + for element in row: element['dist'] = 0 element['error'] = 'non' - while len(row) < len(rec) : - row.append({'dist' : 0, 'error' : 'non'}) - for i in range(len(lab)) : + while len(row) < len(rec): + row.append({'dist': 0, 'error': 'non'}) + for i in range(len(lab)): self.space[i][0]['dist'] = i self.space[i][0]['error'] = 'del' - for j in range(len(rec)) : + for j in range(len(rec)): self.space[0][j]['dist'] = j self.space[0][j]['error'] = 'ins' self.space[0][0]['error'] = 'non' - for token in lab : - if token not in self.data and len(token) > 0 : - self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, - 'ins' : 0, 'del' : 0} - for token in rec : - if token not in self.data and len(token) > 0 : - self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, - 'ins' : 0, 'del' : 0} + for token in lab: + if token not in self.data and len(token) > 0: + self.data[token] = { + 'all': 0, + 'cor': 0, + 'sub': 0, + 'ins': 0, + 'del': 0 + } + for token in rec: + if token not in self.data and len(token) > 0: + self.data[token] = { + 'all': 0, + 'cor': 0, + 'sub': 0, + 'ins': 0, + 'del': 0 + } # Computing edit distance - for i, lab_token in enumerate(lab) : - for j, rec_token in enumerate(rec) : - if i == 0 or j == 0 : + for i, lab_token in enumerate(lab): + for j, rec_token in enumerate(rec): + if i == 0 or j == 0: continue min_dist = sys.maxsize min_error = 'none' dist = self.space[i - 1][j]['dist'] + self.cost['del'] error = 'del' - if dist < min_dist : + if dist < min_dist: min_dist = dist min_error = error dist = self.space[i][j - 1]['dist'] + self.cost['ins'] error = 'ins' - if dist < min_dist : + if dist < min_dist: min_dist = dist min_error = error - if lab_token == rec_token : + if lab_token == rec_token: dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor'] error = 'cor' - else : + else: dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub'] error = 'sub' - if dist < min_dist : + if dist < min_dist: min_dist = dist min_error = error self.space[i][j]['dist'] = min_dist self.space[i][j]['error'] = min_error # Tracing back - result = {'lab': [], 'rec': [], 'all': 0, 'cor': 0, 'sub': 0, - 'ins': 0, 'del': 0} + result = { + 'lab': [], + 'rec': [], + 'all': 0, + 'cor': 0, + 'sub': 0, + 'ins': 0, + 'del': 0 + } i = len(lab) - 1 j = len(rec) - 1 - while True : - if self.space[i][j]['error'] == 'cor' : # correct - if len(lab[i]) > 0 : + while True: + if self.space[i][j]['error'] == 'cor': # correct + if len(lab[i]) > 0: self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1 result['all'] = result['all'] + 1 @@ -166,8 +187,8 @@ def calculate(self, lab, rec) : result['rec'].insert(0, rec[j]) i = i - 1 j = j - 1 - elif self.space[i][j]['error'] == 'sub' : # substitution - if len(lab[i]) > 0 : + elif self.space[i][j]['error'] == 'sub': # substitution + if len(lab[i]) > 0: self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1 result['all'] = result['all'] + 1 @@ -176,8 +197,8 @@ def calculate(self, lab, rec) : result['rec'].insert(0, rec[j]) i = i - 1 j = j - 1 - elif self.space[i][j]['error'] == 'del' : # deletion - if len(lab[i]) > 0 : + elif self.space[i][j]['error'] == 'del': # deletion + if len(lab[i]) > 0: self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1 result['all'] = result['all'] + 1 @@ -185,24 +206,25 @@ def calculate(self, lab, rec) : result['lab'].insert(0, lab[i]) result['rec'].insert(0, "") i = i - 1 - elif self.space[i][j]['error'] == 'ins' : # insertion - if len(rec[j]) > 0 : + elif self.space[i][j]['error'] == 'ins': # insertion + if len(rec[j]) > 0: self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1 result['ins'] = result['ins'] + 1 result['lab'].insert(0, "") result['rec'].insert(0, rec[j]) j = j - 1 - elif self.space[i][j]['error'] == 'non' : # starting point + elif self.space[i][j]['error'] == 'non': # starting point break - else : # shouldn't reach here + else: # shouldn't reach here print('this should not happen , i={i} , j={j} , \ - error={error}'. - format(i=i, j=j, error=self.space[i][j]['error'])) + error={error}'.format(i=i, + j=j, + error=self.space[i][j]['error'])) return result - def overall(self) : + def overall(self): result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} - for token in self.data : + for token in self.data: result['all'] = result['all'] + self.data[token]['all'] result['cor'] = result['cor'] + self.data[token]['cor'] result['sub'] = result['sub'] + self.data[token]['sub'] @@ -210,10 +232,10 @@ def overall(self) : result['del'] = result['del'] + self.data[token]['del'] return result - def cluster(self, data) : + def cluster(self, data): result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} - for token in data : - if token in self.data : + for token in data: + if token in self.data: result['all'] = result['all'] + self.data[token]['all'] result['cor'] = result['cor'] + self.data[token]['cor'] result['sub'] = result['sub'] + self.data[token]['sub'] @@ -221,60 +243,64 @@ def cluster(self, data) : result['del'] = result['del'] + self.data[token]['del'] return result - def keys(self) : + def keys(self): return list(self.data.keys()) + def width(string): return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string) -def default_cluster(word) : + +def default_cluster(word): unicode_names = [unicodedata.name(char) for char in word] - for i in reversed(range(len(unicode_names))) : - if unicode_names[i].startswith('DIGIT') : # 1 + for i in reversed(range(len(unicode_names))): + if unicode_names[i].startswith('DIGIT'): # 1 unicode_names[i] = 'Number' # 'DIGIT' - elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH') or - unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')) : + elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH') + or unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')): # 明 / 郎 unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH' - elif (unicode_names[i].startswith('LATIN CAPITAL LETTER') or - unicode_names[i].startswith('LATIN SMALL LETTER')) : + elif (unicode_names[i].startswith('LATIN CAPITAL LETTER') + or unicode_names[i].startswith('LATIN SMALL LETTER')): # A / a unicode_names[i] = 'English' # 'LATIN LETTER' - elif unicode_names[i].startswith('HIRAGANA LETTER') : # は こ め + elif unicode_names[i].startswith('HIRAGANA LETTER'): # は こ め unicode_names[i] = 'Japanese' # 'GANA LETTER' - elif (unicode_names[i].startswith('AMPERSAND') or - unicode_names[i].startswith('APOSTROPHE') or - unicode_names[i].startswith('COMMERCIAL AT') or - unicode_names[i].startswith('DEGREE CELSIUS') or - unicode_names[i].startswith('EQUALS SIGN') or - unicode_names[i].startswith('FULL STOP') or - unicode_names[i].startswith('HYPHEN-MINUS') or - unicode_names[i].startswith('LOW LINE') or - unicode_names[i].startswith('NUMBER SIGN') or - unicode_names[i].startswith('PLUS SIGN') or - unicode_names[i].startswith('SEMICOLON')) : + elif (unicode_names[i].startswith('AMPERSAND') + or unicode_names[i].startswith('APOSTROPHE') + or unicode_names[i].startswith('COMMERCIAL AT') + or unicode_names[i].startswith('DEGREE CELSIUS') + or unicode_names[i].startswith('EQUALS SIGN') + or unicode_names[i].startswith('FULL STOP') + or unicode_names[i].startswith('HYPHEN-MINUS') + or unicode_names[i].startswith('LOW LINE') + or unicode_names[i].startswith('NUMBER SIGN') + or unicode_names[i].startswith('PLUS SIGN') + or unicode_names[i].startswith('SEMICOLON')): # & / ' / @ / ℃ / = / . / - / _ / # / + / ; del unicode_names[i] - else : + else: return 'Other' - if len(unicode_names) == 0 : + if len(unicode_names) == 0: return 'Other' - if len(unicode_names) == 1 : + if len(unicode_names) == 1: return unicode_names[0] - for i in range(len(unicode_names) - 1) : - if unicode_names[i] != unicode_names[i + 1] : + for i in range(len(unicode_names) - 1): + if unicode_names[i] != unicode_names[i + 1]: return 'Other' return unicode_names[0] -def usage() : + +def usage(): print("compute-wer.py : compute word error rate (WER) \ and align recognition results and references.") print(" usage : python compute-wer.py [--cs={0,1}] \ [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] \ [--padding-symbol={space,underline}] test.ref test.hyp > test.wer") + if __name__ == '__main__': - if len(sys.argv) == 1 : + if len(sys.argv) == 1: usage() sys.exit(0) calculator = Calculator() @@ -390,11 +416,11 @@ def usage() : if len(array) == 0: continue fid = array[0] - rec_set[fid] = normalize(array[1:], ignore_words, - case_sensitive, split) + rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive, + split) # compute error rate on the interaction of reference file and hyp file - for line in open(ref_file, 'r', encoding='utf-8') : + for line in open(ref_file, 'r', encoding='utf-8'): if tochar: array = characterize(line) else: @@ -409,30 +435,30 @@ def usage() : if verbose: print('\nutt: %s' % fid) - for word in rec + lab : - if word not in default_words : + for word in rec + lab: + if word not in default_words: default_cluster_name = default_cluster(word) - if default_cluster_name not in default_clusters : + if default_cluster_name not in default_clusters: default_clusters[default_cluster_name] = {} - if word not in default_clusters[default_cluster_name] : + if word not in default_clusters[default_cluster_name]: default_clusters[default_cluster_name][word] = 1 default_words[word] = default_cluster_name result = calculator.calculate(lab, rec) if verbose: - if result['all'] != 0 : + if result['all'] != 0: wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] - else : + else: wer = 0.0 print('WER: %4.2f %%' % wer, end=' ') print('N=%d C=%d S=%d D=%d I=%d' % - (result['all'], result['cor'], result['sub'], - result['del'], result['ins'])) + (result['all'], result['cor'], result['sub'], result['del'], + result['ins'])) space = {} space['lab'] = [] space['rec'] = [] - for idx in range(len(result['lab'])) : + for idx in range(len(result['lab'])): len_lab = width(result['lab'][idx]) len_rec = width(result['rec'][idx]) length = max(len_lab, len_rec) @@ -450,7 +476,7 @@ def usage() : for idx in range(lab1, lab2): token = result['lab'][idx] print('{token}'.format(token=token), end='') - for n in range(space['lab'][idx]) : + for n in range(space['lab'][idx]): print(padding_symbol, end='') print(' ', end='') print() @@ -462,7 +488,7 @@ def usage() : for idx in range(rec1, rec2): token = result['rec'][idx] print('{token}'.format(token=token), end='') - for n in range(space['rec'][idx]) : + for n in range(space['rec'][idx]): print(padding_symbol, end='') print(' ', end='') print('\n', end='\n') @@ -475,43 +501,44 @@ def usage() : print() result = calculator.overall() - if result['all'] != 0 : + if result['all'] != 0: wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] - else : + else: wer = 0.0 print('Overall -> %4.2f %%' % wer, end=' ') print('N=%d C=%d S=%d D=%d I=%d' % - (result['all'], result['cor'], result['sub'], - result['del'], result['ins'])) + (result['all'], result['cor'], result['sub'], result['del'], + result['ins'])) if not verbose: print() if verbose: - for cluster_id in default_clusters : - result = calculator.cluster(k for k in default_clusters[cluster_id]) - if result['all'] != 0 : + for cluster_id in default_clusters: + result = calculator.cluster(k + for k in default_clusters[cluster_id]) + if result['all'] != 0: wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] - else : + else: wer = 0.0 print('%s -> %4.2f %%' % (cluster_id, wer), end=' ') print('N=%d C=%d S=%d D=%d I=%d' % - (result['all'], result['cor'], result['sub'], - result['del'], result['ins'])) - if len(cluster_file) > 0 : # compute separated WERs for word clusters + (result['all'], result['cor'], result['sub'], result['del'], + result['ins'])) + if len(cluster_file) > 0: # compute separated WERs for word clusters cluster_id = '' cluster = [] - for line in open(cluster_file, 'r', encoding='utf-8') : - for token in line.decode('utf-8').rstrip('\n').split() : + for line in open(cluster_file, 'r', encoding='utf-8'): + for token in line.decode('utf-8').rstrip('\n').split(): # end of cluster reached, like if token[0:2] == '' and \ token.lstrip('') == cluster_id : result = calculator.cluster(cluster) - if result['all'] != 0 : + if result['all'] != 0: wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] - else : + else: wer = 0.0 print('%s -> %4.2f %%' % (cluster_id, wer), end=' ') print('N=%d C=%d S=%d D=%d I=%d' % @@ -520,12 +547,12 @@ def usage() : cluster_id = '' cluster = [] # begin of cluster reached, like - elif (token[0] == '<' and token[len(token) - 1] == '>' and - cluster_id == ''): + elif (token[0] == '<' and token[len(token) - 1] == '>' + and cluster_id == ''): cluster_id = token.lstrip('<').rstrip('>') cluster = [] # general terms, like WEATHER / CAR / ... - else : + else: cluster.append(token) print() print('=======================================' diff --git a/tools/compute-wer.py b/tools/compute-wer.py index a3eefc0dc..e413a2749 100755 --- a/tools/compute-wer.py +++ b/tools/compute-wer.py @@ -1,61 +1,64 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- - import re, sys, unicodedata import codecs remove_tag = True -spacelist= [' ', '\t', '\r', '\n'] -puncts = ['!', ',', '?', - '、', '。', '!', ',', ';', '?', - ':', '「', '」', '︰', '『', '』', '《', '》'] - -def characterize(string) : - res = [] - i = 0 - while i < len(string): - char = string[i] - if char in puncts: - i += 1 - continue - cat1 = unicodedata.category(char) - #https://unicodebook.readthedocs.io/unicode.html#unicode-categories - if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned - i += 1 - continue - if cat1 == 'Lo': # letter-other - res.append(char) - i += 1 - else: - # some input looks like: , we want to separate it to two words. - sep = ' ' - if char == '<': sep = '>' - j = i+1 - while j < len(string): - c = string[j] - if ord(c) >= 128 or (c in spacelist) or (c==sep): - break - j += 1 - if j < len(string) and string[j] == '>': - j += 1 - res.append(string[i:j]) - i = j - return res +spacelist = [' ', '\t', '\r', '\n'] +puncts = [ + '!', ',', '?', '、', '。', '!', ',', ';', '?', ':', '「', '」', '︰', '『', '』', + '《', '》' +] + + +def characterize(string): + res = [] + i = 0 + while i < len(string): + char = string[i] + if char in puncts: + i += 1 + continue + cat1 = unicodedata.category(char) + #https://unicodebook.readthedocs.io/unicode.html#unicode-categories + if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned + i += 1 + continue + if cat1 == 'Lo': # letter-other + res.append(char) + i += 1 + else: + # some input looks like: , we want to separate it to two words. + sep = ' ' + if char == '<': sep = '>' + j = i + 1 + while j < len(string): + c = string[j] + if ord(c) >= 128 or (c in spacelist) or (c == sep): + break + j += 1 + if j < len(string) and string[j] == '>': + j += 1 + res.append(string[i:j]) + i = j + return res + def stripoff_tags(x): - if not x: return '' - chars = [] - i = 0; T=len(x) - while i < T: - if x[i] == '<': - while i < T and x[i] != '>': - i += 1 - i += 1 - else: - chars.append(x[i]) - i += 1 - return ''.join(chars) + if not x: return '' + chars = [] + i = 0 + T = len(x) + while i < T: + if x[i] == '<': + while i < T and x[i] != '>': + i += 1 + i += 1 + else: + chars.append(x[i]) + i += 1 + return ''.join(chars) def normalize(sentence, ignore_words, cs, split=None): @@ -65,436 +68,486 @@ def normalize(sentence, ignore_words, cs, split=None): for token in sentence: x = token if not cs: - x = x.upper() + x = x.upper() if x in ignore_words: - continue + continue if remove_tag: - x = stripoff_tags(x) + x = stripoff_tags(x) if not x: - continue + continue if split and x in split: - new_sentence += split[x] + new_sentence += split[x] else: - new_sentence.append(x) + new_sentence.append(x) return new_sentence -class Calculator : - def __init__(self) : - self.data = {} - self.space = [] - self.cost = {} - self.cost['cor'] = 0 - self.cost['sub'] = 1 - self.cost['del'] = 1 - self.cost['ins'] = 1 - def calculate(self, lab, rec) : - # Initialization - lab.insert(0, '') - rec.insert(0, '') - while len(self.space) < len(lab) : - self.space.append([]) - for row in self.space : - for element in row : - element['dist'] = 0 - element['error'] = 'non' - while len(row) < len(rec) : - row.append({'dist' : 0, 'error' : 'non'}) - for i in range(len(lab)) : - self.space[i][0]['dist'] = i - self.space[i][0]['error'] = 'del' - for j in range(len(rec)) : - self.space[0][j]['dist'] = j - self.space[0][j]['error'] = 'ins' - self.space[0][0]['error'] = 'non' - for token in lab : - if token not in self.data and len(token) > 0 : - self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, 'ins' : 0, 'del' : 0} - for token in rec : - if token not in self.data and len(token) > 0 : - self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, 'ins' : 0, 'del' : 0} - # Computing edit distance - for i, lab_token in enumerate(lab) : - for j, rec_token in enumerate(rec) : - if i == 0 or j == 0 : - continue - min_dist = sys.maxsize - min_error = 'none' - dist = self.space[i-1][j]['dist'] + self.cost['del'] - error = 'del' - if dist < min_dist : - min_dist = dist - min_error = error - dist = self.space[i][j-1]['dist'] + self.cost['ins'] - error = 'ins' - if dist < min_dist : - min_dist = dist - min_error = error - if lab_token == rec_token : - dist = self.space[i-1][j-1]['dist'] + self.cost['cor'] - error = 'cor' - else : - dist = self.space[i-1][j-1]['dist'] + self.cost['sub'] - error = 'sub' - if dist < min_dist : - min_dist = dist - min_error = error - self.space[i][j]['dist'] = min_dist - self.space[i][j]['error'] = min_error - # Tracing back - result = {'lab':[], 'rec':[], 'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0} - i = len(lab) - 1 - j = len(rec) - 1 - while True : - if self.space[i][j]['error'] == 'cor' : # correct - if len(lab[i]) > 0 : - self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 - self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1 - result['all'] = result['all'] + 1 - result['cor'] = result['cor'] + 1 - result['lab'].insert(0, lab[i]) - result['rec'].insert(0, rec[j]) - i = i - 1 - j = j - 1 - elif self.space[i][j]['error'] == 'sub' : # substitution - if len(lab[i]) > 0 : - self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 - self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1 - result['all'] = result['all'] + 1 - result['sub'] = result['sub'] + 1 - result['lab'].insert(0, lab[i]) - result['rec'].insert(0, rec[j]) - i = i - 1 - j = j - 1 - elif self.space[i][j]['error'] == 'del' : # deletion - if len(lab[i]) > 0 : - self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 - self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1 - result['all'] = result['all'] + 1 - result['del'] = result['del'] + 1 - result['lab'].insert(0, lab[i]) - result['rec'].insert(0, "") - i = i - 1 - elif self.space[i][j]['error'] == 'ins' : # insertion - if len(rec[j]) > 0 : - self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1 - result['ins'] = result['ins'] + 1 - result['lab'].insert(0, "") - result['rec'].insert(0, rec[j]) - j = j - 1 - elif self.space[i][j]['error'] == 'non' : # starting point - break - else : # shouldn't reach here - print('this should not happen , i = {i} , j = {j} , error = {error}'.format(i = i, j = j, error = self.space[i][j]['error'])) - return result - def overall(self) : - result = {'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0} - for token in self.data : - result['all'] = result['all'] + self.data[token]['all'] - result['cor'] = result['cor'] + self.data[token]['cor'] - result['sub'] = result['sub'] + self.data[token]['sub'] - result['ins'] = result['ins'] + self.data[token]['ins'] - result['del'] = result['del'] + self.data[token]['del'] - return result - def cluster(self, data) : - result = {'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0} - for token in data : - if token in self.data : - result['all'] = result['all'] + self.data[token]['all'] - result['cor'] = result['cor'] + self.data[token]['cor'] - result['sub'] = result['sub'] + self.data[token]['sub'] - result['ins'] = result['ins'] + self.data[token]['ins'] - result['del'] = result['del'] + self.data[token]['del'] - return result - def keys(self) : - return list(self.data.keys()) + +class Calculator: + + def __init__(self): + self.data = {} + self.space = [] + self.cost = {} + self.cost['cor'] = 0 + self.cost['sub'] = 1 + self.cost['del'] = 1 + self.cost['ins'] = 1 + + def calculate(self, lab, rec): + # Initialization + lab.insert(0, '') + rec.insert(0, '') + while len(self.space) < len(lab): + self.space.append([]) + for row in self.space: + for element in row: + element['dist'] = 0 + element['error'] = 'non' + while len(row) < len(rec): + row.append({'dist': 0, 'error': 'non'}) + for i in range(len(lab)): + self.space[i][0]['dist'] = i + self.space[i][0]['error'] = 'del' + for j in range(len(rec)): + self.space[0][j]['dist'] = j + self.space[0][j]['error'] = 'ins' + self.space[0][0]['error'] = 'non' + for token in lab: + if token not in self.data and len(token) > 0: + self.data[token] = { + 'all': 0, + 'cor': 0, + 'sub': 0, + 'ins': 0, + 'del': 0 + } + for token in rec: + if token not in self.data and len(token) > 0: + self.data[token] = { + 'all': 0, + 'cor': 0, + 'sub': 0, + 'ins': 0, + 'del': 0 + } + # Computing edit distance + for i, lab_token in enumerate(lab): + for j, rec_token in enumerate(rec): + if i == 0 or j == 0: + continue + min_dist = sys.maxsize + min_error = 'none' + dist = self.space[i - 1][j]['dist'] + self.cost['del'] + error = 'del' + if dist < min_dist: + min_dist = dist + min_error = error + dist = self.space[i][j - 1]['dist'] + self.cost['ins'] + error = 'ins' + if dist < min_dist: + min_dist = dist + min_error = error + if lab_token == rec_token: + dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor'] + error = 'cor' + else: + dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub'] + error = 'sub' + if dist < min_dist: + min_dist = dist + min_error = error + self.space[i][j]['dist'] = min_dist + self.space[i][j]['error'] = min_error + # Tracing back + result = { + 'lab': [], + 'rec': [], + 'all': 0, + 'cor': 0, + 'sub': 0, + 'ins': 0, + 'del': 0 + } + i = len(lab) - 1 + j = len(rec) - 1 + while True: + if self.space[i][j]['error'] == 'cor': # correct + if len(lab[i]) > 0: + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1 + result['all'] = result['all'] + 1 + result['cor'] = result['cor'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, rec[j]) + i = i - 1 + j = j - 1 + elif self.space[i][j]['error'] == 'sub': # substitution + if len(lab[i]) > 0: + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1 + result['all'] = result['all'] + 1 + result['sub'] = result['sub'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, rec[j]) + i = i - 1 + j = j - 1 + elif self.space[i][j]['error'] == 'del': # deletion + if len(lab[i]) > 0: + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1 + result['all'] = result['all'] + 1 + result['del'] = result['del'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, "") + i = i - 1 + elif self.space[i][j]['error'] == 'ins': # insertion + if len(rec[j]) > 0: + self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1 + result['ins'] = result['ins'] + 1 + result['lab'].insert(0, "") + result['rec'].insert(0, rec[j]) + j = j - 1 + elif self.space[i][j]['error'] == 'non': # starting point + break + else: # shouldn't reach here + print( + 'this should not happen , i = {i} , j = {j} , error = {error}' + .format(i=i, j=j, error=self.space[i][j]['error'])) + return result + + def overall(self): + result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} + for token in self.data: + result['all'] = result['all'] + self.data[token]['all'] + result['cor'] = result['cor'] + self.data[token]['cor'] + result['sub'] = result['sub'] + self.data[token]['sub'] + result['ins'] = result['ins'] + self.data[token]['ins'] + result['del'] = result['del'] + self.data[token]['del'] + return result + + def cluster(self, data): + result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} + for token in data: + if token in self.data: + result['all'] = result['all'] + self.data[token]['all'] + result['cor'] = result['cor'] + self.data[token]['cor'] + result['sub'] = result['sub'] + self.data[token]['sub'] + result['ins'] = result['ins'] + self.data[token]['ins'] + result['del'] = result['del'] + self.data[token]['del'] + return result + + def keys(self): + return list(self.data.keys()) + def width(string): - return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string) - -def default_cluster(word) : - unicode_names = [ unicodedata.name(char) for char in word ] - for i in reversed(range(len(unicode_names))) : - if unicode_names[i].startswith('DIGIT') : # 1 - unicode_names[i] = 'Number' # 'DIGIT' - elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH') or - unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')) : - # 明 / 郎 - unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH' - elif (unicode_names[i].startswith('LATIN CAPITAL LETTER') or - unicode_names[i].startswith('LATIN SMALL LETTER')) : - # A / a - unicode_names[i] = 'English' # 'LATIN LETTER' - elif unicode_names[i].startswith('HIRAGANA LETTER') : # は こ め - unicode_names[i] = 'Japanese' # 'GANA LETTER' - elif (unicode_names[i].startswith('AMPERSAND') or - unicode_names[i].startswith('APOSTROPHE') or - unicode_names[i].startswith('COMMERCIAL AT') or - unicode_names[i].startswith('DEGREE CELSIUS') or - unicode_names[i].startswith('EQUALS SIGN') or - unicode_names[i].startswith('FULL STOP') or - unicode_names[i].startswith('HYPHEN-MINUS') or - unicode_names[i].startswith('LOW LINE') or - unicode_names[i].startswith('NUMBER SIGN') or - unicode_names[i].startswith('PLUS SIGN') or - unicode_names[i].startswith('SEMICOLON')) : - # & / ' / @ / ℃ / = / . / - / _ / # / + / ; - del unicode_names[i] - else : - return 'Other' - if len(unicode_names) == 0 : - return 'Other' - if len(unicode_names) == 1 : - return unicode_names[0] - for i in range(len(unicode_names)-1) : - if unicode_names[i] != unicode_names[i+1] : - return 'Other' - return unicode_names[0] - -def usage() : - print("compute-wer.py : compute word error rate (WER) and align recognition results and references.") - print(" usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer") + return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string) + + +def default_cluster(word): + unicode_names = [unicodedata.name(char) for char in word] + for i in reversed(range(len(unicode_names))): + if unicode_names[i].startswith('DIGIT'): # 1 + unicode_names[i] = 'Number' # 'DIGIT' + elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH') + or unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')): + # 明 / 郎 + unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH' + elif (unicode_names[i].startswith('LATIN CAPITAL LETTER') + or unicode_names[i].startswith('LATIN SMALL LETTER')): + # A / a + unicode_names[i] = 'English' # 'LATIN LETTER' + elif unicode_names[i].startswith('HIRAGANA LETTER'): # は こ め + unicode_names[i] = 'Japanese' # 'GANA LETTER' + elif (unicode_names[i].startswith('AMPERSAND') + or unicode_names[i].startswith('APOSTROPHE') + or unicode_names[i].startswith('COMMERCIAL AT') + or unicode_names[i].startswith('DEGREE CELSIUS') + or unicode_names[i].startswith('EQUALS SIGN') + or unicode_names[i].startswith('FULL STOP') + or unicode_names[i].startswith('HYPHEN-MINUS') + or unicode_names[i].startswith('LOW LINE') + or unicode_names[i].startswith('NUMBER SIGN') + or unicode_names[i].startswith('PLUS SIGN') + or unicode_names[i].startswith('SEMICOLON')): + # & / ' / @ / ℃ / = / . / - / _ / # / + / ; + del unicode_names[i] + else: + return 'Other' + if len(unicode_names) == 0: + return 'Other' + if len(unicode_names) == 1: + return unicode_names[0] + for i in range(len(unicode_names) - 1): + if unicode_names[i] != unicode_names[i + 1]: + return 'Other' + return unicode_names[0] + + +def usage(): + print( + "compute-wer.py : compute word error rate (WER) and align recognition results and references." + ) + print( + " usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer" + ) + if __name__ == '__main__': - if len(sys.argv) == 1 : - usage() - sys.exit(0) - calculator = Calculator() - cluster_file = '' - ignore_words = set() - tochar = False - verbose= 1 - padding_symbol= ' ' - case_sensitive = False - max_words_per_line = sys.maxsize - split = None - while len(sys.argv) > 3: - a = '--maxw=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):] - del sys.argv[1] - max_words_per_line = int(b) - continue - a = '--rt=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):].lower() - del sys.argv[1] - remove_tag = (b == 'true') or (b != '0') - continue - a = '--cs=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):].lower() - del sys.argv[1] - case_sensitive = (b == 'true') or (b != '0') - continue - a = '--cluster=' - if sys.argv[1].startswith(a): - cluster_file = sys.argv[1][len(a):] - del sys.argv[1] - continue - a = '--splitfile=' - if sys.argv[1].startswith(a): - split_file = sys.argv[1][len(a):] - del sys.argv[1] - split = dict() - with codecs.open(split_file, 'r', 'utf-8') as fh: - for line in fh: # line in unicode - words = line.strip().split() - if len(words) >= 2: - split[words[0]] = words[1:] - continue - a = '--ig=' - if sys.argv[1].startswith(a): - ignore_file = sys.argv[1][len(a):] - del sys.argv[1] - with codecs.open(ignore_file, 'r', 'utf-8') as fh: - for line in fh: # line in unicode - line = line.strip() - if len(line) > 0: - ignore_words.add(line) - continue - a = '--char=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):].lower() - del sys.argv[1] - tochar = (b == 'true') or (b != '0') - continue - a = '--v=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):].lower() - del sys.argv[1] - verbose=0 - try: - verbose=int(b) - except: - if b == 'true' or b != '0': - verbose = 1 - continue - a = '--padding-symbol=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):].lower() - del sys.argv[1] - if b == 'space': - padding_symbol= ' ' - elif b == 'underline': - padding_symbol= '_' - continue - if True or sys.argv[1].startswith('-'): - #ignore invalid switch - del sys.argv[1] - continue - - if not case_sensitive: - ig=set([w.upper() for w in ignore_words]) - ignore_words = ig - - default_clusters = {} - default_words = {} - - ref_file = sys.argv[1] - hyp_file = sys.argv[2] - rec_set = {} - if split and not case_sensitive: - newsplit = dict() - for w in split: - words = split[w] - for i in range(len(words)): - words[i] = words[i].upper() - newsplit[w.upper()] = words - split = newsplit - - with codecs.open(hyp_file, 'r', 'utf-8') as fh: - for line in fh: + if len(sys.argv) == 1: + usage() + sys.exit(0) + calculator = Calculator() + cluster_file = '' + ignore_words = set() + tochar = False + verbose = 1 + padding_symbol = ' ' + case_sensitive = False + max_words_per_line = sys.maxsize + split = None + while len(sys.argv) > 3: + a = '--maxw=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):] + del sys.argv[1] + max_words_per_line = int(b) + continue + a = '--rt=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + remove_tag = (b == 'true') or (b != '0') + continue + a = '--cs=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + case_sensitive = (b == 'true') or (b != '0') + continue + a = '--cluster=' + if sys.argv[1].startswith(a): + cluster_file = sys.argv[1][len(a):] + del sys.argv[1] + continue + a = '--splitfile=' + if sys.argv[1].startswith(a): + split_file = sys.argv[1][len(a):] + del sys.argv[1] + split = dict() + with codecs.open(split_file, 'r', 'utf-8') as fh: + for line in fh: # line in unicode + words = line.strip().split() + if len(words) >= 2: + split[words[0]] = words[1:] + continue + a = '--ig=' + if sys.argv[1].startswith(a): + ignore_file = sys.argv[1][len(a):] + del sys.argv[1] + with codecs.open(ignore_file, 'r', 'utf-8') as fh: + for line in fh: # line in unicode + line = line.strip() + if len(line) > 0: + ignore_words.add(line) + continue + a = '--char=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + tochar = (b == 'true') or (b != '0') + continue + a = '--v=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + verbose = 0 + try: + verbose = int(b) + except: + if b == 'true' or b != '0': + verbose = 1 + continue + a = '--padding-symbol=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + if b == 'space': + padding_symbol = ' ' + elif b == 'underline': + padding_symbol = '_' + continue + if True or sys.argv[1].startswith('-'): + #ignore invalid switch + del sys.argv[1] + continue + + if not case_sensitive: + ig = set([w.upper() for w in ignore_words]) + ignore_words = ig + + default_clusters = {} + default_words = {} + + ref_file = sys.argv[1] + hyp_file = sys.argv[2] + rec_set = {} + if split and not case_sensitive: + newsplit = dict() + for w in split: + words = split[w] + for i in range(len(words)): + words[i] = words[i].upper() + newsplit[w.upper()] = words + split = newsplit + + with codecs.open(hyp_file, 'r', 'utf-8') as fh: + for line in fh: + if tochar: + array = characterize(line) + else: + array = line.strip().split() + if len(array) == 0: continue + fid = array[0] + rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive, + split) + + # compute error rate on the interaction of reference file and hyp file + for line in open(ref_file, 'r', encoding='utf-8'): if tochar: array = characterize(line) else: - array = line.strip().split() - if len(array)==0: continue + array = line.rstrip('\n').split() + if len(array) == 0: continue fid = array[0] - rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive, split) + if fid not in rec_set: + continue + lab = normalize(array[1:], ignore_words, case_sensitive, split) + rec = rec_set[fid] + if verbose: + print('\nutt: %s' % fid) + + for word in rec + lab: + if word not in default_words: + default_cluster_name = default_cluster(word) + if default_cluster_name not in default_clusters: + default_clusters[default_cluster_name] = {} + if word not in default_clusters[default_cluster_name]: + default_clusters[default_cluster_name][word] = 1 + default_words[word] = default_cluster_name + + result = calculator.calculate(lab, rec) + if verbose: + if result['all'] != 0: + wer = float(result['ins'] + result['sub'] + + result['del']) * 100.0 / result['all'] + else: + wer = 0.0 + print('WER: %4.2f %%' % wer, end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], + result['ins'])) + space = {} + space['lab'] = [] + space['rec'] = [] + for idx in range(len(result['lab'])): + len_lab = width(result['lab'][idx]) + len_rec = width(result['rec'][idx]) + length = max(len_lab, len_rec) + space['lab'].append(length - len_lab) + space['rec'].append(length - len_rec) + upper_lab = len(result['lab']) + upper_rec = len(result['rec']) + lab1, rec1 = 0, 0 + while lab1 < upper_lab or rec1 < upper_rec: + if verbose > 1: + print('lab(%s):' % fid.encode('utf-8'), end=' ') + else: + print('lab:', end=' ') + lab2 = min(upper_lab, lab1 + max_words_per_line) + for idx in range(lab1, lab2): + token = result['lab'][idx] + print('{token}'.format(token=token), end='') + for n in range(space['lab'][idx]): + print(padding_symbol, end='') + print(' ', end='') + print() + if verbose > 1: + print('rec(%s):' % fid.encode('utf-8'), end=' ') + else: + print('rec:', end=' ') + rec2 = min(upper_rec, rec1 + max_words_per_line) + for idx in range(rec1, rec2): + token = result['rec'][idx] + print('{token}'.format(token=token), end='') + for n in range(space['rec'][idx]): + print(padding_symbol, end='') + print(' ', end='') + print('\n', end='\n') + lab1 = lab2 + rec1 = rec2 - # compute error rate on the interaction of reference file and hyp file - for line in open(ref_file, 'r', encoding='utf-8') : - if tochar: - array = characterize(line) - else: - array = line.rstrip('\n').split() - if len(array)==0: continue - fid = array[0] - if fid not in rec_set: - continue - lab = normalize(array[1:], ignore_words, case_sensitive, split) - rec = rec_set[fid] - if verbose: - print('\nutt: %s' % fid) - - for word in rec + lab : - if word not in default_words : - default_cluster_name = default_cluster(word) - if default_cluster_name not in default_clusters : - default_clusters[default_cluster_name] = {} - if word not in default_clusters[default_cluster_name] : - default_clusters[default_cluster_name][word] = 1 - default_words[word] = default_cluster_name - - result = calculator.calculate(lab, rec) if verbose: - if result['all'] != 0 : - wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] - else : - wer = 0.0 - print('WER: %4.2f %%' % wer, end = ' ') - print('N=%d C=%d S=%d D=%d I=%d' % - (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) - space = {} - space['lab'] = [] - space['rec'] = [] - for idx in range(len(result['lab'])) : - len_lab = width(result['lab'][idx]) - len_rec = width(result['rec'][idx]) - length = max(len_lab, len_rec) - space['lab'].append(length-len_lab) - space['rec'].append(length-len_rec) - upper_lab = len(result['lab']) - upper_rec = len(result['rec']) - lab1, rec1 = 0, 0 - while lab1 < upper_lab or rec1 < upper_rec: - if verbose > 1: - print('lab(%s):' % fid.encode('utf-8'), end = ' ') - else: - print('lab:', end = ' ') - lab2 = min(upper_lab, lab1 + max_words_per_line) - for idx in range(lab1, lab2): - token = result['lab'][idx] - print('{token}'.format(token = token), end = '') - for n in range(space['lab'][idx]) : - print(padding_symbol, end = '') - print(' ',end='') - print() - if verbose > 1: - print('rec(%s):' % fid.encode('utf-8'), end = ' ') - else: - print('rec:', end = ' ') - rec2 = min(upper_rec, rec1 + max_words_per_line) - for idx in range(rec1, rec2): - token = result['rec'][idx] - print('{token}'.format(token = token), end = '') - for n in range(space['rec'][idx]) : - print(padding_symbol, end = '') - print(' ',end='') - print('\n', end='\n') - lab1 = lab2 - rec1 = rec2 - - if verbose: - print('===========================================================================') - print() - - result = calculator.overall() - if result['all'] != 0 : - wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] - else : - wer = 0.0 - print('Overall -> %4.2f %%' % wer, end = ' ') - print('N=%d C=%d S=%d D=%d I=%d' % - (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) - if not verbose: - print() - - if verbose: - for cluster_id in default_clusters : - result = calculator.cluster([ k for k in default_clusters[cluster_id] ]) - if result['all'] != 0 : - wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] - else : + print( + '===========================================================================' + ) + print() + + result = calculator.overall() + if result['all'] != 0: + wer = float(result['ins'] + result['sub'] + + result['del']) * 100.0 / result['all'] + else: wer = 0.0 - print('%s -> %4.2f %%' % (cluster_id, wer), end = ' ') - print('N=%d C=%d S=%d D=%d I=%d' % - (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) - if len(cluster_file) > 0 : # compute separated WERs for word clusters - cluster_id = '' - cluster = [] - for line in open(cluster_file, 'r', encoding='utf-8') : - for token in line.decode('utf-8').rstrip('\n').split() : - # end of cluster reached, like - if token[0:2] == '' and \ - token.lstrip('') == cluster_id : - result = calculator.cluster(cluster) - if result['all'] != 0 : - wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] - else : - wer = 0.0 - print('%s -> %4.2f %%' % (cluster_id, wer), end = ' ') - print('N=%d C=%d S=%d D=%d I=%d' % - (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) - cluster_id = '' - cluster = [] - # begin of cluster reached, like - elif token[0] == '<' and token[len(token)-1] == '>' and \ - cluster_id == '' : - cluster_id = token.lstrip('<').rstrip('>') - cluster = [] - # general terms, like WEATHER / CAR / ... - else : - cluster.append(token) - print() - print('===========================================================================') + print('Overall -> %4.2f %%' % wer, end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], + result['ins'])) + if not verbose: + print() + + if verbose: + for cluster_id in default_clusters: + result = calculator.cluster( + [k for k in default_clusters[cluster_id]]) + if result['all'] != 0: + wer = float(result['ins'] + result['sub'] + + result['del']) * 100.0 / result['all'] + else: + wer = 0.0 + print('%s -> %4.2f %%' % (cluster_id, wer), end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], + result['ins'])) + if len(cluster_file) > 0: # compute separated WERs for word clusters + cluster_id = '' + cluster = [] + for line in open(cluster_file, 'r', encoding='utf-8'): + for token in line.decode('utf-8').rstrip('\n').split(): + # end of cluster reached, like + if token[0:2] == '' and \ + token.lstrip('') == cluster_id : + result = calculator.cluster(cluster) + if result['all'] != 0: + wer = float(result['ins'] + result['sub'] + + result['del']) * 100.0 / result['all'] + else: + wer = 0.0 + print('%s -> %4.2f %%' % (cluster_id, wer), end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], + result['del'], result['ins'])) + cluster_id = '' + cluster = [] + # begin of cluster reached, like + elif token[0] == '<' and token[len(token)-1] == '>' and \ + cluster_id == '' : + cluster_id = token.lstrip('<').rstrip('>') + cluster = [] + # general terms, like WEATHER / CAR / ... + else: + cluster.append(token) + print() + print( + '===========================================================================' + ) diff --git a/tools/compute_cmvn_stats.py b/tools/compute_cmvn_stats.py index 9c89789c4..cd3a2e2ef 100755 --- a/tools/compute_cmvn_stats.py +++ b/tools/compute_cmvn_stats.py @@ -32,7 +32,8 @@ def __call__(self, batch): value = item[1].strip().split(",") assert len(value) == 3 or len(value) == 1 wav_path = value[0] - sample_rate = torchaudio.backend.sox_io_backend.info(wav_path).sample_rate + sample_rate = torchaudio.backend.sox_io_backend.info( + wav_path).sample_rate resample_rate = sample_rate # len(value) == 3 means segmented wav.scp, # len(value) == 1 means original wav.scp @@ -64,6 +65,7 @@ def __call__(self, batch): class AudioDataset(Dataset): + def __init__(self, data_file): self.items = [] with codecs.open(data_file, 'r', encoding='utf-8') as f: @@ -101,7 +103,8 @@ def __getitem__(self, idx): feat_dim = configs['dataset_conf']['fbank_conf']['num_mel_bins'] resample_rate = 0 if 'resample_conf' in configs['dataset_conf']: - resample_rate = configs['dataset_conf']['resample_conf']['resample_rate'] + resample_rate = configs['dataset_conf']['resample_conf'][ + 'resample_rate'] print('using resample and new sample rate is {}'.format(resample_rate)) collate_func = CollateFunc(feat_dim, resample_rate) diff --git a/tools/flake8_hook.py b/tools/flake8_hook.py index bbe21bf4a..9c05575c5 100755 --- a/tools/flake8_hook.py +++ b/tools/flake8_hook.py @@ -5,9 +5,7 @@ from flake8.main import git if __name__ == '__main__': - sys.exit( - git.hook( - strict=True, - lazy=git.config_for('lazy'), - ) - ) + sys.exit(git.hook( + strict=True, + lazy=git.config_for('lazy'), + )) diff --git a/tools/fst/prepare_dict.py b/tools/fst/prepare_dict.py index c012eced3..865638a7d 100755 --- a/tools/fst/prepare_dict.py +++ b/tools/fst/prepare_dict.py @@ -50,17 +50,16 @@ def contain_oov(units): else: pieces = word if contain_oov(pieces): - print( - 'Ignoring words {}, which contains oov unit'.format( - ''.join(word).strip('▁')) - ) + print('Ignoring words {}, which contains oov unit'.format( + ''.join(word).strip('▁'))) continue chars = ' '.join( [p if p in unit_table else '' for p in pieces]) else: # ignore words with OOV if contain_oov(word): - print('Ignoring words {}, which contains oov unit'.format(word)) + print('Ignoring words {}, which contains oov unit'.format( + word)) continue # Optional, append ▁ in front of english word # we assume the model unit of our e2e system is char now. diff --git a/tools/k2/prepare_char.py b/tools/k2/prepare_char.py index 6e05042c4..b1e210787 100644 --- a/tools/k2/prepare_char.py +++ b/tools/k2/prepare_char.py @@ -16,8 +16,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """ This script generates the following files in the directory sys.argv[3]: @@ -139,9 +137,8 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: return False -def generate_lexicon( - token_sym_table: Dict[str, int], words: List[str] -) -> Lexicon: +def generate_lexicon(token_sym_table: Dict[str, int], + words: List[str]) -> Lexicon: """Generate a lexicon from a word list and token_sym_table. Args: @@ -205,8 +202,10 @@ def generate_words(text_file: str) -> Dict[str, int]: # We put '' '' at begining of word2id # '#0', '', '' at end of word2id - words = [word for word in words - if word not in ['', '', '#0', '', '']] + words = [ + word for word in words + if word not in ['', '', '#0', '', ''] + ] words.insert(0, '') words.insert(1, '') words.append('#0') @@ -221,9 +220,10 @@ def main(): word2id = generate_words(sys.argv[2]) tgt_dir = Path(sys.argv[3]) - words = [word for word in word2id.keys() - if word not in - ["", "!SIL", "", "", "#0", "", ""]] + words = [ + word for word in word2id.keys() if word not in + ["", "!SIL", "", "", "#0", "", ""] + ] lexicon = generate_lexicon(token2id, words) lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) diff --git a/tools/latency_metrics.py b/tools/latency_metrics.py index df2d8eee4..f408ced3e 100644 --- a/tools/latency_metrics.py +++ b/tools/latency_metrics.py @@ -34,33 +34,46 @@ def get_args(): parser = argparse.ArgumentParser( description='Analyze latency and plot CTC-Spike.') - parser.add_argument('--config', required=True, - type=str, help='configration') + parser.add_argument('--config', + required=True, + type=str, + help='configration') parser.add_argument('--gpu', type=int, default=0, help='gpu id for this rank, -1 for cpu') - parser.add_argument('--ckpt', required=True, - type=str, help='model checkpoint') - parser.add_argument('--tag', required=True, - type=str, help='image subtitle') - parser.add_argument('--wavscp', required=True, - type=str, help='wav.scp') - parser.add_argument('--alignment', required=True, - type=str, help='force alignment, generated by Kaldi.') - parser.add_argument('--chunk_size', required=True, - type=int, help='chunk size') - parser.add_argument('--left_chunks', default=-1, - type=int, help='left chunks') - parser.add_argument('--font', required=True, - type=str, help='font file') - parser.add_argument('--dict', required=True, - type=str, help='dict file') - parser.add_argument('--result_dir', required=True, - type=str, help='saving pdf') - parser.add_argument('--model_type', default='ctc', - choices=['ctc', 'transducer'], - help='show latency metrics from ctc models or rnn-t models') + parser.add_argument('--ckpt', + required=True, + type=str, + help='model checkpoint') + parser.add_argument('--tag', + required=True, + type=str, + help='image subtitle') + parser.add_argument('--wavscp', required=True, type=str, help='wav.scp') + parser.add_argument('--alignment', + required=True, + type=str, + help='force alignment, generated by Kaldi.') + parser.add_argument('--chunk_size', + required=True, + type=int, + help='chunk size') + parser.add_argument('--left_chunks', + default=-1, + type=int, + help='left chunks') + parser.add_argument('--font', required=True, type=str, help='font file') + parser.add_argument('--dict', required=True, type=str, help='dict file') + parser.add_argument('--result_dir', + required=True, + type=str, + help='saving pdf') + parser.add_argument( + '--model_type', + default='ctc', + choices=['ctc', 'transducer'], + help='show latency metrics from ctc models or rnn-t models') args = parser.parse_args() return args @@ -110,7 +123,8 @@ def main(): num_mel_bins=conf['dataset_conf']['fbank_conf']['num_mel_bins'], frame_length=conf['dataset_conf']['fbank_conf']['frame_length'], frame_shift=conf['dataset_conf']['fbank_conf']['frame_shift'], - dither=0.0, energy_floor=0.0, + dither=0.0, + energy_floor=0.0, sample_frequency=resample_rate, ) @@ -118,8 +132,9 @@ def main(): speech_lengths = torch.tensor([mat.size(0)]).to(device) # Let's assume batch_size = 1 - encoder_out, encoder_mask = model.encoder( - speech, speech_lengths, args.chunk_size, args.left_chunks) + encoder_out, encoder_mask = model.encoder(speech, speech_lengths, + args.chunk_size, + args.left_chunks) maxlen = encoder_out.size(1) # (B, maxlen, encoder_dim) encoder_out_lens = encoder_mask.squeeze(1).sum(1) @@ -146,7 +161,8 @@ def main(): padding = torch.zeros(1, 1).to(encoder_out.device) # sos pred_input_step = torch.tensor([model.blank]).reshape(1, 1) - cache = model.predictor.init_state(1, method="zero", + cache = model.predictor.init_state(1, + method="zero", device=encoder_out.device) new_cache: List[torch.Tensor] = [] t = 0 @@ -158,11 +174,12 @@ def main(): while t < encoder_out_lens: encoder_out_step = encoder_out[:, t:t + 1, :] # [1, 1, E] if prev_out_nblk: - step_outs = model.predictor.forward_step(pred_input_step, - padding, cache) + step_outs = model.predictor.forward_step( + pred_input_step, padding, cache) pred_out_step, new_cache = step_outs[0], step_outs[1] - joint_out_step = model.joint(encoder_out_step, pred_out_step) # [1,1,v] + joint_out_step = model.joint(encoder_out_step, + pred_out_step) # [1,1,v] joint_out_probs = joint_out_step.log_softmax(dim=-1) scores.append(torch.max(joint_out_probs).item()) @@ -220,11 +237,12 @@ def main(): # datas[i] = [key, text_fa, text_st, list_of_diff, # FirstTokenDelay, LastTokenDelay, AvgTokenDelay, # streaming_timestamps, force_alignment] - datas.append([key, text_fa, text_st, - [a - b for a, b in zip(st, fa)], - st[0] - fa[0], st[-1] - fa[-1], - (sum(st) - sum(fa)) / len(st), - timestamps[key], align.split()]) + datas.append([ + key, text_fa, text_st, + [a - b for a, b in zip(st, fa)], st[0] - fa[0], st[-1] - fa[-1], + (sum(st) - sum(fa)) / len(st), timestamps[key], + align.split() + ]) logging.info("not found: {}, length unequal: {}, ignored: {}, \ valid samples: {}".format(not_found, len_unequal, ignored, len(datas))) @@ -234,11 +252,18 @@ def main(): names = ['FirstTokenDelay', 'LastTokenDelay', 'AvgTokenDelay'] names_index = [4, 5, 6] parts = ['max', 'P90', 'P75', 'P50', 'P25', 'min'] - parts_index = [num_datas - 1, int(num_datas * 0.90), int(num_datas * 0.75), - int(num_datas * 0.50), int(num_datas * 0.25), 0] + parts_index = [ + num_datas - 1, + int(num_datas * 0.90), + int(num_datas * 0.75), + int(num_datas * 0.50), + int(num_datas * 0.25), 0 + ] for name, name_idx in zip(names, names_index): + def f(name_idx=name_idx): return name_idx + datas.sort(key=lambda x: x[f()]) logging.info("==========================") for p, i in zip(parts, parts_index): @@ -268,22 +293,25 @@ def f(name_idx=name_idx): for frame, token, prob in zip(x, hyps, scores): if char_dict[token] != '': axes[j].bar( - frame, np.exp(prob), - label='{} {:.3f}'.format( - char_dict[token], np.exp(prob)), + frame, + np.exp(prob), + label='{} {:.3f}'.format(char_dict[token], + np.exp(prob)), ) axes[j].text( - frame, np.exp(prob), - '{} {:.3f} {}'.format( - char_dict[token], np.exp(prob), frame), + frame, + np.exp(prob), + '{} {:.3f} {}'.format(char_dict[token], + np.exp(prob), frame), fontdict=dict(fontsize=24), fontproperties=font, ) else: axes[j].bar( - frame, 0.01, - label='{} {:.3f}'.format( - char_dict[token], np.exp(prob)), + frame, + 0.01, + label='{} {:.3f}'.format(char_dict[token], + np.exp(prob)), ) axes[j].tick_params(labelsize=25) @@ -294,8 +322,8 @@ def f(name_idx=name_idx): axes[-1].plot(time, samples) # i.e., RESULT_DIR/LTD_P90_120ms_BAC009S0768W0342.pdf - plt.savefig(args.result_dir + "/" + name + "_" + - p + "_" + str(data[f()]) + "ms" + "_" + data[0] + ".pdf") + plt.savefig(args.result_dir + "/" + name + "_" + p + "_" + + str(data[f()]) + "ms" + "_" + data[0] + ".pdf") if __name__ == '__main__': diff --git a/tools/onnx2horizonbin.py b/tools/onnx2horizonbin.py index e0db5e137..3f6474572 100755 --- a/tools/onnx2horizonbin.py +++ b/tools/onnx2horizonbin.py @@ -55,7 +55,6 @@ from wenet.bin.export_onnx_cpu import to_numpy from wenet.bin.export_onnx_bpu import export_encoder, export_ctc - try: import hbdk # noqa: F401 import horizon_nn # noqa: F401 @@ -64,7 +63,6 @@ print('Please install hbdk,horizon_nn,horizon_tc_ui !') sys.exit(1) - logger = logging.getLogger(__file__) logger.setLevel(logging.INFO) @@ -82,8 +80,11 @@ def make_calibration_data(enc, args, conf): conf['shuffle'] = True logger.info(conf) tokenizer = init_tokenizer(ali_conf, args.symbol_table, args.bpe_model) - dataset = Dataset( - "shard", args.cali_datalist, tokenizer, conf, partition=False) + dataset = Dataset("shard", + args.cali_datalist, + tokenizer, + conf, + partition=False) dataloader = DataLoader(dataset, batch_size=None, num_workers=0) subsampling = enc.embed.subsampling_rate @@ -105,14 +106,16 @@ def make_calibration_data(enc, args, conf): num_frames, prefix = feats.size(1), keys[0] att_cache = torch.zeros( [1, head * num_layers, d_k * 2, required_cache_size], - dtype=feats.dtype, device=feats.device) + dtype=feats.dtype, + device=feats.device) att_mask = torch.ones( [1, head, chunk_size, required_cache_size + chunk_size], - dtype=feats.dtype, device=feats.device) + dtype=feats.dtype, + device=feats.device) att_mask[:, :, :, :required_cache_size] = 0 - cnn_cache = torch.zeros( - [1, dim, num_layers, lorder], - dtype=feats.dtype, device=feats.device) + cnn_cache = torch.zeros([1, dim, num_layers, lorder], + dtype=feats.dtype, + device=feats.device) # Feed forward overlap input step by step random_high = (num_frames - context) // stride @@ -138,9 +141,10 @@ def make_calibration_data(enc, args, conf): prefix + "." + str(i)) save_data(att_mask, "{}/att_mask".format(cal_data_dir), prefix + "." + str(i)) - (y, att_cache, cnn_cache) = enc.forward( - xs=chunk, att_cache=att_cache, - cnn_cache=cnn_cache, att_mask=att_mask) + (y, att_cache, cnn_cache) = enc.forward(xs=chunk, + att_cache=att_cache, + cnn_cache=cnn_cache, + att_mask=att_mask) # NOTE(xcsong): It's fast to calibrate ctc.onnx, # so it's okay to save all chunks save_data(y, "{}/hidden".format(cal_data_dir), @@ -150,8 +154,11 @@ def make_calibration_data(enc, args, conf): def check_wer(enc, ctc, args, conf): conf['shuffle'] = False tokenizer = init_tokenizer(ali_conf, args.symbol_table, args.bpe_model) - dataset = Dataset( - "shard", args.wer_datalist, tokenizer, conf, partition=False) + dataset = Dataset("shard", + args.wer_datalist, + tokenizer, + conf, + partition=False) dataloader = DataLoader(dataset, batch_size=None, num_workers=0) char_dict = {v: k for k, v in args.symbol_table.items()} eos = len(char_dict) - 1 @@ -178,14 +185,16 @@ def check_wer(enc, ctc, args, conf): num_frames, prefix = feats.size(1), keys[0] att_cache = torch.zeros( [1, head * num_layers, d_k * 2, required_cache_size], - dtype=feats.dtype, device=feats.device) + dtype=feats.dtype, + device=feats.device) att_mask = torch.ones( [1, head, chunk_size, required_cache_size + chunk_size], - dtype=feats.dtype, device=feats.device) + dtype=feats.dtype, + device=feats.device) att_mask[:, :, :, :required_cache_size] = 0 - cnn_cache = torch.zeros( - [1, dim, num_layers, lorder], - dtype=feats.dtype, device=feats.device) + cnn_cache = torch.zeros([1, dim, num_layers, lorder], + dtype=feats.dtype, + device=feats.device) onnx_att_cache = to_numpy(att_cache) onnx_cnn_cache = to_numpy(cnn_cache) @@ -204,21 +213,28 @@ def check_wer(enc, ctc, args, conf): if pad_len >= subsampling: att_mask[:, :, :, -(pad_len // subsampling):] = 0 # Torch model - (y, att_cache, cnn_cache) = enc.forward( - xs=chunk, att_cache=att_cache, - cnn_cache=cnn_cache, att_mask=att_mask) + (y, att_cache, cnn_cache) = enc.forward(xs=chunk, + att_cache=att_cache, + cnn_cache=cnn_cache, + att_mask=att_mask) torch_out.append(ctc.forward(y).transpose(1, 3).squeeze(2)) # Quantized onnx model ort_inputs = { - 'chunk': to_numpy(chunk), 'att_cache': onnx_att_cache, - 'cnn_cache': onnx_cnn_cache, 'att_mask': to_numpy(att_mask)} - ort_outs = enc_session.run_feature( - enc_session.output_names, ort_inputs, input_offset=0) + 'chunk': to_numpy(chunk), + 'att_cache': onnx_att_cache, + 'cnn_cache': onnx_cnn_cache, + 'att_mask': to_numpy(att_mask) + } + ort_outs = enc_session.run_feature(enc_session.output_names, + ort_inputs, + input_offset=0) onnx_att_cache, onnx_cnn_cache = ort_outs[1], ort_outs[2] - onnx_y = ctc_session.run_feature( - ctc_session.output_names, {'hidden': ort_outs[0]}, input_offset=0) - onnx_out.append(torch.from_numpy( - np.squeeze(onnx_y[0].transpose(0, 3, 2, 1), axis=2))) + onnx_y = ctc_session.run_feature(ctc_session.output_names, + {'hidden': ort_outs[0]}, + input_offset=0) + onnx_out.append( + torch.from_numpy( + np.squeeze(onnx_y[0].transpose(0, 3, 2, 1), axis=2))) def post_process(list_out, file_obj, keys): probs = torch.cat(list_out, dim=1) @@ -325,17 +341,16 @@ def generate_config(enc_session, ctc_session, args): ctc_cal_data = ";".join( [cal_data_dir + "/" + x for x in ctc_dic['input_name'].split(';')]) enc_config = template.format( - enc_onnx_path, "encoder", enc_log_path, - enc_dic['input_name'], enc_dic['input_type'], - enc_dic['input_layout_train'], enc_dic['input_shape'], - enc_dic['norm_type'], enc_dic['input_type'], enc_dic['input_layout_rt'], - enc_cal_data, args.calibration_type, args.extra_ops_run_on_cpu, "") + enc_onnx_path, "encoder", enc_log_path, enc_dic['input_name'], + enc_dic['input_type'], enc_dic['input_layout_train'], + enc_dic['input_shape'], enc_dic['norm_type'], enc_dic['input_type'], + enc_dic['input_layout_rt'], enc_cal_data, args.calibration_type, + args.extra_ops_run_on_cpu, "") ctc_config = template.format( - ctc_onnx_path, "ctc", ctc_log_path, - ctc_dic['input_name'], ctc_dic['input_type'], - ctc_dic['input_layout_train'], ctc_dic['input_shape'], - ctc_dic['norm_type'], ctc_dic['input_type'], ctc_dic['input_layout_rt'], - ctc_cal_data, "default", "", "") + ctc_onnx_path, "ctc", ctc_log_path, ctc_dic['input_name'], + ctc_dic['input_type'], ctc_dic['input_layout_train'], + ctc_dic['input_shape'], ctc_dic['norm_type'], ctc_dic['input_type'], + ctc_dic['input_layout_rt'], ctc_cal_data, "default", "", "") with open(output_dir + "/config_encoder.yaml", "w") as enc_yaml: enc_yaml.write(enc_config) with open(output_dir + "/config_ctc.yaml", "w") as ctc_yaml: @@ -343,32 +358,51 @@ def generate_config(enc_session, ctc_session, args): def get_args(): - parser = argparse.ArgumentParser(description='convert onnx to horizon .bin') + parser = argparse.ArgumentParser( + description='convert onnx to horizon .bin') parser.add_argument('--config', required=True, help='config file') parser.add_argument('--checkpoint', required=True, help='checkpoint model') parser.add_argument('--output_dir', required=True, help='output directory') - parser.add_argument('--chunk_size', required=True, - type=int, help='decoding chunk size') - parser.add_argument('--num_decoding_left_chunks', required=True, - type=int, help='cache chunks') - parser.add_argument('--reverse_weight', default=0.5, - type=float, help='reverse_weight in attention_rescoing') + parser.add_argument('--chunk_size', + required=True, + type=int, + help='decoding chunk size') + parser.add_argument('--num_decoding_left_chunks', + required=True, + type=int, + help='cache chunks') + parser.add_argument('--reverse_weight', + default=0.5, + type=float, + help='reverse_weight in attention_rescoing') parser.add_argument('--dict', type=str, required=True, help='dict file') - parser.add_argument('--max_samples', type=int, required=True, + parser.add_argument('--max_samples', + type=int, + required=True, help='maximum samples') - parser.add_argument('--cali_datalist', type=str, default=None, + parser.add_argument('--cali_datalist', + type=str, + default=None, help='make calibration data') - parser.add_argument('--wer_datalist', type=str, default=None, + parser.add_argument('--wer_datalist', + type=str, + default=None, help='check wer') - parser.add_argument('--wer_text', type=str, default=None, - help='check wer') - parser.add_argument('--bpe_model', default=None, type=str, + parser.add_argument('--wer_text', type=str, default=None, help='check wer') + parser.add_argument('--bpe_model', + default=None, + type=str, help='bpe model for english part') - parser.add_argument('--ln_run_on_bpu', action='store_true', + parser.add_argument('--ln_run_on_bpu', + action='store_true', help='layernorm running on bpu') - parser.add_argument('--extra_ops_run_on_cpu', type=str, default=None, + parser.add_argument('--extra_ops_run_on_cpu', + type=str, + default=None, help='extra operations running on cpu.') - parser.add_argument('--calibration_type', type=str, default='default', + parser.add_argument('--calibration_type', + type=str, + default='default', help='kl / max / default.') return parser @@ -453,33 +487,29 @@ def get_args(): output_dir = os.path.realpath(args.output_dir) logger.info("Stage-4: Make ctc.bin") - os.system( - "cd {} && mkdir -p hb_makertbin_log_ctc".format(output_dir) + - " && cd hb_makertbin_log_ctc &&" + - " hb_mapper makertbin --model-type \"onnx\" --config \"{}\"".format( - output_dir + "/config_ctc.yaml") - ) + os.system("cd {} && mkdir -p hb_makertbin_log_ctc".format(output_dir) + + " && cd hb_makertbin_log_ctc &&" + + " hb_mapper makertbin --model-type \"onnx\" --config \"{}\"". + format(output_dir + "/config_ctc.yaml")) logger.info("Stage-5: Make encoder.bin") os.system( "cd {} && mkdir -p hb_makertbin_log_encoder ".format(output_dir) + " && cd hb_makertbin_log_encoder &&" + - " hb_mapper makertbin --model-type \"onnx\" --config \"{}\"".format( - output_dir + "/config_encoder.yaml") - ) + " hb_mapper makertbin --model-type \"onnx\" --config \"{}\"". + format(output_dir + "/config_encoder.yaml")) if args.wer_datalist is not None: - logger.info("Stage-6: Check wer between torch model and quantized onnx") + logger.info( + "Stage-6: Check wer between torch model and quantized onnx") assert args.wer_text is not None check_wer(enc, ctc, args, conf) os.system( "python3 tools/compute-wer.py --char=1 --v=1 {} {} > {}".format( args.wer_text, args.output_dir + "/torch_text", - args.output_dir + "/torch_wer") - ) + args.output_dir + "/torch_wer")) os.system( "python3 tools/compute-wer.py --char=1 --v=1 {} {} > {}".format( args.wer_text, args.output_dir + "/onnx_text", - args.output_dir + "/onnx_wer") - ) - os.system("tail {} {}".format( - args.output_dir + "/torch_wer", args.output_dir + "/onnx_wer")) + args.output_dir + "/onnx_wer")) + os.system("tail {} {}".format(args.output_dir + "/torch_wer", + args.output_dir + "/onnx_wer")) diff --git a/tools/remove_longshortdata.py b/tools/remove_longshortdata.py index 7e92f8a42..a0596bf91 100755 --- a/tools/remove_longshortdata.py +++ b/tools/remove_longshortdata.py @@ -6,29 +6,40 @@ if __name__ == '__main__': parser = argparse.ArgumentParser( description='remove too long or too short data in format.data') - parser.add_argument('--data_file', - type=str, - help='input format data') + parser.add_argument('--data_file', type=str, help='input format data') parser.add_argument('--output_data_file', type=str, help='output format data') - parser.add_argument('--min_input_len', type=float, - default=0, - help='minimum input seq length, in seconds for raw wav, \ + parser.add_argument( + '--min_input_len', + type=float, + default=0, + help='minimum input seq length, in seconds for raw wav, \ in frame numbers for feature data') - parser.add_argument('--max_input_len', type=float, - default=20, - help='maximum output seq length, in seconds for raw wav, \ + parser.add_argument( + '--max_input_len', + type=float, + default=20, + help='maximum output seq length, in seconds for raw wav, \ in frame numbers for feature data') - parser.add_argument('--min_output_len', type=float, - default=0, help='minimum input seq length, in modeling units') - parser.add_argument('--max_output_len', type=float, + parser.add_argument('--min_output_len', + type=float, + default=0, + help='minimum input seq length, in modeling units') + parser.add_argument('--max_output_len', + type=float, default=500, help='maximum output seq length, in modeling units') - parser.add_argument('--min_output_input_ratio', type=float, default=0.05, - help='minimum output seq length/output seq length ratio') - parser.add_argument('--max_output_input_ratio', type=float, default=10, - help='maximum output seq length/output seq length ratio') + parser.add_argument( + '--min_output_input_ratio', + type=float, + default=0.05, + help='minimum output seq length/output seq length ratio') + parser.add_argument( + '--max_output_input_ratio', + type=float, + default=10, + help='maximum output seq length/output seq length ratio') args = parser.parse_args() data_file = args.data_file @@ -49,13 +60,14 @@ feature_shape = items[2] feat_len = float(feature_shape.split(':')[1].split(',')[0]) token_len = float(token_shape.split(':')[1].split(',')[0]) - condition = [feat_len > min_input_len, - feat_len < max_input_len, - token_len > min_output_len, - token_len < max_output_len, - token_len / feat_len > min_output_input_ratio, - token_len / feat_len < max_output_input_ratio, - ] + condition = [ + feat_len > min_input_len, + feat_len < max_input_len, + token_len > min_output_len, + token_len < max_output_len, + token_len / feat_len > min_output_input_ratio, + token_len / feat_len < max_output_input_ratio, + ] if all(condition): fout.write('{}\n'.format(l)) continue diff --git a/tools/segment.py b/tools/segment.py index a1a7f93a0..c0d77c41e 100755 --- a/tools/segment.py +++ b/tools/segment.py @@ -32,4 +32,5 @@ item = l.strip().split() if item[1] in wav_dic: item[1] = wav_dic[item[1]] - f.write("{} {},{},{}\n".format(item[0], item[1], item[2], item[3])) + f.write("{} {},{},{}\n".format(item[0], item[1], item[2], + item[3])) diff --git a/tools/spm_decode b/tools/spm_decode index 882b4f966..c6ad5fd4f 100755 --- a/tools/spm_decode +++ b/tools/spm_decode @@ -15,19 +15,24 @@ import sentencepiece as spm def main(): parser = argparse.ArgumentParser() - parser.add_argument("--model", required=True, + parser.add_argument("--model", + required=True, help="sentencepiece model to use for decoding") parser.add_argument("--input", default=None, help="input file to decode") - parser.add_argument("--input_format", choices=["piece", "id"], default="piece") + parser.add_argument("--input_format", + choices=["piece", "id"], + default="piece") args = parser.parse_args() sp = spm.SentencePieceProcessor() sp.Load(args.model) if args.input_format == "piece": + def decode(l): return "".join(sp.DecodePieces(l)) elif args.input_format == "id": + def decode(l): return "".join(sp.DecodeIds(l)) else: diff --git a/tools/spm_encode b/tools/spm_encode index 4dd2e1004..f22bd1833 100755 --- a/tools/spm_encode +++ b/tools/spm_encode @@ -16,16 +16,27 @@ import sentencepiece as spm def main(): parser = argparse.ArgumentParser() - parser.add_argument("--model", required=True, + parser.add_argument("--model", + required=True, help="sentencepiece model to use for encoding") - parser.add_argument("--inputs", nargs="+", default=['-'], + parser.add_argument("--inputs", + nargs="+", + default=['-'], help="input files to filter/encode") - parser.add_argument("--outputs", nargs="+", default=['-'], + parser.add_argument("--outputs", + nargs="+", + default=['-'], help="path to save encoded outputs") - parser.add_argument("--output_format", choices=["piece", "id"], default="piece") - parser.add_argument("--min-len", type=int, metavar="N", + parser.add_argument("--output_format", + choices=["piece", "id"], + default="piece") + parser.add_argument("--min-len", + type=int, + metavar="N", help="filter sentence pairs with fewer than N tokens") - parser.add_argument("--max-len", type=int, metavar="N", + parser.add_argument("--max-len", + type=int, + metavar="N", help="filter sentence pairs with more than N tokens") args = parser.parse_args() @@ -36,34 +47,34 @@ def main(): sp.Load(args.model) if args.output_format == "piece": + def encode(l): return sp.EncodeAsPieces(l) elif args.output_format == "id": + def encode(l): return list(map(str, sp.EncodeAsIds(l))) else: raise NotImplementedError if args.min_len is not None or args.max_len is not None: + def valid(line): - return ( - (args.min_len is None or len(line) >= args.min_len) and - (args.max_len is None or len(line) <= args.max_len) - ) + return ((args.min_len is None or len(line) >= args.min_len) + and (args.max_len is None or len(line) <= args.max_len)) else: + def valid(lines): return True with contextlib.ExitStack() as stack: inputs = [ stack.enter_context(open(input, "r", encoding="utf-8")) - if input != "-" else sys.stdin - for input in args.inputs + if input != "-" else sys.stdin for input in args.inputs ] outputs = [ stack.enter_context(open(output, "w", encoding="utf-8")) - if output != "-" else sys.stdout - for output in args.outputs + if output != "-" else sys.stdout for output in args.outputs ] stats = { @@ -91,8 +102,10 @@ def main(): if i % 10000 == 0: print("processed {} lines".format(i), file=sys.stderr) - print("skipped {} empty lines".format(stats["num_empty"]), file=sys.stderr) - print("filtered {} lines".format(stats["num_filtered"]), file=sys.stderr) + print("skipped {} empty lines".format(stats["num_empty"]), + file=sys.stderr) + print("filtered {} lines".format(stats["num_filtered"]), + file=sys.stderr) if __name__ == "__main__": diff --git a/tools/spm_train b/tools/spm_train index 0b247aee0..134a0b1d3 100755 --- a/tools/spm_train +++ b/tools/spm_train @@ -8,6 +8,5 @@ import sys import sentencepiece as spm - if __name__ == "__main__": spm.SentencePieceTrainer.Train(" ".join(sys.argv[1:])) diff --git a/tools/ssh_launcher.py b/tools/ssh_launcher.py index 757e8d5ec..8d1318e7b 100644 --- a/tools/ssh_launcher.py +++ b/tools/ssh_launcher.py @@ -9,8 +9,6 @@ import socket import subprocess from threading import Thread - - """ Requirements: @@ -125,11 +123,8 @@ def run(prog): try: subprocess.check_call(prog, shell=True) except subprocess.CalledProcessError as e: - logging.info( - "subprocess({}) failed({})! {}".format( - e.cmd, e.returncode, e.output - ) - ) + logging.info("subprocess({}) failed({})! {}".format( + e.cmd, e.returncode, e.output)) os._exit(-1) pass_envs = os.environ.copy() @@ -145,12 +140,8 @@ def run(prog): prog = ( "ssh -o StrictHostKeyChecking=no " # + ssh_port_arg # no port available in aidi - + node - + " '" - + prog - + "'" - ) - thread = Thread(target=run, args=(prog,)) + + node + " '" + prog + "'") + thread = Thread(target=run, args=(prog, )) thread.setDaemon(True) thread.start() thread_list.append(thread) @@ -170,9 +161,10 @@ def run(prog): required=True, help="number of worker process to be launched", ) - parser.add_argument( - "-H", "--hostfile", type=str, help="the hostfile of workers" - ) + parser.add_argument("-H", + "--hostfile", + type=str, + help="the hostfile of workers") parser.add_argument( "-p", "--port", @@ -187,9 +179,9 @@ def run(prog): default=443, help="the port used for ssh connect, used when distribute-training", ) - parser.add_argument( - "command", nargs="+", help="command for plugin program" - ) + parser.add_argument("command", + nargs="+", + help="command for plugin program") args = parser.parse_args() cmd = " ".join(args.command) submit(args.nworker, args.hostfile, args.port, args.sshport, cmd) diff --git a/tools/text2token.py b/tools/text2token.py index 4f4dcc901..a230d6352 100755 --- a/tools/text2token.py +++ b/tools/text2token.py @@ -27,12 +27,14 @@ def exist_or_not(i, match_pos): return start_pos, end_pos + def seg_char(sent): pattern = re.compile(r'([\u4e00-\u9fa5])') chars = pattern.split(sent) chars = [w for w in chars if len(w.strip()) > 0] return chars + def get_parser(): parser = argparse.ArgumentParser( description='convert raw text to tokenized text', diff --git a/tools/wav2dur.py b/tools/wav2dur.py index 1bcc1b693..b53a7fe1d 100755 --- a/tools/wav2dur.py +++ b/tools/wav2dur.py @@ -4,6 +4,7 @@ import sys import torchaudio + torchaudio.set_audio_backend("sox_io") scp = sys.argv[1] diff --git a/tools/websocket/performance-ws.py b/tools/websocket/performance-ws.py index af77dea06..0810f80f6 100755 --- a/tools/websocket/performance-ws.py +++ b/tools/websocket/performance-ws.py @@ -24,15 +24,12 @@ import soundfile as sf import statistics - WS_START = json.dumps({ 'signal': 'start', 'nbest': 1, 'continuous_decoding': False, }) -WS_END = json.dumps({ - 'signal': 'end' -}) +WS_END = json.dumps({'signal': 'end'}) async def ws_rec(data, ws_uri): @@ -73,20 +70,27 @@ async def ws_rec(data, ws_uri): def get_args(): parser = argparse.ArgumentParser(description='') parser.add_argument( - '-u', '--ws_uri', required=True, + '-u', + '--ws_uri', + required=True, help="websocket_server_main's uri, e.g. ws://127.0.0.1:10086") - parser.add_argument( - '-w', '--wav_scp', required=True, - help='path to wav_scp_file') - parser.add_argument( - '-t', '--trans', required=True, - help='path to trans_text_file of wavs') - parser.add_argument( - '-s', '--save_to', required=True, - help='path to save transcription') - parser.add_argument( - '-n', '--num_concurrence', type=int, required=True, - help='num of concurrence for query') + parser.add_argument('-w', + '--wav_scp', + required=True, + help='path to wav_scp_file') + parser.add_argument('-t', + '--trans', + required=True, + help='path to trans_text_file of wavs') + parser.add_argument('-s', + '--save_to', + required=True, + help='path to save transcription') + parser.add_argument('-n', + '--num_concurrence', + type=int, + required=True, + help='num of concurrence for query') args = parser.parse_args() return args diff --git a/wenet/bin/alignment.py b/wenet/bin/alignment.py index 810d56e7b..1d510a9d8 100644 --- a/wenet/bin/alignment.py +++ b/wenet/bin/alignment.py @@ -74,9 +74,9 @@ def get_frames_timestamp(alignment, end += 1 local_start = end - 1 # find the possible front border for current token - while local_start >= start and (prob[local_start][0] < math.log( - blank_thres) or prob[local_start][alignment[ - end - 1]] > math.log(thres)): + while local_start >= start and ( + prob[local_start][0] < math.log(blank_thres) + or prob[local_start][alignment[end - 1]] > math.log(thres)): alignment[local_start] = alignment[end - 1] local_start -= 1 cur_alignment = alignment[start:end] @@ -183,7 +183,6 @@ def get_labformat(timestamp, subsample): char_dict[int(arr[1])] = arr[0] eos = len(char_dict) - 1 - # Init dataset and data loader ali_conf = copy.deepcopy(configs['dataset_conf']) @@ -202,7 +201,8 @@ def get_labformat(timestamp, subsample): ali_conf['batch_conf']['batch_type'] = "static" ali_conf['batch_conf']['batch_size'] = args.batch_size - tokenizer = init_tokenizer(ali_conf, args.dict, args.bpe_model, args.non_lang_syms) + tokenizer = init_tokenizer(ali_conf, args.dict, args.bpe_model, + args.non_lang_syms) ali_dataset = Dataset(args.data_type, args.input_file, tokenizer, diff --git a/wenet/bin/average_model.py b/wenet/bin/average_model.py index 9163b3b8f..a99e69884 100644 --- a/wenet/bin/average_model.py +++ b/wenet/bin/average_model.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import os import argparse import glob diff --git a/wenet/bin/export_ipex.py b/wenet/bin/export_ipex.py index bf76acd35..1d3ff181c 100644 --- a/wenet/bin/export_ipex.py +++ b/wenet/bin/export_ipex.py @@ -14,6 +14,7 @@ import intel_extension_for_pytorch as ipex from intel_extension_for_pytorch.quantization import prepare, convert + def get_args(): parser = argparse.ArgumentParser(description='export your script model') parser.add_argument('--config', required=True, help='config file') @@ -28,22 +29,21 @@ def get_args(): args = parser.parse_args() return args + def scripting(model): with torch.inference_mode(): script_model = torch.jit.script(model) script_model = torch.jit.freeze( script_model, - preserved_attrs=["forward_encoder_chunk", - "ctc_activation", - "forward_attention_decoder", - "subsampling_rate", - "right_context", - "sos_symbol", - "eos_symbol", - "is_bidirectional_decoder"] - ) + preserved_attrs=[ + "forward_encoder_chunk", "ctc_activation", + "forward_attention_decoder", "subsampling_rate", + "right_context", "sos_symbol", "eos_symbol", + "is_bidirectional_decoder" + ]) return script_model + def main(): args = get_args() logging.basicConfig(level=logging.DEBUG, @@ -62,8 +62,10 @@ def main(): model.to(memory_format=torch.channels_last) if args.dtype == "fp32": ipex_model = ipex.optimize(model) - elif args.dtype == "bf16": # For Intel 4th generation Xeon (SPR) - ipex_model = ipex.optimize(model, dtype=torch.bfloat16, weights_prepack=False) + elif args.dtype == "bf16": # For Intel 4th generation Xeon (SPR) + ipex_model = ipex.optimize(model, + dtype=torch.bfloat16, + weights_prepack=False) # Export jit torch script model if args.output_file: @@ -79,11 +81,8 @@ def main(): # Export quantized jit torch script model if args.output_quant_file: dynamic_qconfig = ipex.quantization.default_dynamic_qconfig - dummy_data = (torch.zeros(1, 67, 80), - 16, - -16, - torch.zeros(12, 4, 32, 128), - torch.zeros(12, 1, 256, 7)) + dummy_data = (torch.zeros(1, 67, 80), 16, -16, + torch.zeros(12, 4, 32, 128), torch.zeros(12, 1, 256, 7)) model = prepare(model, dynamic_qconfig, dummy_data) model = convert(model) script_quant_model = scripting(model) diff --git a/wenet/bin/export_jit.py b/wenet/bin/export_jit.py index ff4c79fce..70f53586b 100644 --- a/wenet/bin/export_jit.py +++ b/wenet/bin/export_jit.py @@ -58,8 +58,7 @@ def main(): # Export quantized jit torch script model if args.output_quant_file: quantized_model = torch.quantization.quantize_dynamic( - model, {torch.nn.Linear}, dtype=torch.qint8 - ) + model, {torch.nn.Linear}, dtype=torch.qint8) print(quantized_model) script_quant_model = torch.jit.script(quantized_model) script_quant_model.save(args.output_quant_file) diff --git a/wenet/bin/export_onnx_bpu.py b/wenet/bin/export_onnx_bpu.py index 3aaacffc9..a1d93a022 100644 --- a/wenet/bin/export_onnx_bpu.py +++ b/wenet/bin/export_onnx_bpu.py @@ -34,7 +34,6 @@ 2. specific decoding method: ctc_greedy_search """ - from __future__ import print_function import os @@ -53,7 +52,6 @@ from wenet.bin.export_onnx_cpu import (get_args, to_numpy, print_input_output_info) - try: import onnx import onnxruntime @@ -61,13 +59,13 @@ print('Please install onnx and onnxruntime!') sys.exit(1) - logger = logging.getLogger(__file__) logger.setLevel(logging.INFO) class BPULayerNorm(torch.nn.Module): """Refactor torch.nn.LayerNorm to meet 4-D dataflow.""" + def __init__(self, module, chunk_size=8, run_on_bpu=False): super().__init__() original = copy.deepcopy(module) @@ -77,21 +75,23 @@ def __init__(self, module, chunk_size=8, run_on_bpu=False): if self.run_on_bpu: self.weight = torch.nn.Parameter( - module.weight.reshape(1, self.hidden, 1, 1).repeat( - 1, 1, 1, chunk_size)) + module.weight.reshape(1, self.hidden, 1, + 1).repeat(1, 1, 1, chunk_size)) self.bias = torch.nn.Parameter( - module.bias.reshape(1, self.hidden, 1, 1).repeat( - 1, 1, 1, chunk_size)) + module.bias.reshape(1, self.hidden, 1, + 1).repeat(1, 1, 1, chunk_size)) self.negtive = torch.nn.Parameter( torch.ones((1, self.hidden, 1, chunk_size)) * -1.0) self.eps = torch.nn.Parameter( torch.zeros((1, self.hidden, 1, chunk_size)) + module.eps) self.mean_conv_1 = torch.nn.Conv2d(self.hidden, 1, 1, bias=False) self.mean_conv_1.weight = torch.nn.Parameter( - torch.ones(self.hidden, self.hidden, 1, 1) / (1.0 * self.hidden)) + torch.ones(self.hidden, self.hidden, 1, 1) / + (1.0 * self.hidden)) self.mean_conv_2 = torch.nn.Conv2d(self.hidden, 1, 1, bias=False) self.mean_conv_2.weight = torch.nn.Parameter( - torch.ones(self.hidden, self.hidden, 1, 1) / (1.0 * self.hidden)) + torch.ones(self.hidden, self.hidden, 1, 1) / + (1.0 * self.hidden)) else: self.norm = module @@ -101,9 +101,11 @@ def check_equal(self, module): random_data = torch.randn(1, self.chunk_size, self.hidden) orig_out = module(random_data) new_out = self.forward(random_data.transpose(1, 2).unsqueeze(2)) - np.testing.assert_allclose( - to_numpy(orig_out), to_numpy(new_out.squeeze(2).transpose(1, 2)), - rtol=1e-02, atol=1e-03) + np.testing.assert_allclose(to_numpy(orig_out), + to_numpy( + new_out.squeeze(2).transpose(1, 2)), + rtol=1e-02, + atol=1e-03) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.run_on_bpu: @@ -125,22 +127,26 @@ class BPUIdentity(torch.nn.Module): """Refactor torch.nn.Identity(). For inserting BPU node whose input == output. """ + def __init__(self, channels): super().__init__() self.channels = channels - self.identity_conv = torch.nn.Conv2d( - channels, channels, 1, groups=channels, bias=False) - torch.nn.init.dirac_( - self.identity_conv.weight.data, groups=channels) + self.identity_conv = torch.nn.Conv2d(channels, + channels, + 1, + groups=channels, + bias=False) + torch.nn.init.dirac_(self.identity_conv.weight.data, groups=channels) self.check_equal() def check_equal(self): random_data = torch.randn(1, self.channels, 1, 10) result = self.forward(random_data) - np.testing.assert_allclose( - to_numpy(random_data), to_numpy(result), - rtol=1e-02, atol=1e-03) + np.testing.assert_allclose(to_numpy(random_data), + to_numpy(result), + rtol=1e-02, + atol=1e-03) def forward(self, x: torch.Tensor) -> torch.Tensor: """Identity with 4-D dataflow, input == output. @@ -155,6 +161,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class BPULinear(torch.nn.Module): """Refactor torch.nn.Linear or pointwise_conv""" + def __init__(self, module, is_pointwise_conv=False): super().__init__() # Unchanged submodules and attributes @@ -187,10 +194,11 @@ def check_equal(self, module): original_result = original_result.transpose(1, 2) random_data = random_data.transpose(1, 2).unsqueeze(2) new_result = self.forward(random_data) - np.testing.assert_allclose( - to_numpy(original_result), - to_numpy(new_result.squeeze(2).transpose(1, 2)), - rtol=1e-02, atol=1e-03) + np.testing.assert_allclose(to_numpy(original_result), + to_numpy( + new_result.squeeze(2).transpose(1, 2)), + rtol=1e-02, + atol=1e-03) def forward(self, x: torch.Tensor) -> torch.Tensor: """Linear with 4-D dataflow. @@ -204,6 +212,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class BPUGlobalCMVN(torch.nn.Module): """Refactor wenet/transformer/cmvn.py::GlobalCMVN""" + def __init__(self, module): super().__init__() # Unchanged submodules and attributes @@ -231,6 +240,7 @@ class BPUConv2dSubsampling8(torch.nn.Module): NOTE(xcsong): Only support pos_enc_class == NoPositionalEncoding """ + def __init__(self, module): super().__init__() # Unchanged submodules and attributes @@ -245,8 +255,7 @@ def __init__(self, module): self.conv = module.conv for idx in [0, 2, 4]: self.conv[idx].weight = torch.nn.Parameter( - module.conv[idx].weight.transpose(2, 3) - ) + module.conv[idx].weight.transpose(2, 3)) # 2. Modify self.linear # NOTE(xcsong): Split final projection to meet the requirment of @@ -256,20 +265,18 @@ def __init__(self, module): freq = module.linear.weight.size(1) // odim # 4608 // 512 == 9 self.odim, self.freq = odim, freq weight = module.linear.weight.reshape( - odim, odim, freq, 1) # (odim, odim * freq) -> (odim, odim, freq, 1) + odim, odim, freq, + 1) # (odim, odim * freq) -> (odim, odim, freq, 1) self.split_size = [] num_split = (freq - 1) // 7 + 1 # XJ3 requires kernel_size <= 7 slice_begin = 0 for idx in range(num_split): kernel_size = min(freq, (idx + 1) * 7) - idx * 7 - conv_ele = torch.nn.Conv2d( - odim, odim, (kernel_size, 1), (kernel_size, 1)) + conv_ele = torch.nn.Conv2d(odim, odim, (kernel_size, 1), + (kernel_size, 1)) conv_ele.weight = torch.nn.Parameter( - weight[:, :, slice_begin:slice_begin + kernel_size, :] - ) - conv_ele.bias = torch.nn.Parameter( - torch.zeros_like(conv_ele.bias) - ) + weight[:, :, slice_begin:slice_begin + kernel_size, :]) + conv_ele.bias = torch.nn.Parameter(torch.zeros_like(conv_ele.bias)) self.linear.append(conv_ele) self.split_size.append(kernel_size) slice_begin += kernel_size @@ -281,12 +288,14 @@ def check_equal(self, module): random_data = torch.randn(1, 67, 80) mask = torch.zeros(1, 1, 67) original_result, _, _ = module(random_data, mask) # (1, 8, 512) - random_data = random_data.transpose(1, 2).unsqueeze(0) # (1, 1, 80, 67) + random_data = random_data.transpose(1, + 2).unsqueeze(0) # (1, 1, 80, 67) new_result = self.forward(random_data) # (1, 512, 1, 8) - np.testing.assert_allclose( - to_numpy(original_result), - to_numpy(new_result.squeeze(2).transpose(1, 2)), - rtol=1e-02, atol=1e-03) + np.testing.assert_allclose(to_numpy(original_result), + to_numpy( + new_result.squeeze(2).transpose(1, 2)), + rtol=1e-02, + atol=1e-03) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x with 4-D dataflow. @@ -311,6 +320,7 @@ class BPUMultiHeadedAttention(torch.nn.Module): NOTE(xcsong): Only support attention_class == MultiHeadedAttention, we do not consider RelPositionMultiHeadedAttention currently. """ + def __init__(self, module, chunk_size, left_chunks): super().__init__() # Unchanged submodules and attributes @@ -340,26 +350,31 @@ def check_equal(self, module): dtype=torch.bool) cache = torch.zeros(1, self.h, self.chunk_size * self.left_chunks, self.d_k * 2) - original_out, original_cache = module( - random_data, random_data, random_data, - mask[:, 0, :, :], torch.empty(0), cache) + original_out, original_cache = module(random_data, random_data, + random_data, mask[:, 0, :, :], + torch.empty(0), cache) random_data = random_data.transpose(1, 2).unsqueeze(2) cache = cache.reshape(1, self.h, self.d_k * 2, self.chunk_size * self.left_chunks) - new_out, new_cache = self.forward( - random_data, random_data, random_data, mask, cache) - np.testing.assert_allclose( - to_numpy(original_out), - to_numpy(new_out.squeeze(2).transpose(1, 2)), - rtol=1e-02, atol=1e-03) - np.testing.assert_allclose( - to_numpy(original_cache), - to_numpy(new_cache.transpose(2, 3)), - rtol=1e-02, atol=1e-03) + new_out, new_cache = self.forward(random_data, random_data, + random_data, mask, cache) + np.testing.assert_allclose(to_numpy(original_out), + to_numpy( + new_out.squeeze(2).transpose(1, 2)), + rtol=1e-02, + atol=1e-03) + np.testing.assert_allclose(to_numpy(original_cache), + to_numpy(new_cache.transpose(2, 3)), + rtol=1e-02, + atol=1e-03) def forward( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - mask: torch.Tensor, cache: torch.Tensor, + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + mask: torch.Tensor, + cache: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute scaled dot product attention. @@ -410,6 +425,7 @@ class BPUConvolution(torch.nn.Module): NOTE(xcsong): Only suport use_layer_norm == False """ + def __init__(self, module): super().__init__() # Unchanged submodules and attributes @@ -426,9 +442,10 @@ def __init__(self, module): self.pointwise_conv1 = BPULinear(module.pointwise_conv1, True) # 2. Modify self.depthwise_conv - self.depthwise_conv = torch.nn.Conv2d( - channels, channels, (1, kernel_size), - stride=1, groups=channels) + self.depthwise_conv = torch.nn.Conv2d(channels, + channels, (1, kernel_size), + stride=1, + groups=channels) self.depthwise_conv.weight = torch.nn.Parameter( module.depthwise_conv.weight.unsqueeze(-2)) self.depthwise_conv.bias = torch.nn.Parameter( @@ -460,18 +477,18 @@ def check_equal(self, module): random_data = random_data.transpose(1, 2).unsqueeze(2) cache = cache.unsqueeze(2) new_out, new_cache = self.forward(random_data, cache) - np.testing.assert_allclose( - to_numpy(original_out), - to_numpy(new_out.squeeze(2).transpose(1, 2)), - rtol=1e-02, atol=1e-03) - np.testing.assert_allclose( - to_numpy(original_cache), - to_numpy(new_cache.squeeze(2)), - rtol=1e-02, atol=1e-03) - - def forward( - self, x: torch.Tensor, cache: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: + np.testing.assert_allclose(to_numpy(original_out), + to_numpy( + new_out.squeeze(2).transpose(1, 2)), + rtol=1e-02, + atol=1e-03) + np.testing.assert_allclose(to_numpy(original_cache), + to_numpy(new_cache.squeeze(2)), + rtol=1e-02, + atol=1e-03) + + def forward(self, x: torch.Tensor, + cache: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Compute convolution module. Args: x (torch.Tensor): Input tensor (#batch, channels, 1, chunk_size). @@ -499,6 +516,7 @@ def forward( class BPUFFN(torch.nn.Module): """Refactor wenet/transformer/positionwise_feed_forward.py::PositionwiseFeedForward """ + def __init__(self, module): super().__init__() # Unchanged submodules and attributes @@ -516,10 +534,11 @@ def check_equal(self, module): original_out = module(random_data) random_data = random_data.transpose(1, 2).unsqueeze(2) new_out = self.forward(random_data) - np.testing.assert_allclose( - to_numpy(original_out), - to_numpy(new_out.squeeze(2).transpose(1, 2)), - rtol=1e-02, atol=1e-03) + np.testing.assert_allclose(to_numpy(original_out), + to_numpy( + new_out.squeeze(2).transpose(1, 2)), + rtol=1e-02, + atol=1e-03) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward function. @@ -535,6 +554,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class BPUConformerEncoderLayer(torch.nn.Module): """Refactor wenet/transformer/encoder_layer.py::ConformerEncoderLayer """ + def __init__(self, module, chunk_size, left_chunks, ln_run_on_bpu=False): super().__init__() # Unchanged submodules and attributes @@ -545,24 +565,25 @@ def __init__(self, module, chunk_size, left_chunks, ln_run_on_bpu=False): # 1. Modify submodules self.feed_forward_macaron = BPUFFN(module.feed_forward_macaron) - self.self_attn = BPUMultiHeadedAttention( - module.self_attn, chunk_size, left_chunks) + self.self_attn = BPUMultiHeadedAttention(module.self_attn, chunk_size, + left_chunks) self.conv_module = BPUConvolution(module.conv_module) self.feed_forward = BPUFFN(module.feed_forward) # 2. Modify norms self.norm_ff = BPULayerNorm(module.norm_ff, chunk_size, ln_run_on_bpu) - self.norm_mha = BPULayerNorm(module.norm_mha, chunk_size, ln_run_on_bpu) - self.norm_ff_macron = BPULayerNorm(module.norm_ff_macaron, - chunk_size, ln_run_on_bpu) - self.norm_conv = BPULayerNorm(module.norm_conv, - chunk_size, ln_run_on_bpu) - self.norm_final = BPULayerNorm(module.norm_final, - chunk_size, ln_run_on_bpu) + self.norm_mha = BPULayerNorm(module.norm_mha, chunk_size, + ln_run_on_bpu) + self.norm_ff_macron = BPULayerNorm(module.norm_ff_macaron, chunk_size, + ln_run_on_bpu) + self.norm_conv = BPULayerNorm(module.norm_conv, chunk_size, + ln_run_on_bpu) + self.norm_final = BPULayerNorm(module.norm_final, chunk_size, + ln_run_on_bpu) # 3. 4-D ff_scale - self.register_buffer( - "ff_scale", torch.full((1, self.size, 1, 1), module.ff_scale)) + self.register_buffer("ff_scale", + torch.full((1, self.size, 1, 1), module.ff_scale)) self.check_equal(original) @@ -575,31 +596,32 @@ def check_equal(self, module): att_cache = torch.zeros(1, h, time2 - time1, d_k * 2) cnn_cache = torch.zeros(1, self.size, self.conv_module.lorder) original_x, _, original_att_cache, original_cnn_cache = module( - random_x, att_mask[:, 0, :, :], torch.empty(0), - att_cache=att_cache, cnn_cache=cnn_cache - ) + random_x, + att_mask[:, 0, :, :], + torch.empty(0), + att_cache=att_cache, + cnn_cache=cnn_cache) random_x = random_x.transpose(1, 2).unsqueeze(2) att_cache = att_cache.reshape(1, h, d_k * 2, time2 - time1) cnn_cache = cnn_cache.unsqueeze(2) new_x, new_att_cache, new_cnn_cache = self.forward( - random_x, att_mask, att_cache, cnn_cache - ) - np.testing.assert_allclose( - to_numpy(original_att_cache), - to_numpy(new_att_cache.transpose(2, 3)), - rtol=1e-02, atol=1e-03) - np.testing.assert_allclose( - to_numpy(original_x), - to_numpy(new_x.squeeze(2).transpose(1, 2)), - rtol=1e-02, atol=1e-03) - np.testing.assert_allclose( - to_numpy(original_cnn_cache), - to_numpy(new_cnn_cache.squeeze(2)), - rtol=1e-02, atol=1e-03) + random_x, att_mask, att_cache, cnn_cache) + np.testing.assert_allclose(to_numpy(original_att_cache), + to_numpy(new_att_cache.transpose(2, 3)), + rtol=1e-02, + atol=1e-03) + np.testing.assert_allclose(to_numpy(original_x), + to_numpy(new_x.squeeze(2).transpose(1, 2)), + rtol=1e-02, + atol=1e-03) + np.testing.assert_allclose(to_numpy(original_cnn_cache), + to_numpy(new_cnn_cache.squeeze(2)), + rtol=1e-02, + atol=1e-03) def forward( - self, x: torch.Tensor, att_mask: torch.Tensor, - att_cache: torch.Tensor, cnn_cache: torch.Tensor + self, x: torch.Tensor, att_mask: torch.Tensor, att_cache: torch.Tensor, + cnn_cache: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Compute encoded features. @@ -625,8 +647,7 @@ def forward( # 2. attention residual = x x = self.norm_mha(x) - x_att, new_att_cache = self.self_attn( - x, x, x, att_mask, att_cache) + x_att, new_att_cache = self.self_attn(x, x, x, att_mask, att_cache) x = residual + x_att # 3. convolution @@ -649,6 +670,7 @@ def forward( class BPUConformerEncoder(torch.nn.Module): """Refactor wenet/transformer/encoder.py::ConformerEncoder """ + def __init__(self, module, chunk_size, left_chunks, ln_run_on_bpu=False): super().__init__() # Unchanged submodules and attributes @@ -666,8 +688,9 @@ def __init__(self, module, chunk_size, left_chunks, ln_run_on_bpu=False): self.embed = BPUConv2dSubsampling8(module.embed) self.encoders = torch.nn.ModuleList() for layer in module.encoders: - self.encoders.append(BPUConformerEncoderLayer( - layer, chunk_size, left_chunks, ln_run_on_bpu)) + self.encoders.append( + BPUConformerEncoderLayer(layer, chunk_size, left_chunks, + ln_run_on_bpu)) # 2. Auxiliary conv self.identity_cnncache = BPUIdentity(output_size) @@ -688,29 +711,32 @@ def check_equal(self, module): att_cache = torch.zeros(layers, h, time2 - time1, d_k * 2) cnn_cache = torch.zeros(layers, 1, self._output_size, lorder) orig_x, orig_att_cache, orig_cnn_cache = module.forward_chunk( - random_x, 0, time2 - time1, att_mask=att_mask[:, 0, :, :], - att_cache=att_cache, cnn_cache=cnn_cache - ) + random_x, + 0, + time2 - time1, + att_mask=att_mask[:, 0, :, :], + att_cache=att_cache, + cnn_cache=cnn_cache) random_x = random_x.unsqueeze(0) att_cache = att_cache.reshape(1, h * layers, d_k * 2, time2 - time1) cnn_cache = cnn_cache.reshape(1, self._output_size, layers, lorder) new_x, new_att_cache, new_cnn_cache = self.forward( - random_x, att_cache, cnn_cache, att_mask - ) + random_x, att_cache, cnn_cache, att_mask) caches = torch.split(new_att_cache, h, dim=1) caches = [c.transpose(2, 3) for c in caches] - np.testing.assert_allclose( - to_numpy(orig_att_cache), - to_numpy(torch.cat(caches, dim=0)), - rtol=1e-02, atol=1e-03) - np.testing.assert_allclose( - to_numpy(orig_x), - to_numpy(new_x.squeeze(2).transpose(1, 2)), - rtol=1e-02, atol=1e-03) + np.testing.assert_allclose(to_numpy(orig_att_cache), + to_numpy(torch.cat(caches, dim=0)), + rtol=1e-02, + atol=1e-03) + np.testing.assert_allclose(to_numpy(orig_x), + to_numpy(new_x.squeeze(2).transpose(1, 2)), + rtol=1e-02, + atol=1e-03) np.testing.assert_allclose( to_numpy(orig_cnn_cache), to_numpy(new_cnn_cache.transpose(0, 2).transpose(1, 2)), - rtol=1e-02, atol=1e-03) + rtol=1e-02, + atol=1e-03) def forward( self, xs: torch.Tensor, att_cache: torch.Tensor, @@ -753,13 +779,14 @@ def forward( r_att_cache = [] r_cnn_cache = [] for i, layer in enumerate(self.encoders): - xs, new_att_cache, new_cnn_cache = layer( - xs, att_mask, att_cache=att_cache[i], cnn_cache=cnn_cache[i]) + xs, new_att_cache, new_cnn_cache = layer(xs, + att_mask, + att_cache=att_cache[i], + cnn_cache=cnn_cache[i]) r_att_cache.append(new_att_cache[:, :, :, self.chunk_size:]) r_cnn_cache.append(new_cnn_cache) r_att_cache = torch.cat(r_att_cache, dim=1) - r_cnn_cache = self.identity_cnncache( - torch.cat(r_cnn_cache, dim=2)) + r_cnn_cache = self.identity_cnncache(torch.cat(r_cnn_cache, dim=2)) xs = xs.squeeze(2).transpose(1, 2).contiguous() xs = self.after_norm(xs) @@ -772,6 +799,7 @@ def forward( class BPUCTC(torch.nn.Module): """Refactor wenet/transformer/ctc.py::CTC """ + def __init__(self, module): super().__init__() # Unchanged submodules and attributes @@ -803,10 +831,11 @@ def check_equal(self, module): original_result = module.ctc_lo(random_data) random_data = random_data.transpose(1, 2).unsqueeze(2) new_result = self.forward(random_data) - np.testing.assert_allclose( - to_numpy(original_result), - to_numpy(new_result.squeeze(2).transpose(1, 2)), - rtol=1e-02, atol=1e-03) + np.testing.assert_allclose(to_numpy(original_result), + to_numpy( + new_result.squeeze(2).transpose(1, 2)), + rtol=1e-02, + atol=1e-03) def forward(self, x: torch.Tensor) -> torch.Tensor: """frame activations, without softmax. @@ -826,9 +855,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def export_encoder(asr_model, args): logger.info("Stage-1: export encoder") decode_window, mel_dim = args.decoding_window, args.feature_size - encoder = BPUConformerEncoder( - asr_model.encoder, args.chunk_size, args.num_decoding_left_chunks, - args.ln_run_on_bpu) + encoder = BPUConformerEncoder(asr_model.encoder, args.chunk_size, + args.num_decoding_left_chunks, + args.ln_run_on_bpu) encoder.eval() encoder_outpath = os.path.join(args.output_dir, 'encoder.onnx') @@ -870,11 +899,16 @@ def export_encoder(asr_model, args): att_mask.size(1), att_mask.size(2), att_mask.size(3) ) torch.onnx.export( # NOTE(xcsong): only support opset==11 - encoder, inputs, encoder_outpath, opset_version=11, - export_params=True, do_constant_folding=True, + encoder, + inputs, + encoder_outpath, + opset_version=11, + export_params=True, + do_constant_folding=True, input_names=attributes['input_name'].split(';'), output_names=attributes['output_name'].split(';'), - dynamic_axes=None, verbose=False) + dynamic_axes=None, + verbose=False) onnx_encoder = onnx.load(encoder_outpath) for k in vars(args): meta = onnx_encoder.metadata_props.add() @@ -895,11 +929,10 @@ def export_encoder(asr_model, args): torch_cnn_cache = copy.deepcopy(cnn_cache) for i in range(10): logger.info("torch chunk-{}: {}, att_cache: {}, cnn_cache: {}" - ", att_mask: {}".format( - i, list(torch_chunk.size()), - list(torch_att_cache.size()), - list(torch_cnn_cache.size()), - list(torch_att_mask.size()))) + ", att_mask: {}".format(i, list(torch_chunk.size()), + list(torch_att_cache.size()), + list(torch_cnn_cache.size()), + list(torch_att_mask.size()))) torch_att_mask[:, :, :, -(encoder.chunk_size * (i + 1)):] = 1 out, torch_att_cache, torch_cnn_cache = encoder( torch_chunk, torch_att_cache, torch_cnn_cache, torch_att_mask) @@ -914,21 +947,26 @@ def export_encoder(asr_model, args): input_names = [node.name for node in onnx_encoder.graph.input] for i in range(10): logger.info("onnx chunk-{}: {}, att_cache: {}, cnn_cache: {}," - " att_mask: {}".format( - i, onnx_chunk.shape, onnx_att_cache.shape, - onnx_cnn_cache.shape, onnx_att_mask.shape)) + " att_mask: {}".format(i, onnx_chunk.shape, + onnx_att_cache.shape, + onnx_cnn_cache.shape, + onnx_att_mask.shape)) onnx_att_mask[:, :, :, -(encoder.chunk_size * (i + 1)):] = 1 ort_inputs = { - 'chunk': onnx_chunk, 'att_cache': onnx_att_cache, - 'cnn_cache': onnx_cnn_cache, 'att_mask': onnx_att_mask, + 'chunk': onnx_chunk, + 'att_cache': onnx_att_cache, + 'cnn_cache': onnx_cnn_cache, + 'att_mask': onnx_att_mask, } ort_outs = ort_session.run(None, ort_inputs) onnx_att_cache, onnx_cnn_cache = ort_outs[1], ort_outs[2] onnx_output.append(ort_outs[0]) onnx_output = np.concatenate(onnx_output, axis=-1) - np.testing.assert_allclose(to_numpy(torch_output), onnx_output, - rtol=1e-03, atol=1e-04) + np.testing.assert_allclose(to_numpy(torch_output), + onnx_output, + rtol=1e-03, + atol=1e-04) meta = ort_session.get_modelmeta() logger.info("custom_metadata_map={}".format(meta.custom_metadata_map)) logger.info("Check onnx_encoder, pass!") @@ -952,13 +990,21 @@ def export_ctc(asr_model, args): attributes['input_layout_train'] = "NCHW" attributes['input_layout_rt'] = "NCHW" attributes['input_shape'] = "{}x{}x{}x{}".format( - hidden.size(0), hidden.size(1), hidden.size(2), hidden.size(3), + hidden.size(0), + hidden.size(1), + hidden.size(2), + hidden.size(3), ) - torch.onnx.export( - ctc, hidden, ctc_outpath, opset_version=11, - export_params=True, do_constant_folding=True, - input_names=['hidden'], output_names=['probs'], - dynamic_axes=None, verbose=False) + torch.onnx.export(ctc, + hidden, + ctc_outpath, + opset_version=11, + export_params=True, + do_constant_folding=True, + input_names=['hidden'], + output_names=['probs'], + dynamic_axes=None, + verbose=False) onnx_ctc = onnx.load(ctc_outpath) for k in vars(args): meta = onnx_ctc.metadata_props.add() @@ -977,8 +1023,10 @@ def export_ctc(asr_model, args): ort_session = onnxruntime.InferenceSession(ctc_outpath) onnx_output = ort_session.run(None, {'hidden': to_numpy(hidden)}) - np.testing.assert_allclose(to_numpy(torch_output), onnx_output[0], - rtol=1e-03, atol=1e-04) + np.testing.assert_allclose(to_numpy(torch_output), + onnx_output[0], + rtol=1e-03, + atol=1e-04) meta = ort_session.get_modelmeta() logger.info("custom_metadata_map={}".format(meta.custom_metadata_map)) logger.info("Check onnx_ctc, pass!") diff --git a/wenet/bin/export_onnx_cpu.py b/wenet/bin/export_onnx_cpu.py index 276e06354..f382545a0 100644 --- a/wenet/bin/export_onnx_cpu.py +++ b/wenet/bin/export_onnx_cpu.py @@ -40,12 +40,18 @@ def get_args(): parser.add_argument('--config', required=True, help='config file') parser.add_argument('--checkpoint', required=True, help='checkpoint model') parser.add_argument('--output_dir', required=True, help='output directory') - parser.add_argument('--chunk_size', required=True, - type=int, help='decoding chunk size') - parser.add_argument('--num_decoding_left_chunks', required=True, - type=int, help='cache chunks') - parser.add_argument('--reverse_weight', default=0.5, - type=float, help='reverse_weight in attention_rescoing') + parser.add_argument('--chunk_size', + required=True, + type=int, + help='decoding chunk size') + parser.add_argument('--num_decoding_left_chunks', + required=True, + type=int, + help='cache chunks') + parser.add_argument('--reverse_weight', + default=0.5, + type=float, + help='reverse_weight in attention_rescoing') args = parser.parse_args() return args @@ -113,16 +119,15 @@ def export_encoder(asr_model, args): elif args['left_chunks'] <= 0: # 16/-1, -1/-1, 16/0 required_cache_size = -1 if args['left_chunks'] < 0 else 0 # Fake cache - att_cache = torch.zeros( - (args['num_blocks'], args['head'], 0, - args['output_size'] // args['head'] * 2)) + att_cache = torch.zeros((args['num_blocks'], args['head'], 0, + args['output_size'] // args['head'] * 2)) # Fake mask att_mask = torch.ones((0, 0, 0), dtype=torch.bool) cnn_cache = torch.zeros( - (args['num_blocks'], args['batch'], - args['output_size'], args['cnn_module_kernel'] - 1)) - inputs = (chunk, offset, required_cache_size, - att_cache, cnn_cache, att_mask) + (args['num_blocks'], args['batch'], args['output_size'], + args['cnn_module_kernel'] - 1)) + inputs = (chunk, offset, required_cache_size, att_cache, cnn_cache, + att_mask) print("\t\tchunk.size(): {}\n".format(chunk.size()), "\t\toffset: {}\n".format(offset), "\t\trequired_cache: {}\n".format(required_cache_size), @@ -132,11 +137,21 @@ def export_encoder(asr_model, args): print("\tStage-1.2: torch.onnx.export") dynamic_axes = { - 'chunk': {1: 'T'}, - 'att_cache': {2: 'T_CACHE'}, - 'att_mask': {2: 'T_ADD_T_CACHE'}, - 'output': {1: 'T'}, - 'r_att_cache': {2: 'T_CACHE'}, + 'chunk': { + 1: 'T' + }, + 'att_cache': { + 2: 'T_CACHE' + }, + 'att_mask': { + 2: 'T_ADD_T_CACHE' + }, + 'output': { + 1: 'T' + }, + 'r_att_cache': { + 2: 'T_CACHE' + }, } # NOTE(xcsong): We keep dynamic axes even if in 16/4 mode, this is # to avoid padding the last chunk (which usually contains less @@ -151,15 +166,19 @@ def export_encoder(asr_model, args): # # be changed. # dynamic_axes.pop('att_cache') # dynamic_axes.pop('r_att_cache') - torch.onnx.export( - encoder, inputs, encoder_outpath, opset_version=13, - export_params=True, do_constant_folding=True, - input_names=[ - 'chunk', 'offset', 'required_cache_size', - 'att_cache', 'cnn_cache', 'att_mask' - ], - output_names=['output', 'r_att_cache', 'r_cnn_cache'], - dynamic_axes=dynamic_axes, verbose=False) + torch.onnx.export(encoder, + inputs, + encoder_outpath, + opset_version=13, + export_params=True, + do_constant_folding=True, + input_names=[ + 'chunk', 'offset', 'required_cache_size', + 'att_cache', 'cnn_cache', 'att_mask' + ], + output_names=['output', 'r_att_cache', 'r_cnn_cache'], + dynamic_axes=dynamic_axes, + verbose=False) onnx_encoder = onnx.load(encoder_outpath) for (k, v) in args.items(): meta = onnx_encoder.metadata_props.add() @@ -188,8 +207,8 @@ def export_encoder(asr_model, args): print("\t\ttorch chunk-{}: {}, offset: {}, att_cache: {}," " cnn_cache: {}, att_mask: {}".format( i, list(torch_chunk.size()), torch_offset, - list(torch_att_cache.size()), - list(torch_cnn_cache.size()), list(torch_att_mask.size()))) + list(torch_att_cache.size()), list(torch_cnn_cache.size()), + list(torch_att_mask.size()))) # NOTE(xsong): att_mask of the first few batches need changes if # we use 16/4 mode. if args['left_chunks'] > 0: # 16/4 @@ -208,22 +227,26 @@ def export_encoder(asr_model, args): onnx_att_cache = to_numpy(att_cache) onnx_cnn_cache = to_numpy(cnn_cache) onnx_att_mask = to_numpy(att_mask) - ort_session = onnxruntime.InferenceSession(encoder_outpath, - providers=['CPUExecutionProvider']) + ort_session = onnxruntime.InferenceSession( + encoder_outpath, providers=['CPUExecutionProvider']) input_names = [node.name for node in onnx_encoder.graph.input] for i in range(10): print("\t\tonnx chunk-{}: {}, offset: {}, att_cache: {}," - " cnn_cache: {}, att_mask: {}".format( - i, onnx_chunk.shape, onnx_offset, onnx_att_cache.shape, - onnx_cnn_cache.shape, onnx_att_mask.shape)) + " cnn_cache: {}, att_mask: {}".format(i, onnx_chunk.shape, + onnx_offset, + onnx_att_cache.shape, + onnx_cnn_cache.shape, + onnx_att_mask.shape)) # NOTE(xsong): att_mask of the first few batches need changes if # we use 16/4 mode. if args['left_chunks'] > 0: # 16/4 onnx_att_mask[:, :, -(args['chunk_size'] * (i + 1)):] = 1 ort_inputs = { - 'chunk': onnx_chunk, 'offset': onnx_offset, + 'chunk': onnx_chunk, + 'offset': onnx_offset, 'required_cache_size': onnx_required_cache_size, - 'att_cache': onnx_att_cache, 'cnn_cache': onnx_cnn_cache, + 'att_cache': onnx_att_cache, + 'cnn_cache': onnx_cnn_cache, 'att_mask': onnx_att_mask } # NOTE(xcsong): If we use 16/-1, -1/-1 or 16/0 mode, `next_cache_start` @@ -239,8 +262,10 @@ def export_encoder(asr_model, args): onnx_offset += ort_outs[0].shape[1] onnx_output = np.concatenate(onnx_output, axis=1) - np.testing.assert_allclose(to_numpy(torch_output), onnx_output, - rtol=1e-03, atol=1e-05) + np.testing.assert_allclose(to_numpy(torch_output), + onnx_output, + rtol=1e-03, + atol=1e-05) meta = ort_session.get_modelmeta() print("\t\tcustom_metadata_map={}".format(meta.custom_metadata_map)) print("\t\tCheck onnx_encoder, pass!") @@ -259,11 +284,16 @@ def export_ctc(asr_model, args): print("\tStage-2.2: torch.onnx.export") dynamic_axes = {'hidden': {1: 'T'}, 'probs': {1: 'T'}} - torch.onnx.export( - ctc, hidden, ctc_outpath, opset_version=13, - export_params=True, do_constant_folding=True, - input_names=['hidden'], output_names=['probs'], - dynamic_axes=dynamic_axes, verbose=False) + torch.onnx.export(ctc, + hidden, + ctc_outpath, + opset_version=13, + export_params=True, + do_constant_folding=True, + input_names=['hidden'], + output_names=['probs'], + dynamic_axes=dynamic_axes, + verbose=False) onnx_ctc = onnx.load(ctc_outpath) for (k, v) in args.items(): meta = onnx_ctc.metadata_props.add() @@ -280,12 +310,14 @@ def export_ctc(asr_model, args): print("\tStage-2.3: check onnx_ctc and torch_ctc") torch_output = ctc(hidden) - ort_session = onnxruntime.InferenceSession(ctc_outpath, - providers=['CPUExecutionProvider']) + ort_session = onnxruntime.InferenceSession( + ctc_outpath, providers=['CPUExecutionProvider']) onnx_output = ort_session.run(None, {'hidden': to_numpy(hidden)}) - np.testing.assert_allclose(to_numpy(torch_output), onnx_output[0], - rtol=1e-03, atol=1e-05) + np.testing.assert_allclose(to_numpy(torch_output), + onnx_output[0], + rtol=1e-03, + atol=1e-05) print("\t\tCheck onnx_ctc, pass!") @@ -300,24 +332,43 @@ def export_decoder(asr_model, args): print("\tStage-3.1: prepare inputs for decoder") # hardcode time->200 nbest->10 len->20, they are dynamic axes. encoder_out = torch.randn((1, 200, args['output_size'])) - hyps = torch.randint(low=0, high=args['vocab_size'], - size=[10, 20]) + hyps = torch.randint(low=0, high=args['vocab_size'], size=[10, 20]) hyps[:, 0] = args['vocab_size'] - 1 # hyps_lens = torch.randint(low=15, high=21, size=[10]) print("\tStage-3.2: torch.onnx.export") dynamic_axes = { - 'hyps': {0: 'NBEST', 1: 'L'}, 'hyps_lens': {0: 'NBEST'}, - 'encoder_out': {1: 'T'}, - 'score': {0: 'NBEST', 1: 'L'}, 'r_score': {0: 'NBEST', 1: 'L'} + 'hyps': { + 0: 'NBEST', + 1: 'L' + }, + 'hyps_lens': { + 0: 'NBEST' + }, + 'encoder_out': { + 1: 'T' + }, + 'score': { + 0: 'NBEST', + 1: 'L' + }, + 'r_score': { + 0: 'NBEST', + 1: 'L' + } } inputs = (hyps, hyps_lens, encoder_out, args['reverse_weight']) torch.onnx.export( - decoder, inputs, decoder_outpath, opset_version=13, - export_params=True, do_constant_folding=True, + decoder, + inputs, + decoder_outpath, + opset_version=13, + export_params=True, + do_constant_folding=True, input_names=['hyps', 'hyps_lens', 'encoder_out', 'reverse_weight'], output_names=['score', 'r_score'], - dynamic_axes=dynamic_axes, verbose=False) + dynamic_axes=dynamic_axes, + verbose=False) onnx_decoder = onnx.load(decoder_outpath) for (k, v) in args.items(): meta = onnx_decoder.metadata_props.add() @@ -329,14 +380,13 @@ def export_decoder(asr_model, args): model_fp32 = decoder_outpath model_quant = os.path.join(args['output_dir'], 'decoder.quant.onnx') quantize_dynamic(model_fp32, model_quant, weight_type=QuantType.QUInt8) - print('\t\tExport onnx_decoder, done! see {}'.format( - decoder_outpath)) + print('\t\tExport onnx_decoder, done! see {}'.format(decoder_outpath)) print("\tStage-3.3: check onnx_decoder and torch_decoder") - torch_score, torch_r_score = decoder( - hyps, hyps_lens, encoder_out, args['reverse_weight']) - ort_session = onnxruntime.InferenceSession(decoder_outpath, - providers=['CPUExecutionProvider']) + torch_score, torch_r_score = decoder(hyps, hyps_lens, encoder_out, + args['reverse_weight']) + ort_session = onnxruntime.InferenceSession( + decoder_outpath, providers=['CPUExecutionProvider']) input_names = [node.name for node in onnx_decoder.graph.input] ort_inputs = { 'hyps': to_numpy(hyps), @@ -349,11 +399,15 @@ def export_decoder(asr_model, args): ort_inputs.pop(k) onnx_output = ort_session.run(None, ort_inputs) - np.testing.assert_allclose(to_numpy(torch_score), onnx_output[0], - rtol=1e-03, atol=1e-05) + np.testing.assert_allclose(to_numpy(torch_score), + onnx_output[0], + rtol=1e-03, + atol=1e-05) if args['is_bidirectional_decoder'] and args['reverse_weight'] > 0.0: - np.testing.assert_allclose(to_numpy(torch_r_score), onnx_output[1], - rtol=1e-03, atol=1e-05) + np.testing.assert_allclose(to_numpy(torch_r_score), + onnx_output[1], + rtol=1e-03, + atol=1e-05) print("\t\tCheck onnx_decoder, pass!") @@ -381,7 +435,8 @@ def main(): arguments['reverse_weight'] = args.reverse_weight arguments['output_size'] = configs['encoder_conf']['output_size'] arguments['num_blocks'] = configs['encoder_conf']['num_blocks'] - arguments['cnn_module_kernel'] = configs['encoder_conf'].get('cnn_module_kernel', 1) + arguments['cnn_module_kernel'] = configs['encoder_conf'].get( + 'cnn_module_kernel', 1) arguments['head'] = configs['encoder_conf']['attention_heads'] arguments['feature_size'] = configs['input_dim'] arguments['vocab_size'] = configs['output_dim'] diff --git a/wenet/bin/export_onnx_gpu.py b/wenet/bin/export_onnx_gpu.py index 55540e8e2..9832519ee 100644 --- a/wenet/bin/export_onnx_gpu.py +++ b/wenet/bin/export_onnx_gpu.py @@ -40,6 +40,7 @@ class Encoder(torch.nn.Module): + def __init__(self, encoder: BaseEncoder, ctc: CTC, beam_size: int = 10): super().__init__() self.encoder = encoder @@ -47,7 +48,9 @@ def __init__(self, encoder: BaseEncoder, ctc: CTC, beam_size: int = 10): self.beam_size = beam_size def forward( - self, speech: torch.Tensor, speech_lengths: torch.Tensor, + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, ): """Encoder Args: @@ -60,13 +63,14 @@ def forward( beam_log_probs: B x T x beam_size beam_log_probs_idx: B x T x beam_size """ - encoder_out, encoder_mask = self.encoder(speech, speech_lengths, -1, -1) + encoder_out, encoder_mask = self.encoder(speech, speech_lengths, -1, + -1) encoder_out_lens = encoder_mask.squeeze(1).sum(1) ctc_log_probs = self.ctc.log_softmax(encoder_out) encoder_out_lens = encoder_out_lens.int() - beam_log_probs, beam_log_probs_idx = torch.topk( - ctc_log_probs, self.beam_size, dim=2 - ) + beam_log_probs, beam_log_probs_idx = torch.topk(ctc_log_probs, + self.beam_size, + dim=2) return ( encoder_out, encoder_out_lens, @@ -77,6 +81,7 @@ def forward( class StreamingEncoder(torch.nn.Module): + def __init__( self, model, @@ -96,9 +101,8 @@ def __init__( self.transformer = transformer self.return_ctc_logprobs = return_ctc_logprobs - def forward( - self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache, cache_mask - ): + def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache, + cache_mask): """Streaming Encoder Args: xs (torch.Tensor): chunk input, with shape (b, time, mel-dim), @@ -175,8 +179,7 @@ def forward( # shape(new_att_cache) is (B, head, attention_key_size, d_k * 2), # shape(new_cnn_cache) is (B, hidden-dim, cache_t2) r_att_cache.append( - new_att_cache[:, :, next_cache_start:, :].unsqueeze(1) - ) + new_att_cache[:, :, next_cache_start:, :].unsqueeze(1)) if not self.transformer: r_cnn_cache.append(new_cnn_cache.unsqueeze(1)) if self.encoder.normalize_before: @@ -191,9 +194,9 @@ def forward( # <---------forward_chunk END---------> log_ctc_probs = self.ctc.log_softmax(chunk_out) - log_probs, log_probs_idx = torch.topk( - log_ctc_probs, self.beam_size, dim=2 - ) + log_probs, log_probs_idx = torch.topk(log_ctc_probs, + self.beam_size, + dim=2) log_probs = log_probs.to(chunk_xs.dtype) r_offset = offset + chunk_out.shape[1] @@ -226,6 +229,7 @@ def forward( class StreamingSqueezeformerEncoder(torch.nn.Module): + def __init__(self, model, required_cache_size, beam_size): super().__init__() self.ctc = model.ctc @@ -258,11 +262,10 @@ def calculate_downsampling_factor(self, i: int) -> int: for exp, rc_idx in enumerate(self.recover_idx): if i >= rc_idx: recover_exp = exp + 1 - return int(2 ** (reduce_exp - recover_exp)) + return int(2**(reduce_exp - recover_exp)) - def forward( - self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache, cache_mask - ): + def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache, + cache_mask): """Streaming Encoder Args: xs (torch.Tensor): chunk input, with shape (b, time, mel-dim), @@ -328,12 +331,14 @@ def forward( r_att_cache = [] r_cnn_cache = [] - mask_pad = torch.ones(1, xs.size(1), device=xs.device, dtype=torch.bool) + mask_pad = torch.ones(1, + xs.size(1), + device=xs.device, + dtype=torch.bool) mask_pad = mask_pad.unsqueeze(1) max_att_len: int = 0 - recover_activations: List[ - Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] - ] = [] + recover_activations: List[Tuple[torch.Tensor, torch.Tensor, + torch.Tensor, torch.Tensor]] = [] index = 0 xs_lens = torch.tensor([xs.size(1)], device=xs.device, dtype=torch.int) xs = self.encoder.preln(xs) @@ -341,19 +346,17 @@ def forward( if self.reduce_idx is not None: if self.time_reduce is not None and i in self.reduce_idx: recover_activations.append( - (xs, att_mask, pos_emb, mask_pad) - ) + (xs, att_mask, pos_emb, mask_pad)) ( xs, xs_lens, att_mask, mask_pad, ) = self.encoder.time_reduction_layer( - xs, xs_lens, att_mask, mask_pad - ) + xs, xs_lens, att_mask, mask_pad) pos_emb = pos_emb[:, ::2, :] if self.encoder.pos_enc_layer_type == "rel_pos_repaired": - pos_emb = pos_emb[:, : xs.size(1) * 2 - 1, :] + pos_emb = pos_emb[:, :xs.size(1) * 2 - 1, :] index += 1 if self.recover_idx is not None: @@ -380,18 +383,15 @@ def forward( xs, att_mask, pos_emb, - att_cache=att_cache[i][:, :, ::factor, :][ - :, :, : pos_emb.size(1) - xs.size(1), : - ] - if elayers > 0 - else att_cache[:, :, ::factor, :], + att_cache=att_cache[i][:, :, ::factor, :] + [:, :, :pos_emb.size(1) - xs.size(1), :] + if elayers > 0 else att_cache[:, :, ::factor, :], cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache, ) - cached_att = new_att_cache[:, :, next_cache_start // factor :, :] + cached_att = new_att_cache[:, :, next_cache_start // factor:, :] cached_cnn = new_cnn_cache.unsqueeze(1) - cached_att = ( - cached_att.unsqueeze(3).repeat(1, 1, 1, factor, 1).flatten(2, 3) - ) + cached_att = (cached_att.unsqueeze(3).repeat(1, 1, 1, factor, + 1).flatten(2, 3)) if i == 0: # record length for the first block as max length max_att_len = cached_att.size(2) @@ -405,9 +405,9 @@ def forward( # <---------forward_chunk END---------> log_ctc_probs = self.ctc.log_softmax(chunk_out) - log_probs, log_probs_idx = torch.topk( - log_ctc_probs, self.beam_size, dim=2 - ) + log_probs, log_probs_idx = torch.topk(log_ctc_probs, + self.beam_size, + dim=2) log_probs = log_probs.to(chunk_xs.dtype) r_offset = offset + chunk_out.shape[1] @@ -430,6 +430,7 @@ def forward( class StreamingEfficientConformerEncoder(torch.nn.Module): + def __init__(self, model, required_cache_size, beam_size): super().__init__() self.ctc = model.ctc @@ -453,9 +454,8 @@ def calculate_downsampling_factor(self, i: int) -> int: factor *= self.stride[idx] return factor - def forward( - self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache, cache_mask - ): + def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache, + cache_mask): """Streaming Encoder Args: chunk_xs (torch.Tensor): chunk input, with shape (b, time, mel-dim), @@ -538,52 +538,37 @@ def forward( if xs.size(1) + att_cache.size(3) / factor > pos_emb.size(1): # The time step is not divisible by the downsampling multiple # We propose to double the chunk_size. - att_cache_trunc = ( - xs.size(1) - + att_cache.size(3) // factor - - pos_emb.size(1) - + 1 - ) + att_cache_trunc = (xs.size(1) + att_cache.size(3) // factor - + pos_emb.size(1) + 1) xs, _, new_att_cache, new_cnn_cache = layer( xs, att_mask, pos_emb, mask_pad=mask_pad, - att_cache=att_cache[i][:, :, ::factor, :][ - :, :, att_cache_trunc:, : - ], + att_cache=att_cache[i][:, :, ::factor, :][:, :, + att_cache_trunc:, :], cnn_cache=cnn_cache[i, :, :, :] - if cnn_cache.size(0) > 0 - else cnn_cache, + if cnn_cache.size(0) > 0 else cnn_cache, ) if i in self.stride_layer_idx: # compute time dimension for next block efficient_index = self.stride_layer_idx.index(i) - att_mask = att_mask[ - :, - :: self.stride[efficient_index], - :: self.stride[efficient_index], - ] - mask_pad = mask_pad[ - :, - :: self.stride[efficient_index], - :: self.stride[efficient_index], - ] - pos_emb = pos_emb[:, :: self.stride[efficient_index], :] + att_mask = att_mask[:, ::self.stride[efficient_index], ::self. + stride[efficient_index], ] + mask_pad = mask_pad[:, ::self.stride[efficient_index], ::self. + stride[efficient_index], ] + pos_emb = pos_emb[:, ::self.stride[efficient_index], :] # shape(new_att_cache) = [batch, head, time2, outdim] - new_att_cache = new_att_cache[:, :, next_cache_start // factor :, :] + new_att_cache = new_att_cache[:, :, next_cache_start // factor:, :] # shape(new_cnn_cache) = [batch, 1, outdim, cache_t2] new_cnn_cache = new_cnn_cache.unsqueeze(1) # shape(1):layerID # use repeat_interleave to new_att_cache # new_att_cache = new_att_cache.repeat_interleave(repeats=factor, dim=2) - new_att_cache = ( - new_att_cache.unsqueeze(3) - .repeat(1, 1, 1, factor, 1) - .flatten(2, 3) - ) + new_att_cache = (new_att_cache.unsqueeze(3).repeat( + 1, 1, 1, factor, 1).flatten(2, 3)) # padding new_cnn_cache to cnn.lorder for casual convolution new_cnn_cache = F.pad( new_cnn_cache, @@ -596,9 +581,8 @@ def forward( max_cnn_len = new_cnn_cache.size(3) # update real shape of att_cache and cnn_cache - r_att_cache.append( - new_att_cache[:, :, -max_att_len:, :].unsqueeze(1) - ) + r_att_cache.append(new_att_cache[:, :, + -max_att_len:, :].unsqueeze(1)) r_cnn_cache.append(new_cnn_cache[:, :, :, -max_cnn_len:]) if self.encoder.normalize_before: @@ -614,9 +598,9 @@ def forward( # <---------forward_chunk END---------> log_ctc_probs = self.ctc.log_softmax(chunk_out) - log_probs, log_probs_idx = torch.topk( - log_ctc_probs, self.beam_size, dim=2 - ) + log_probs, log_probs_idx = torch.topk(log_ctc_probs, + self.beam_size, + dim=2) log_probs = log_probs.to(chunk_xs.dtype) r_offset = offset + chunk_out.shape[1] @@ -624,10 +608,8 @@ def forward( # chunk_out_lens = torch.div(chunk_lens, subsampling_rate, # rounding_mode='floor') chunk_out_lens = ( - chunk_lens - // self.subsampling_rate - // self.calculate_downsampling_factor(self.num_blocks + 1) - ) + chunk_lens // self.subsampling_rate // + self.calculate_downsampling_factor(self.num_blocks + 1)) chunk_out_lens += 1 r_offset = r_offset.unsqueeze(1) @@ -644,6 +626,7 @@ def forward( class Decoder(torch.nn.Module): + def __init__( self, decoder: TransformerDecoder, @@ -691,7 +674,7 @@ def forward( encoder_mask = encoder_mask.repeat(1, bz, 1).view(B2, 1, T) T2 = hyps_pad_sos_eos.shape[2] - 1 hyps_pad = hyps_pad_sos_eos.view(B2, T2 + 1) - hyps_lens = hyps_lens_sos.view(B2,) + hyps_lens = hyps_lens_sos.view(B2, ) hyps_pad_sos = hyps_pad[:, :-1].contiguous() hyps_pad_eos = hyps_pad[:, 1:].contiguous() @@ -718,17 +701,14 @@ def forward( score = score * mask decoder_out = decoder_out.view(B, bz, T2, V) if self.reverse_weight > 0: - r_decoder_out = torch.nn.functional.log_softmax( - r_decoder_out, dim=-1 - ) + r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, + dim=-1) r_decoder_out = r_decoder_out.view(B2, T2, V) index = torch.unsqueeze(r_hyps_pad_eos * mask, 2) r_score = r_decoder_out.gather(2, index).squeeze(2) r_score = r_score * mask - score = ( - score * (1 - self.reverse_weight) - + self.reverse_weight * r_score - ) + score = (score * (1 - self.reverse_weight) + + self.reverse_weight * r_score) r_decoder_out = r_decoder_out.view(B, bz, T2, V) score = torch.sum(score, axis=1) # B2 score = torch.reshape(score, (B, bz)) + self.ctc_weight * ctc_score @@ -770,9 +750,10 @@ def export_offline_encoder(model, configs, args, logger, encoder_onnx_path): feature_size = configs["input_dim"] speech = torch.randn(bz, seq_len, feature_size, dtype=torch.float32) - speech_lens = torch.randint( - low=10, high=seq_len, size=(bz,), dtype=torch.int32 - ) + speech_lens = torch.randint(low=10, + high=seq_len, + size=(bz, ), + dtype=torch.int32) encoder = Encoder(model.encoder, model.ctc, beam_size) encoder.eval() @@ -792,13 +773,32 @@ def export_offline_encoder(model, configs, args, logger, encoder_onnx_path): "beam_log_probs_idx", ], dynamic_axes={ - "speech": {0: "B", 1: "T"}, - "speech_lengths": {0: "B"}, - "encoder_out": {0: "B", 1: "T_OUT"}, - "encoder_out_lens": {0: "B"}, - "ctc_log_probs": {0: "B", 1: "T_OUT"}, - "beam_log_probs": {0: "B", 1: "T_OUT"}, - "beam_log_probs_idx": {0: "B", 1: "T_OUT"}, + "speech": { + 0: "B", + 1: "T" + }, + "speech_lengths": { + 0: "B" + }, + "encoder_out": { + 0: "B", + 1: "T_OUT" + }, + "encoder_out_lens": { + 0: "B" + }, + "ctc_log_probs": { + 0: "B", + 1: "T_OUT" + }, + "beam_log_probs": { + 0: "B", + 1: "T_OUT" + }, + "beam_log_probs_idx": { + 0: "B", + 1: "T_OUT" + }, }, verbose=False, ) @@ -807,9 +807,8 @@ def export_offline_encoder(model, configs, args, logger, encoder_onnx_path): o0, o1, o2, o3, o4 = encoder(speech, speech_lens) providers = ["CUDAExecutionProvider"] - ort_session = onnxruntime.InferenceSession( - encoder_onnx_path, providers=providers - ) + ort_session = onnxruntime.InferenceSession(encoder_onnx_path, + providers=providers) ort_inputs = { "speech": to_numpy(speech), "speech_lengths": to_numpy(speech_lens), @@ -846,13 +845,12 @@ def export_online_encoder(model, configs, args, logger, encoder_onnx_path): num_decoding_left_chunks = args.num_decoding_left_chunks required_cache_size = decoding_chunk_size * num_decoding_left_chunks if configs["encoder"] == "squeezeformer": - encoder = StreamingSqueezeformerEncoder( - model, required_cache_size, args.beam_size - ) + encoder = StreamingSqueezeformerEncoder(model, required_cache_size, + args.beam_size) elif configs["encoder"] == "efficientConformer": - encoder = StreamingEfficientConformerEncoder( - model, required_cache_size, args.beam_size - ) + encoder = StreamingEfficientConformerEncoder(model, + required_cache_size, + args.beam_size) else: encoder = StreamingEncoder( model, @@ -864,9 +862,10 @@ def export_online_encoder(model, configs, args, logger, encoder_onnx_path): encoder.eval() # begin to export encoder - chunk_xs = torch.randn( - batch_size, audio_len, feature_size, dtype=torch.float32 - ) + chunk_xs = torch.randn(batch_size, + audio_len, + feature_size, + dtype=torch.float32) chunk_lens = torch.ones(batch_size, dtype=torch.int32) * audio_len offset = torch.arange(0, batch_size).unsqueeze(1) @@ -889,9 +888,10 @@ def export_online_encoder(model, configs, args, logger, encoder_onnx_path): dtype=torch.float32, ) - cache_mask = torch.ones( - batch_size, 1, required_cache_size, dtype=torch.float32 - ) + cache_mask = torch.ones(batch_size, + 1, + required_cache_size, + dtype=torch.float32) input_names = [ "chunk_xs", "chunk_lens", @@ -929,9 +929,8 @@ def export_online_encoder(model, configs, args, logger, encoder_onnx_path): cache_mask, ) if transformer: - assert ( - args.return_ctc_logprobs is False - ), "return_ctc_logprobs is not supported in transformer" + assert (args.return_ctc_logprobs is + False), "return_ctc_logprobs is not supported in transformer" output_names.pop(6) all_names = input_names + output_names @@ -955,14 +954,12 @@ def export_online_encoder(model, configs, args, logger, encoder_onnx_path): ) with torch.no_grad(): - torch_outs = encoder( - chunk_xs, chunk_lens, offset, att_cache, cnn_cache, cache_mask - ) + torch_outs = encoder(chunk_xs, chunk_lens, offset, att_cache, + cnn_cache, cache_mask) if transformer: torch_outs = list(torch_outs).pop(6) ort_session = onnxruntime.InferenceSession( - encoder_onnx_path, providers=["CUDAExecutionProvider"] - ) + encoder_onnx_path, providers=["CUDAExecutionProvider"]) ort_inputs = {} input_tensors = to_numpy(input_tensors) @@ -988,9 +985,8 @@ def export_online_encoder(model, configs, args, logger, encoder_onnx_path): return onnx_config -def export_rescoring_decoder( - model, configs, args, logger, decoder_onnx_path, decoder_fastertransformer -): +def export_rescoring_decoder(model, configs, args, logger, decoder_onnx_path, + decoder_fastertransformer): bz, seq_len = 32, 100 beam_size = args.beam_size decoder = Decoder( @@ -1002,21 +998,23 @@ def export_rescoring_decoder( ) decoder.eval() - hyps_pad_sos_eos = torch.randint( - low=3, high=1000, size=(bz, beam_size, seq_len) - ) - hyps_lens_sos = torch.randint( - low=3, high=seq_len, size=(bz, beam_size), dtype=torch.int32 - ) - r_hyps_pad_sos_eos = torch.randint( - low=3, high=1000, size=(bz, beam_size, seq_len) - ) + hyps_pad_sos_eos = torch.randint(low=3, + high=1000, + size=(bz, beam_size, seq_len)) + hyps_lens_sos = torch.randint(low=3, + high=seq_len, + size=(bz, beam_size), + dtype=torch.int32) + r_hyps_pad_sos_eos = torch.randint(low=3, + high=1000, + size=(bz, beam_size, seq_len)) output_size = configs["encoder_conf"]["output_size"] encoder_out = torch.randn(bz, seq_len, output_size, dtype=torch.float32) - encoder_out_lens = torch.randint( - low=3, high=seq_len, size=(bz,), dtype=torch.int32 - ) + encoder_out_lens = torch.randint(low=3, + high=seq_len, + size=(bz, ), + dtype=torch.int32) ctc_score = torch.randn(bz, beam_size, dtype=torch.float32) input_names = [ @@ -1048,13 +1046,30 @@ def export_rescoring_decoder( input_names=input_names, output_names=output_names, dynamic_axes={ - "encoder_out": {0: "B", 1: "T"}, - "encoder_out_lens": {0: "B"}, - "hyps_pad_sos_eos": {0: "B", 2: "T2"}, - "hyps_lens_sos": {0: "B"}, - "r_hyps_pad_sos_eos": {0: "B", 2: "T2"}, - "ctc_score": {0: "B"}, - "best_index": {0: "B"}, + "encoder_out": { + 0: "B", + 1: "T" + }, + "encoder_out_lens": { + 0: "B" + }, + "hyps_pad_sos_eos": { + 0: "B", + 2: "T2" + }, + "hyps_lens_sos": { + 0: "B" + }, + "r_hyps_pad_sos_eos": { + 0: "B", + 2: "T2" + }, + "ctc_score": { + 0: "B" + }, + "best_index": { + 0: "B" + }, }, verbose=False, ) @@ -1068,9 +1083,8 @@ def export_rescoring_decoder( ctc_score, ) providers = ["CUDAExecutionProvider"] - ort_session = onnxruntime.InferenceSession( - decoder_onnx_path, providers=providers - ) + ort_session = onnxruntime.InferenceSession(decoder_onnx_path, + providers=providers) input_tensors = [ encoder_out, @@ -1116,8 +1130,8 @@ def export_rescoring_decoder( default=-1.0, type=float, required=False, - help="reverse weight for bitransformer," - + "default value is in config file", + help="reverse weight for bitransformer," + + "default value is in config file", ) parser.add_argument( "--ctc_weight", @@ -1182,10 +1196,8 @@ def export_rescoring_decoder( configs = yaml.load(fin, Loader=yaml.FullLoader) if args.cmvn_file and os.path.exists(args.cmvn_file): configs["cmvn_file"] = args.cmvn_file - if ( - args.reverse_weight != -1.0 - and "reverse_weight" in configs["model_conf"] - ): + if (args.reverse_weight != -1.0 + and "reverse_weight" in configs["model_conf"]): configs["model_conf"]["reverse_weight"] = args.reverse_weight print("Update reverse weight to", args.reverse_weight) if args.ctc_weight != -1: @@ -1207,9 +1219,8 @@ def export_rescoring_decoder( else: export_enc_func = export_offline_encoder - onnx_config = export_enc_func( - model, configs, args, logger, encoder_onnx_path - ) + onnx_config = export_enc_func(model, configs, args, logger, + encoder_onnx_path) decoder_onnx_path = os.path.join(args.output_onnx_dir, "decoder.onnx") export_rescoring_decoder( @@ -1225,22 +1236,19 @@ def export_rescoring_decoder( try: import onnxmltools from onnxmltools.utils.float16_converter import ( - convert_float_to_float16, - ) + convert_float_to_float16, ) except ImportError: print("Please install onnxmltools!") sys.exit(1) encoder_onnx_model = onnxmltools.utils.load_model(encoder_onnx_path) encoder_onnx_model = convert_float_to_float16(encoder_onnx_model) - encoder_onnx_path = os.path.join( - args.output_onnx_dir, "encoder_fp16.onnx" - ) + encoder_onnx_path = os.path.join(args.output_onnx_dir, + "encoder_fp16.onnx") onnxmltools.utils.save_model(encoder_onnx_model, encoder_onnx_path) decoder_onnx_model = onnxmltools.utils.load_model(decoder_onnx_path) decoder_onnx_model = convert_float_to_float16(decoder_onnx_model) - decoder_onnx_path = os.path.join( - args.output_onnx_dir, "decoder_fp16.onnx" - ) + decoder_onnx_path = os.path.join(args.output_onnx_dir, + "decoder_fp16.onnx") onnxmltools.utils.save_model(decoder_onnx_model, decoder_onnx_path) # dump configurations diff --git a/wenet/bin/recognize.py b/wenet/bin/recognize.py index e7be3be39..699ecd1df 100644 --- a/wenet/bin/recognize.py +++ b/wenet/bin/recognize.py @@ -207,7 +207,8 @@ def main(): test_conf['batch_conf']['batch_type'] = "static" test_conf['batch_conf']['batch_size'] = args.batch_size - tokenizer = init_tokenizer(configs, args.dict, args.bpe_model, args.non_lang_syms) + tokenizer = init_tokenizer(configs, args.dict, args.bpe_model, + args.non_lang_syms) test_dataset = Dataset(args.data_type, args.test_data, tokenizer, @@ -227,8 +228,9 @@ def main(): context_graph = None if 'decoding-graph' in args.context_bias_mode: - context_graph = ContextGraph(args.context_list_path, tokenizer.symbol_table, - args.bpe_model, args.context_graph_score) + context_graph = ContextGraph(args.context_list_path, + tokenizer.symbol_table, args.bpe_model, + args.context_graph_score) _, blank_id = get_blank_id(configs, tokenizer.symbol_table) logging.info("blank_id is {}".format(blank_id)) diff --git a/wenet/bin/recognize_onnx_gpu.py b/wenet/bin/recognize_onnx_gpu.py index de06aef57..5a5cea66c 100644 --- a/wenet/bin/recognize_onnx_gpu.py +++ b/wenet/bin/recognize_onnx_gpu.py @@ -25,7 +25,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """ This script is for testing exported onnx encoder and decoder from export_onnx_gpu.py. The exported onnx models only support batch offline ASR inference. @@ -77,8 +76,12 @@ def get_args(): default=-1, help='gpu id for this rank, -1 for cpu') parser.add_argument('--dict', required=True, help='dict file') - parser.add_argument('--encoder_onnx', required=True, help='encoder onnx file') - parser.add_argument('--decoder_onnx', required=True, help='decoder onnx file') + parser.add_argument('--encoder_onnx', + required=True, + help='encoder onnx file') + parser.add_argument('--decoder_onnx', + required=True, + help='decoder onnx file') parser.add_argument('--result_file', required=True, help='asr result file') parser.add_argument('--batch_size', type=int, @@ -87,7 +90,8 @@ def get_args(): parser.add_argument('--mode', choices=[ 'ctc_greedy_search', 'ctc_prefix_beam_search', - 'attention_rescoring'], + 'attention_rescoring' + ], default='attention_rescoring', help='decoding mode') parser.add_argument('--bpe_model', @@ -151,10 +155,12 @@ def main(): else: EP_list = ['CPUExecutionProvider'] - encoder_ort_session = rt.InferenceSession(args.encoder_onnx, providers=EP_list) + encoder_ort_session = rt.InferenceSession(args.encoder_onnx, + providers=EP_list) decoder_ort_session = None if args.mode == "attention_rescoring": - decoder_ort_session = rt.InferenceSession(args.decoder_onnx, providers=EP_list) + decoder_ort_session = rt.InferenceSession(args.decoder_onnx, + providers=EP_list) # Load dict vocabulary = [] @@ -174,7 +180,8 @@ def main(): feats = feats.astype(np.float16) ort_inputs = { encoder_ort_session.get_inputs()[0].name: feats, - encoder_ort_session.get_inputs()[1].name: feats_lengths} + encoder_ort_session.get_inputs()[1].name: feats_lengths + } ort_outs = encoder_ort_session.run(None, ort_inputs) encoder_out, encoder_out_lens, ctc_log_probs, \ beam_log_probs, beam_log_probs_idx = ort_outs @@ -187,9 +194,10 @@ def main(): batch_sents = [] for idx, seq in enumerate(log_probs_idx): batch_sents.append(seq[0:encoder_out_lens[idx]].tolist()) - hyps = map_batch(batch_sents, vocabulary, num_processes, - True, 0) - elif args.mode in ('ctc_prefix_beam_search', "attention_rescoring"): + hyps = map_batch(batch_sents, vocabulary, num_processes, True, + 0) + elif args.mode in ('ctc_prefix_beam_search', + "attention_rescoring"): batch_log_probs_seq_list = beam_log_probs.tolist() batch_log_probs_idx_list = beam_log_probs_idx.tolist() batch_len_list = encoder_out_lens.tolist() @@ -207,13 +215,9 @@ def main(): root_dict[i] = PathTrie() batch_root.append(root_dict[i]) batch_start.append(True) - score_hyps = ctc_beam_search_decoder_batch(batch_log_probs_seq, - batch_log_probs_ids, - batch_root, - batch_start, - beam_size, - num_processes, - 0, -2, 0.99999) + score_hyps = ctc_beam_search_decoder_batch( + batch_log_probs_seq, batch_log_probs_ids, batch_root, + batch_start, beam_size, num_processes, 0, -2, 0.99999) if args.mode == 'ctc_prefix_beam_search': hyps = [] for cand_hyps in score_hyps: @@ -225,7 +229,8 @@ def main(): for hyps in score_hyps: cur_len = len(hyps) if len(hyps) < beam_size: - hyps += (beam_size - cur_len) * [(-float("INF"), (0,))] + hyps += (beam_size - cur_len) * [(-float("INF"), + (0, ))] cur_ctc_score = [] for hyp in hyps: cur_ctc_score.append(hyp[0]) @@ -238,17 +243,22 @@ def main(): else: ctc_score = np.array(ctc_score, dtype=np.float32) hyps_pad_sos_eos = np.ones( - (batch_size, beam_size, max_len + 2), dtype=np.int64) * IGNORE_ID + (batch_size, beam_size, max_len + 2), + dtype=np.int64) * IGNORE_ID r_hyps_pad_sos_eos = np.ones( - (batch_size, beam_size, max_len + 2), dtype=np.int64) * IGNORE_ID - hyps_lens_sos = np.ones((batch_size, beam_size), dtype=np.int32) + (batch_size, beam_size, max_len + 2), + dtype=np.int64) * IGNORE_ID + hyps_lens_sos = np.ones((batch_size, beam_size), + dtype=np.int32) k = 0 for i in range(batch_size): for j in range(beam_size): cand = all_hyps[k] l = len(cand) + 2 hyps_pad_sos_eos[i][j][0:l] = [sos] + cand + [eos] - r_hyps_pad_sos_eos[i][j][0:l] = [sos] + cand[::-1] + [eos] + r_hyps_pad_sos_eos[i][j][0:l] = [sos] + cand[::-1] + [ + eos + ] hyps_lens_sos[i][j] = len(cand) + 1 k += 1 decoder_ort_inputs = { @@ -256,15 +266,19 @@ def main(): decoder_ort_session.get_inputs()[1].name: encoder_out_lens, decoder_ort_session.get_inputs()[2].name: hyps_pad_sos_eos, decoder_ort_session.get_inputs()[3].name: hyps_lens_sos, - decoder_ort_session.get_inputs()[-1].name: ctc_score} + decoder_ort_session.get_inputs()[-1].name: ctc_score + } if reverse_weight > 0: - r_hyps_pad_sos_eos_name = decoder_ort_session.get_inputs()[4].name - decoder_ort_inputs[r_hyps_pad_sos_eos_name] = r_hyps_pad_sos_eos - best_index = decoder_ort_session.run(None, decoder_ort_inputs)[0] + r_hyps_pad_sos_eos_name = decoder_ort_session.get_inputs( + )[4].name + decoder_ort_inputs[ + r_hyps_pad_sos_eos_name] = r_hyps_pad_sos_eos + best_index = decoder_ort_session.run(None, + decoder_ort_inputs)[0] best_sents = [] k = 0 for idx in best_index: - cur_best_sent = all_hyps[k: k + beam_size][idx] + cur_best_sent = all_hyps[k:k + beam_size][idx] best_sents.append(cur_best_sent) k += beam_size hyps = map_batch(best_sents, vocabulary, num_processes) @@ -274,5 +288,6 @@ def main(): logging.info('{} {}'.format(key, content)) fout.write('{} {}\n'.format(key, content)) + if __name__ == '__main__': main() diff --git a/wenet/bin/train.py b/wenet/bin/train.py index 4516c8291..2cf70a8b8 100644 --- a/wenet/bin/train.py +++ b/wenet/bin/train.py @@ -29,15 +29,13 @@ from wenet.utils.config import override_config from wenet.utils.init_model import init_model from wenet.utils.init_tokenizer import init_tokenizer -from wenet.utils.train_utils import (add_model_args, add_dataset_args, - add_ddp_args, add_deepspeed_args, - add_trace_args, init_distributed, - init_dataset_and_dataloader, - check_modify_and_save_config, - init_optimizer_and_scheduler, - trace_and_print_model, wrap_cuda_model, - init_summarywriter, save_model, - log_per_epoch) +from wenet.utils.train_utils import ( + add_model_args, add_dataset_args, add_ddp_args, add_deepspeed_args, + add_trace_args, init_distributed, init_dataset_and_dataloader, + check_modify_and_save_config, init_optimizer_and_scheduler, + trace_and_print_model, wrap_cuda_model, init_summarywriter, save_model, + log_per_epoch) + def get_args(): parser = argparse.ArgumentParser(description='training your network') @@ -86,7 +84,8 @@ def main(): init_dataset_and_dataloader(args, configs, tokenizer) # Do some sanity checks and save config to arsg.model_dir - configs = check_modify_and_save_config(args, configs, tokenizer.symbol_table) + configs = check_modify_and_save_config(args, configs, + tokenizer.symbol_table) # Init asr model from configs model, configs = init_model(args, configs) @@ -105,10 +104,14 @@ def main(): args, configs, model) # Save checkpoints - save_model(model, info_dict={ - "save_time": datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S'), - "tag": "init", **configs - }) + save_model(model, + info_dict={ + "save_time": + datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S'), + "tag": + "init", + **configs + }) # Get executor executor = Executor() @@ -127,25 +130,33 @@ def main(): configs['epoch'] = epoch lr = optimizer.param_groups[0]['lr'] - logging.info('Epoch {} TRAIN info lr {} rank {}'.format(epoch, lr, rank)) + logging.info('Epoch {} TRAIN info lr {} rank {}'.format( + epoch, lr, rank)) - dist.barrier() # NOTE(xcsong): Ensure all ranks start Train at the same time. + dist.barrier( + ) # NOTE(xcsong): Ensure all ranks start Train at the same time. # NOTE(xcsong): Why we need a new group? see `train_utils.py::wenet_join` - group_join = dist.new_group(backend="gloo", - timeout=datetime.timedelta(seconds=args.timeout)) - executor.train(model, optimizer, scheduler, train_data_loader, - writer, configs, scaler, group_join) + group_join = dist.new_group( + backend="gloo", timeout=datetime.timedelta(seconds=args.timeout)) + executor.train(model, optimizer, scheduler, train_data_loader, writer, + configs, scaler, group_join) dist.destroy_process_group(group_join) - dist.barrier() # NOTE(xcsong): Ensure all ranks start CV at the same time. + dist.barrier( + ) # NOTE(xcsong): Ensure all ranks start CV at the same time. total_loss, num_seen_utts = executor.cv(model, cv_data_loader, configs) cv_loss = total_loss / num_seen_utts - logging.info('Epoch {} CV info cv_loss {} rank {}'.format(epoch, cv_loss, rank)) + logging.info('Epoch {} CV info cv_loss {} rank {}'.format( + epoch, cv_loss, rank)) info_dict = { - 'epoch': epoch, 'lr': lr, 'cv_loss': cv_loss, 'step': executor.step, + 'epoch': epoch, + 'lr': lr, + 'cv_loss': cv_loss, + 'step': executor.step, 'save_time': datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S'), - 'tag': str(epoch), **configs + 'tag': str(epoch), + **configs } log_per_epoch(writer, info_dict=info_dict) save_model(model, info_dict=info_dict) @@ -154,7 +165,8 @@ def main(): if final_epoch is not None and rank == 0: final_model_path = os.path.join(args.model_dir, 'final.pt') - os.remove(final_model_path) if os.path.exists(final_model_path) else None + os.remove(final_model_path) if os.path.exists( + final_model_path) else None os.symlink('{}.pt'.format(final_epoch), final_model_path) writer.close() diff --git a/wenet/branchformer/cgmlp.py b/wenet/branchformer/cgmlp.py index a991d5b0c..b56a2505e 100644 --- a/wenet/branchformer/cgmlp.py +++ b/wenet/branchformer/cgmlp.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # Modified from ESPnet(https://github.com/espnet/espnet) - """MLP with convolutional gating (cgMLP) definition. References: @@ -85,9 +84,7 @@ def espnet_initialization_fn(self): torch.nn.init.ones_(self.linear.bias) def forward( - self, - x: torch.Tensor, - cache: torch.Tensor = torch.zeros((0, 0, 0)) + self, x: torch.Tensor, cache: torch.Tensor = torch.zeros((0, 0, 0)) ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward method @@ -118,7 +115,9 @@ def forward( # It's better we just return None if no cache is required, # However, for JIT export, here we just fake one tensor instead of # None. - new_cache = torch.zeros((0, 0, 0), dtype=x_g.dtype, device=x_g.device) + new_cache = torch.zeros((0, 0, 0), + dtype=x_g.dtype, + device=x_g.device) x_g = x_g.transpose(1, 2) x_g = self.norm(x_g) # (N, T, D/2) @@ -148,8 +147,7 @@ def __init__( super().__init__() self.channel_proj1 = torch.nn.Sequential( - torch.nn.Linear(size, linear_units), torch.nn.GELU() - ) + torch.nn.Linear(size, linear_units), torch.nn.GELU()) self.csgu = ConvolutionalSpatialGatingUnit( size=linear_units, kernel_size=kernel_size, diff --git a/wenet/branchformer/encoder.py b/wenet/branchformer/encoder.py index b35120c30..7d00b2a70 100644 --- a/wenet/branchformer/encoder.py +++ b/wenet/branchformer/encoder.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # Modified from ESPnet(https://github.com/espnet/espnet) - """Encoder definition.""" import torch @@ -25,7 +24,9 @@ from wenet.utils.mask import make_pad_mask from wenet.utils.mask import add_optional_chunk_mask from wenet.utils.class_utils import ( - WENET_ATTENTION_CLASSES, WENET_EMB_CLASSES, WENET_SUBSAMPLE_CLASSES, + WENET_ATTENTION_CLASSES, + WENET_EMB_CLASSES, + WENET_SUBSAMPLE_CLASSES, ) @@ -68,8 +69,8 @@ def __init__( input_size, output_size, dropout_rate, - WENET_EMB_CLASSES[pos_enc_layer_type]( - output_size, positional_dropout_rate), + WENET_EMB_CLASSES[pos_enc_layer_type](output_size, + positional_dropout_rate), ) encoder_selfattn_layer_args = ( @@ -93,36 +94,30 @@ def __init__( if len(stochastic_depth_rate) != num_blocks: raise ValueError( f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) " - f"should be equal to num_blocks ({num_blocks})" - ) + f"should be equal to num_blocks ({num_blocks})") if isinstance(cgmlp_weight, float): cgmlp_weight = [cgmlp_weight] * num_blocks if len(cgmlp_weight) != num_blocks: raise ValueError( f"Length of cgmlp_weight ({len(cgmlp_weight)}) should be equal to " - f"num_blocks ({num_blocks})" - ) + f"num_blocks ({num_blocks})") if isinstance(attn_branch_drop_rate, float): attn_branch_drop_rate = [attn_branch_drop_rate] * num_blocks if len(attn_branch_drop_rate) != num_blocks: raise ValueError( f"Length of attn_branch_drop_rate ({len(attn_branch_drop_rate)}) " - f"should be equal to num_blocks ({num_blocks})" - ) + f"should be equal to num_blocks ({num_blocks})") - self.encoders = torch.nn.ModuleList([BranchformerEncoderLayer( - output_size, - WENET_ATTENTION_CLASSES[attention_layer_type](*encoder_selfattn_layer_args) - if use_attn - else None, - cgmlp_layer(*cgmlp_layer_args) if use_cgmlp else None, - dropout_rate, - merge_method, - cgmlp_weight[lnum], - attn_branch_drop_rate[lnum], - stochastic_depth_rate[lnum]) for lnum in range(num_blocks) + self.encoders = torch.nn.ModuleList([ + BranchformerEncoderLayer( + output_size, WENET_ATTENTION_CLASSES[attention_layer_type]( + *encoder_selfattn_layer_args) if use_attn else None, + cgmlp_layer(*cgmlp_layer_args) if use_cgmlp else None, + dropout_rate, merge_method, cgmlp_weight[lnum], + attn_branch_drop_rate[lnum], stochastic_depth_rate[lnum]) + for lnum in range(num_blocks) ]) self.after_norm = nn.LayerNorm(output_size) self.static_chunk_size = static_chunk_size @@ -174,7 +169,7 @@ def forward( self.static_chunk_size, num_decoding_left_chunks) for layer in self.encoders: - xs, chunk_masks, _ , _ = layer(xs, chunk_masks, pos_emb, mask_pad) + xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) xs = self.after_norm(xs) # Here we assume the mask is not changed in encoder layers, so just @@ -236,8 +231,8 @@ def forward_chunk( elayers, cache_t1 = att_cache.size(0), att_cache.size(2) chunk_size = xs.size(1) attention_key_size = cache_t1 + chunk_size - pos_emb = self.embed.position_encoding( - offset=offset - cache_t1, size=attention_key_size) + pos_emb = self.embed.position_encoding(offset=offset - cache_t1, + size=attention_key_size) if required_cache_size < 0: next_cache_start = 0 elif required_cache_size == 0: @@ -251,10 +246,11 @@ def forward_chunk( # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2), # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2) xs, _, new_att_cache, new_cnn_cache = layer( - xs, att_mask, pos_emb, + xs, + att_mask, + pos_emb, att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache, - cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache - ) + cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache) # NOTE(xcsong): After layer.forward # shape(new_att_cache) is (1, head, attention_key_size, d_k * 2), # shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2) @@ -321,10 +317,14 @@ def forward_chunk_by_chunk( for cur in range(0, num_frames - context + 1, stride): end = min(cur + decoding_window, num_frames) chunk_xs = xs[:, cur:end, :] - (y, att_cache, cnn_cache) = self.forward_chunk( - chunk_xs, offset, required_cache_size, att_cache, cnn_cache) + (y, att_cache, + cnn_cache) = self.forward_chunk(chunk_xs, offset, + required_cache_size, att_cache, + cnn_cache) outputs.append(y) offset += y.size(1) ys = torch.cat(outputs, 1) - masks = torch.ones((1, 1, ys.size(1)), device=ys.device, dtype=torch.bool) + masks = torch.ones((1, 1, ys.size(1)), + device=ys.device, + dtype=torch.bool) return ys, masks diff --git a/wenet/branchformer/encoder_layer.py b/wenet/branchformer/encoder_layer.py index a2cc663e8..9654a2405 100644 --- a/wenet/branchformer/encoder_layer.py +++ b/wenet/branchformer/encoder_layer.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # Modified from ESPnet(https://github.com/espnet/espnet) - """BranchformerEncoderLayer definition.""" import torch @@ -50,8 +49,7 @@ def __init__( ): super().__init__() assert (attn is not None) or ( - cgmlp is not None - ), "At least one branch should be valid" + cgmlp is not None), "At least one branch should be valid" self.size = size self.attn = attn @@ -66,7 +64,8 @@ def __init__( self.norm_mha = nn.LayerNorm(size) # for the MHA module if cgmlp is not None: self.norm_mlp = nn.LayerNorm(size) # for the MLP module - self.norm_final = nn.LayerNorm(size) # for the final output of the block + self.norm_final = nn.LayerNorm( + size) # for the final output of the block self.dropout = torch.nn.Dropout(dropout_rate) @@ -87,9 +86,8 @@ def __init__( self.merge_proj = torch.nn.Linear(size, size) elif self.merge_method == "fixed_ave": - assert ( - 0.0 <= cgmlp_weight <= 1.0 - ), "cgmlp weight should be between 0.0 and 1.0" + assert (0.0 <= cgmlp_weight <= + 1.0), "cgmlp weight should be between 0.0 and 1.0" # remove the other branch if only one branch is used if cgmlp_weight == 0.0: @@ -157,7 +155,8 @@ def forward( # Branch 1: multi-headed attention module if self.attn is not None: x1 = self.norm_mha(x1) - x_att, new_att_cache = self.attn(x1, x1, x1, mask, pos_emb, att_cache) + x_att, new_att_cache = self.attn(x1, x1, x1, mask, pos_emb, + att_cache) x1 = self.dropout(x_att) # Branch 2: convolutional gating mlp @@ -172,57 +171,54 @@ def forward( if self.use_two_branches: if self.merge_method == "concat": x = x + stoch_layer_coeff * self.dropout( - self.merge_proj(torch.cat([x1, x2], dim=-1)) - ) + self.merge_proj(torch.cat([x1, x2], dim=-1))) elif self.merge_method == "learned_ave": - if ( - self.training - and self.attn_branch_drop_rate > 0 - and torch.rand(1).item() < self.attn_branch_drop_rate - ): + if (self.training and self.attn_branch_drop_rate > 0 + and torch.rand(1).item() < self.attn_branch_drop_rate): # Drop the attn branch w1, w2 = torch.tensor(0.0), torch.tensor(1.0) else: # branch1 - score1 = (self.pooling_proj1(x1).transpose(1, 2) / self.size**0.5) + score1 = (self.pooling_proj1(x1).transpose(1, 2) / + self.size**0.5) score1 = score1.masked_fill(mask_pad.eq(0), -float('inf')) score1 = torch.softmax(score1, dim=-1).masked_fill( - mask_pad.eq(0), 0.0 - ) + mask_pad.eq(0), 0.0) - pooled1 = torch.matmul(score1, x1).squeeze(1) # (batch, size) + pooled1 = torch.matmul(score1, + x1).squeeze(1) # (batch, size) weight1 = self.weight_proj1(pooled1) # (batch, 1) # branch2 - score2 = (self.pooling_proj2(x2).transpose(1, 2) / self.size**0.5) + score2 = (self.pooling_proj2(x2).transpose(1, 2) / + self.size**0.5) score2 = score2.masked_fill(mask_pad.eq(0), -float('inf')) score2 = torch.softmax(score2, dim=-1).masked_fill( - mask_pad.eq(0), 0.0 - ) + mask_pad.eq(0), 0.0) - pooled2 = torch.matmul(score2, x2).squeeze(1) # (batch, size) + pooled2 = torch.matmul(score2, + x2).squeeze(1) # (batch, size) weight2 = self.weight_proj2(pooled2) # (batch, 1) # normalize weights of two branches - merge_weights = torch.softmax( - torch.cat([weight1, weight2], dim=-1), dim=-1 - ) # (batch, 2) + merge_weights = torch.softmax(torch.cat([weight1, weight2], + dim=-1), + dim=-1) # (batch, 2) merge_weights = merge_weights.unsqueeze(-1).unsqueeze( - -1 - ) # (batch, 2, 1, 1) - w1, w2 = merge_weights[:, 0], merge_weights[:, 1] # (batch, 1, 1) + -1) # (batch, 2, 1, 1) + w1, w2 = merge_weights[:, + 0], merge_weights[:, + 1] # (batch, 1, 1) x = x + stoch_layer_coeff * self.dropout( - self.merge_proj(w1 * x1 + w2 * x2) - ) + self.merge_proj(w1 * x1 + w2 * x2)) elif self.merge_method == "fixed_ave": x = x + stoch_layer_coeff * self.dropout( - self.merge_proj( - (1.0 - self.cgmlp_weight) * x1 + self.cgmlp_weight * x2 - ) - ) + self.merge_proj((1.0 - self.cgmlp_weight) * x1 + + self.cgmlp_weight * x2)) else: - raise RuntimeError(f"unknown merge method: {self.merge_method}") + raise RuntimeError( + f"unknown merge method: {self.merge_method}") else: if self.attn is None: x = x + stoch_layer_coeff * self.dropout(self.merge_proj(x2)) @@ -230,7 +226,8 @@ def forward( x = x + stoch_layer_coeff * self.dropout(self.merge_proj(x1)) else: # This should not happen - raise RuntimeError("Both branches are not None, which is unexpected.") + raise RuntimeError( + "Both branches are not None, which is unexpected.") x = self.norm_final(x) diff --git a/wenet/cli/model.py b/wenet/cli/model.py index 3dd952ef0..182bfc5f6 100644 --- a/wenet/cli/model.py +++ b/wenet/cli/model.py @@ -28,8 +28,13 @@ class Model: - def __init__(self, model_dir: str, gpu: int = -1, beam: int = 5, - context_path: str = None, context_score: float = 6.0, + + def __init__(self, + model_dir: str, + gpu: int = -1, + beam: int = 5, + context_path: str = None, + context_score: float = 6.0, resample_rate: int = 16000): model_path = os.path.join(model_dir, 'final.zip') units_path = os.path.join(model_dir, 'units.txt') @@ -46,7 +51,8 @@ def __init__(self, model_dir: str, gpu: int = -1, beam: int = 5, self.char_dict = {v: k for k, v in self.symbol_table.items()} self.beam = beam if context_path is not None: - self.context_graph = ContextGraph(context_path, self.symbol_table, + self.context_graph = ContextGraph(context_path, + self.symbol_table, context_score=context_score) else: self.context_graph = None @@ -74,14 +80,15 @@ def _decode(self, label: str = None) -> dict: feats = self.compute_feats(audio_file) encoder_out, _, _ = self.model.forward_encoder_chunk(feats, 0, -1) - encoder_lens = torch.tensor([ - encoder_out.size(1)], - dtype=torch.long, - device=encoder_out.device) + encoder_lens = torch.tensor([encoder_out.size(1)], + dtype=torch.long, + device=encoder_out.device) ctc_probs = self.model.ctc_activation(encoder_out) if label is None: ctc_prefix_results = ctc_prefix_beam_search( - ctc_probs, encoder_lens, self.beam, + ctc_probs, + encoder_lens, + self.beam, context_graph=self.context_graph) else: # force align mode, construct ctc prefix result from alignment label_t = self.tokenize(label) diff --git a/wenet/cli/paraformer_model.py b/wenet/cli/paraformer_model.py index 0ecd03b79..8f602baeb 100644 --- a/wenet/cli/paraformer_model.py +++ b/wenet/cli/paraformer_model.py @@ -11,7 +11,9 @@ class Paraformer: - def __init__(self, model_dir: str, device: int = -1, + def __init__(self, + model_dir: str, + device: int = -1, resample_rate: int = 16000) -> None: model_path = os.path.join(model_dir, 'final.zip') @@ -41,10 +43,9 @@ def transcribe(self, audio_file: str, tokens_info: bool = False) -> dict: energy_floor=0.0, sample_frequency=self.resample_rate) feats = feats.unsqueeze(0) - feats_lens = torch.tensor([ - feats.size(1)], - dtype=torch.int64, - device=feats.device) + feats_lens = torch.tensor([feats.size(1)], + dtype=torch.int64, + device=feats.device) decoder_out, token_num = self.model.forward_paraformer( feats, feats_lens) diff --git a/wenet/cli/transcribe.py b/wenet/cli/transcribe.py index f0086f4af..7fe6a1c47 100644 --- a/wenet/cli/transcribe.py +++ b/wenet/cli/transcribe.py @@ -50,11 +50,14 @@ def get_args(): parser.add_argument('--paraformer', action='store_true', help='whether to use the best chinese model') - parser.add_argument('--beam', type=int, default=5, - help="beam size") - parser.add_argument('--context_path', type=str, default=None, + parser.add_argument('--beam', type=int, default=5, help="beam size") + parser.add_argument('--context_path', + type=str, + default=None, help='context list file') - parser.add_argument('--context_score', type=float, default=6.0, + parser.add_argument('--context_score', + type=float, + default=6.0, help='context score') args = parser.parse_args() return args @@ -66,8 +69,8 @@ def main(): if args.paraformer: model = load_paraformer(args.model_dir, args.gpu) else: - model = load_model(args.language, args.model_dir, args.gpu, - args.beam, args.context_path, args.context_score) + model = load_model(args.language, args.model_dir, args.gpu, args.beam, + args.context_path, args.context_score) if args.align: result = model.align(args.audio_file, args.label) else: diff --git a/wenet/ctl_model/asr_model_ctl.py b/wenet/ctl_model/asr_model_ctl.py index 23581556d..de0d4d725 100644 --- a/wenet/ctl_model/asr_model_ctl.py +++ b/wenet/ctl_model/asr_model_ctl.py @@ -25,6 +25,7 @@ from wenet.transformer.asr_model import ASRModel from wenet.utils.common import IGNORE_ID + class CTLModel(ASRModel): """ Implementation of Interspeecch 2023 paper: @@ -32,6 +33,7 @@ class CTLModel(ASRModel): with Contrastive Learning' https://arxiv.org/abs/2306.00755 """ + def __init__( self, vocab_size: int, @@ -48,8 +50,8 @@ def __init__( ctl_weight: float = 1, ): assert 0.0 <= ctc_weight <= 1.0, ctc_weight - super().__init__(vocab_size, encoder, decoder, ctc, - ctc_weight, ignore_id, reverse_weight, lsm_weight, + super().__init__(vocab_size, encoder, decoder, ctc, ctc_weight, + ignore_id, reverse_weight, lsm_weight, length_normalized_loss) # For CTL Loss @@ -76,13 +78,18 @@ def forward( num = encoder_out_full.size(1) targets = encoder_out_full src = encoder_out - negs, negs_idxs = self.sample_negatives( - targets, targets.size(1), speech_lengths=lens_chunk) + negs, negs_idxs = self.sample_negatives(targets, + targets.size(1), + speech_lengths=lens_chunk) ctl_loss = self.CTL(src, targets, negs, encoder_mask) loss = loss_full + loss_chunk + self.ctl_weight * ctl_loss - return {"loss": loss, "loss_full": loss_full, - "loss_chunk": loss_chunk, "loss_ctl": ctl_loss} + return { + "loss": loss, + "loss_full": loss_full, + "loss_chunk": loss_chunk, + "loss_ctl": ctl_loss + } def forward_full( self, @@ -107,7 +114,8 @@ def forward_full( text_lengths.shape[0]), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) # 1. Encoder - encoder_out, encoder_mask = self.encoder.forward_full(speech, speech_lengths) + encoder_out, encoder_mask = self.encoder.forward_full( + speech, speech_lengths) encoder_out_lens = encoder_mask.squeeze(1).sum(1) # 2a. Attention-decoder branch @@ -194,33 +202,30 @@ def sample_negatives(self, y, num, padding_count=0, speech_lengths=None): assert high > 1, f"{bsz,tsz,fsz}" if self.n_negatives > 0: - tszs = ( - torch.arange(num) - .unsqueeze(-1) - .expand(-1, self.n_negatives) - .flatten() - ) + tszs = (torch.arange(num).unsqueeze(-1).expand( + -1, self.n_negatives).flatten()) if speech_lengths is not None: - neg_idxs = [torch.randint( - low=0, high=speech_lengths[i].item() - 1, - size=(1, self.n_negatives * tsz)) - for i in range(len(speech_lengths))] - neg_idxs = torch.cat(neg_idxs).reshape(bsz, self.n_negatives * tsz) + neg_idxs = [ + torch.randint(low=0, + high=speech_lengths[i].item() - 1, + size=(1, self.n_negatives * tsz)) + for i in range(len(speech_lengths)) + ] + neg_idxs = torch.cat(neg_idxs).reshape( + bsz, self.n_negatives * tsz) else: - neg_idxs = torch.randint( - low=0, high=num - 1, size=(bsz, self.n_negatives * tsz) - ) + neg_idxs = torch.randint(low=0, + high=num - 1, + size=(bsz, + self.n_negatives * tsz)) neg_idxs[neg_idxs >= tszs] += 1 if self.n_negatives > 0: neg_idxs = neg_idxs + (torch.arange(bsz).unsqueeze(1) * high) negs = y[neg_idxs.view(-1)] - negs = negs.contiguous().view( - bsz, num, self.n_negatives, fsz - ).permute( - 2, 0, 1, 3 - ) # to NxBxTxC + negs = negs.contiguous().view(bsz, num, self.n_negatives, + fsz).permute(2, 0, 1, 3) # to NxBxTxC return negs, neg_idxs def compute_preds(self, x, y, negatives): diff --git a/wenet/ctl_model/encoder.py b/wenet/ctl_model/encoder.py index 28c31d6a9..6b71d0cf8 100644 --- a/wenet/ctl_model/encoder.py +++ b/wenet/ctl_model/encoder.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # Modified from ESPnet(https://github.com/espnet/espnet) - """Encoder definition.""" from typing import Tuple @@ -24,8 +23,10 @@ from wenet.utils.mask import add_optional_chunk_mask from wenet.transformer.encoder import TransformerEncoder, ConformerEncoder + class DualTransformerEncoder(TransformerEncoder): """Transformer encoder module.""" + def __init__( self, input_size: int, @@ -51,8 +52,8 @@ def __init__( linear_units, num_blocks, dropout_rate, positional_dropout_rate, attention_dropout_rate, input_layer, pos_enc_layer_type, normalize_before, - static_chunk_size, use_dynamic_chunk, - global_cmvn, use_dynamic_left_chunk) + static_chunk_size, use_dynamic_chunk, global_cmvn, + use_dynamic_left_chunk) def forward( self, @@ -86,7 +87,8 @@ def forward( xs = self.global_cmvn(xs) xs, pos_emb, masks = self.embed(xs, masks) mask_pad = masks # (B, 1, T/subsample_rate) - chunk_masks = add_optional_chunk_mask(xs, masks, + chunk_masks = add_optional_chunk_mask(xs, + masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk, decoding_chunk_size, @@ -121,8 +123,10 @@ def forward_full( xs = self.after_norm(xs) return xs, masks + class DualConformerEncoder(ConformerEncoder): """Conformer encoder module.""" + def __init__( self, input_size: int, @@ -156,11 +160,10 @@ def __init__( linear_units, num_blocks, dropout_rate, positional_dropout_rate, attention_dropout_rate, input_layer, pos_enc_layer_type, normalize_before, - static_chunk_size, use_dynamic_chunk, - global_cmvn, use_dynamic_left_chunk, - positionwise_conv_kernel_size, macaron_style, - selfattention_layer_type, activation_type, - use_cnn_module, cnn_module_kernel, + static_chunk_size, use_dynamic_chunk, global_cmvn, + use_dynamic_left_chunk, positionwise_conv_kernel_size, + macaron_style, selfattention_layer_type, + activation_type, use_cnn_module, cnn_module_kernel, causal, cnn_module_norm) def forward( @@ -195,7 +198,8 @@ def forward( xs = self.global_cmvn(xs) xs, pos_emb, masks = self.embed(xs, masks) mask_pad = masks # (B, 1, T/subsample_rate) - chunk_masks = add_optional_chunk_mask(xs, masks, + chunk_masks = add_optional_chunk_mask(xs, + masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk, decoding_chunk_size, diff --git a/wenet/dataset/kaldi_io.py b/wenet/dataset/kaldi_io.py index c9bef293c..b686380e4 100644 --- a/wenet/dataset/kaldi_io.py +++ b/wenet/dataset/kaldi_io.py @@ -12,139 +12,166 @@ # Select kaldi, if not 'KALDI_ROOT' in os.environ: - # Default! To change run python with 'export KALDI_ROOT=/some_dir python' - os.environ['KALDI_ROOT']='/mnt/matylda5/iveselyk/Tools/kaldi-trunk' + # Default! To change run python with 'export KALDI_ROOT=/some_dir python' + os.environ['KALDI_ROOT'] = '/mnt/matylda5/iveselyk/Tools/kaldi-trunk' # Add kaldi tools to path, -os.environ['PATH'] = os.popen('echo $KALDI_ROOT/src/bin:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/src/fstbin/:$KALDI_ROOT/src/gmmbin/:$KALDI_ROOT/src/featbin/:$KALDI_ROOT/src/lm/:$KALDI_ROOT/src/sgmmbin/:$KALDI_ROOT/src/sgmm2bin/:$KALDI_ROOT/src/fgmmbin/:$KALDI_ROOT/src/latbin/:$KALDI_ROOT/src/nnetbin:$KALDI_ROOT/src/nnet2bin:$KALDI_ROOT/src/nnet3bin:$KALDI_ROOT/src/online2bin/:$KALDI_ROOT/src/ivectorbin/:$KALDI_ROOT/src/lmbin/').readline().strip() + ':' + os.environ['PATH'] +os.environ['PATH'] = os.popen( + 'echo $KALDI_ROOT/src/bin:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/src/fstbin/:$KALDI_ROOT/src/gmmbin/:$KALDI_ROOT/src/featbin/:$KALDI_ROOT/src/lm/:$KALDI_ROOT/src/sgmmbin/:$KALDI_ROOT/src/sgmm2bin/:$KALDI_ROOT/src/fgmmbin/:$KALDI_ROOT/src/latbin/:$KALDI_ROOT/src/nnetbin:$KALDI_ROOT/src/nnet2bin:$KALDI_ROOT/src/nnet3bin:$KALDI_ROOT/src/online2bin/:$KALDI_ROOT/src/ivectorbin/:$KALDI_ROOT/src/lmbin/' +).readline().strip() + ':' + os.environ['PATH'] ################################################# # Define all custom exceptions, -class UnsupportedDataType(Exception): pass -class UnknownVectorHeader(Exception): pass -class UnknownMatrixHeader(Exception): pass +class UnsupportedDataType(Exception): + pass -class BadSampleSize(Exception): pass -class BadInputFormat(Exception): pass -class SubprocessFailed(Exception): pass +class UnknownVectorHeader(Exception): + pass + + +class UnknownMatrixHeader(Exception): + pass + + +class BadSampleSize(Exception): + pass + + +class BadInputFormat(Exception): + pass + + +class SubprocessFailed(Exception): + pass + ################################################# # Data-type independent helper functions, + def open_or_fd(file, mode='rb'): - """ fd = open_or_fd(file) + """ fd = open_or_fd(file) Open file, gzipped file, pipe, or forward the file-descriptor. Eventually seeks in the 'file' argument contains ':offset' suffix. """ - offset = None - try: - # strip 'ark:' prefix from r{x,w}filename (optional), - if re.search('^(ark|scp)(,scp|,b|,t|,n?f|,n?p|,b?o|,n?s|,n?cs)*:', file): - (prefix,file) = file.split(':',1) - # separate offset from filename (optional), - if re.search(':[0-9]+$', file): - (file,offset) = file.rsplit(':',1) - # input pipe? - if file[-1] == '|': - fd = popen(file[:-1], 'rb') # custom, - # output pipe? - elif file[0] == '|': - fd = popen(file[1:], 'wb') # custom, - # is it gzipped? - elif file.split('.')[-1] == 'gz': - fd = gzip.open(file, mode) - # a normal file... - else: - fd = open(file, mode) - except TypeError: - # 'file' is opened file descriptor, - fd = file - # Eventually seek to offset, - if offset != None: fd.seek(int(offset)) - return fd + offset = None + try: + # strip 'ark:' prefix from r{x,w}filename (optional), + if re.search('^(ark|scp)(,scp|,b|,t|,n?f|,n?p|,b?o|,n?s|,n?cs)*:', + file): + (prefix, file) = file.split(':', 1) + # separate offset from filename (optional), + if re.search(':[0-9]+$', file): + (file, offset) = file.rsplit(':', 1) + # input pipe? + if file[-1] == '|': + fd = popen(file[:-1], 'rb') # custom, + # output pipe? + elif file[0] == '|': + fd = popen(file[1:], 'wb') # custom, + # is it gzipped? + elif file.split('.')[-1] == 'gz': + fd = gzip.open(file, mode) + # a normal file... + else: + fd = open(file, mode) + except TypeError: + # 'file' is opened file descriptor, + fd = file + # Eventually seek to offset, + if offset != None: fd.seek(int(offset)) + return fd + # based on '/usr/local/lib/python3.4/os.py' def popen(cmd, mode="rb"): - if not isinstance(cmd, str): - raise TypeError("invalid cmd type (%s, expected string)" % type(cmd)) - - import subprocess, io, threading - - # cleanup function for subprocesses, - def cleanup(proc, cmd): - ret = proc.wait() - if ret > 0: - raise SubprocessFailed('cmd %s returned %d !' % (cmd,ret)) - return - - # text-mode, - if mode == "r": - proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE) - threading.Thread(target=cleanup,args=(proc,cmd)).start() # clean-up thread, - return io.TextIOWrapper(proc.stdout) - elif mode == "w": - proc = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE) - threading.Thread(target=cleanup,args=(proc,cmd)).start() # clean-up thread, - return io.TextIOWrapper(proc.stdin) - # binary, - elif mode == "rb": - proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE) - threading.Thread(target=cleanup,args=(proc,cmd)).start() # clean-up thread, - return proc.stdout - elif mode == "wb": - proc = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE) - threading.Thread(target=cleanup,args=(proc,cmd)).start() # clean-up thread, - return proc.stdin - # sanity, - else: - raise ValueError("invalid mode %s" % mode) + if not isinstance(cmd, str): + raise TypeError("invalid cmd type (%s, expected string)" % type(cmd)) + + import subprocess, io, threading + + # cleanup function for subprocesses, + def cleanup(proc, cmd): + ret = proc.wait() + if ret > 0: + raise SubprocessFailed('cmd %s returned %d !' % (cmd, ret)) + return + + # text-mode, + if mode == "r": + proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE) + threading.Thread(target=cleanup, + args=(proc, cmd)).start() # clean-up thread, + return io.TextIOWrapper(proc.stdout) + elif mode == "w": + proc = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE) + threading.Thread(target=cleanup, + args=(proc, cmd)).start() # clean-up thread, + return io.TextIOWrapper(proc.stdin) + # binary, + elif mode == "rb": + proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE) + threading.Thread(target=cleanup, + args=(proc, cmd)).start() # clean-up thread, + return proc.stdout + elif mode == "wb": + proc = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE) + threading.Thread(target=cleanup, + args=(proc, cmd)).start() # clean-up thread, + return proc.stdin + # sanity, + else: + raise ValueError("invalid mode %s" % mode) def read_key(fd): - """ [key] = read_key(fd) + """ [key] = read_key(fd) Read the utterance-key from the opened ark/stream descriptor 'fd'. """ - key = '' - while 1: - char = fd.read(1).decode("latin1") - if char == '' : break - if char == ' ' : break - key += char - key = key.strip() - if key == '': return None # end of file, - assert(re.match('^\S+$',key) != None) # check format (no whitespace!) - return key + key = '' + while 1: + char = fd.read(1).decode("latin1") + if char == '': break + if char == ' ': break + key += char + key = key.strip() + if key == '': return None # end of file, + assert (re.match('^\S+$', key) != None) # check format (no whitespace!) + return key ################################################# # Integer vectors (alignments, ...), + def read_ali_ark(file_or_fd): - """ Alias to 'read_vec_int_ark()' """ - return read_vec_int_ark(file_or_fd) + """ Alias to 'read_vec_int_ark()' """ + return read_vec_int_ark(file_or_fd) + def read_vec_int_ark(file_or_fd): - """ generator(key,vec) = read_vec_int_ark(file_or_fd) + """ generator(key,vec) = read_vec_int_ark(file_or_fd) Create generator of (key,vector) tuples, which reads from the ark file/stream. file_or_fd : ark, gzipped ark, pipe or opened file descriptor. Read ark to a 'dictionary': d = { u:d for u,d in kaldi_io.read_vec_int_ark(file) } """ - fd = open_or_fd(file_or_fd) - try: - key = read_key(fd) - while key: - ali = read_vec_int(fd) - yield key, ali - key = read_key(fd) - finally: - if fd is not file_or_fd: fd.close() + fd = open_or_fd(file_or_fd) + try: + key = read_key(fd) + while key: + ali = read_vec_int(fd) + yield key, ali + key = read_key(fd) + finally: + if fd is not file_or_fd: fd.close() + def read_vec_int_scp(file_or_fd): - """ generator(key,vec) = read_vec_int_scp(file_or_fd) + """ generator(key,vec) = read_vec_int_scp(file_or_fd) Returns generator of (key,vector) tuples, read according to kaldi scp. file_or_fd : scp, gzipped scp, pipe or opened file descriptor. @@ -155,41 +182,48 @@ def read_vec_int_scp(file_or_fd): Read scp to a 'dictionary': d = { key:vec for key,mat in kaldi_io.read_vec_int_scp(file) } """ - fd = open_or_fd(file_or_fd) - try: - for line in fd: - (key,rxfile) = line.decode().split(' ') - vec = read_vec_int(rxfile) - yield key, vec - finally: - if fd is not file_or_fd : fd.close() + fd = open_or_fd(file_or_fd) + try: + for line in fd: + (key, rxfile) = line.decode().split(' ') + vec = read_vec_int(rxfile) + yield key, vec + finally: + if fd is not file_or_fd: fd.close() + def read_vec_int(file_or_fd): - """ [int-vec] = read_vec_int(file_or_fd) + """ [int-vec] = read_vec_int(file_or_fd) Read kaldi integer vector, ascii or binary input, """ - fd = open_or_fd(file_or_fd) - binary = fd.read(2).decode() - if binary == '\0B': # binary flag - assert(fd.read(1).decode() == '\4'); # int-size - vec_size = np.frombuffer(fd.read(4), dtype='int32', count=1)[0] # vector dim - # Elements from int32 vector are sored in tuples: (sizeof(int32), value), - vec = np.frombuffer(fd.read(vec_size*5), dtype=[('size','int8'),('value','int32')], count=vec_size) - assert(vec[0]['size'] == 4) # int32 size, - ans = vec[:]['value'] # values are in 2nd column, - else: # ascii, - arr = (binary + fd.readline().decode()).strip().split() - try: - arr.remove('['); arr.remove(']') # optionally - except ValueError: - pass - ans = np.array(arr, dtype=int) - if fd is not file_or_fd : fd.close() # cleanup - return ans + fd = open_or_fd(file_or_fd) + binary = fd.read(2).decode() + if binary == '\0B': # binary flag + assert (fd.read(1).decode() == '\4') + # int-size + vec_size = np.frombuffer(fd.read(4), dtype='int32', + count=1)[0] # vector dim + # Elements from int32 vector are sored in tuples: (sizeof(int32), value), + vec = np.frombuffer(fd.read(vec_size * 5), + dtype=[('size', 'int8'), ('value', 'int32')], + count=vec_size) + assert (vec[0]['size'] == 4) # int32 size, + ans = vec[:]['value'] # values are in 2nd column, + else: # ascii, + arr = (binary + fd.readline().decode()).strip().split() + try: + arr.remove('[') + arr.remove(']') # optionally + except ValueError: + pass + ans = np.array(arr, dtype=int) + if fd is not file_or_fd: fd.close() # cleanup + return ans + # Writing, def write_vec_int(file_or_fd, v, key=''): - """ write_vec_int(f, v, key='') + """ write_vec_int(f, v, key='') Write a binary kaldi integer vector to filename or stream. Arguments: file_or_fd : filename or opened file descriptor for writing, @@ -204,28 +238,32 @@ def write_vec_int(file_or_fd, v, key=''): for key,vec in dict.iteritems(): kaldi_io.write_vec_flt(f, vec, key=key) """ - fd = open_or_fd(file_or_fd, mode='wb') - if sys.version_info[0] == 3: assert(fd.mode == 'wb') - try: - if key != '' : fd.write((key+' ').encode("latin1")) # ark-files have keys (utterance-id), - fd.write('\0B'.encode()) # we write binary! - # dim, - fd.write('\4'.encode()) # int32 type, - fd.write(struct.pack(np.dtype('int32').char, v.shape[0])) - # data, - for i in range(len(v)): - fd.write('\4'.encode()) # int32 type, - fd.write(struct.pack(np.dtype('int32').char, v[i])) # binary, - finally: - if fd is not file_or_fd : fd.close() + fd = open_or_fd(file_or_fd, mode='wb') + if sys.version_info[0] == 3: assert (fd.mode == 'wb') + try: + if key != '': + fd.write( + (key + + ' ').encode("latin1")) # ark-files have keys (utterance-id), + fd.write('\0B'.encode()) # we write binary! + # dim, + fd.write('\4'.encode()) # int32 type, + fd.write(struct.pack(np.dtype('int32').char, v.shape[0])) + # data, + for i in range(len(v)): + fd.write('\4'.encode()) # int32 type, + fd.write(struct.pack(np.dtype('int32').char, v[i])) # binary, + finally: + if fd is not file_or_fd: fd.close() ################################################# # Float vectors (confidences, ivectors, ...), + # Reading, def read_vec_flt_scp(file_or_fd): - """ generator(key,mat) = read_vec_flt_scp(file_or_fd) + """ generator(key,mat) = read_vec_flt_scp(file_or_fd) Returns generator of (key,vector) tuples, read according to kaldi scp. file_or_fd : scp, gzipped scp, pipe or opened file descriptor. @@ -236,68 +274,74 @@ def read_vec_flt_scp(file_or_fd): Read scp to a 'dictionary': d = { key:mat for key,mat in kaldi_io.read_mat_scp(file) } """ - fd = open_or_fd(file_or_fd) - try: - for line in fd: - (key,rxfile) = line.decode().split(' ') - vec = read_vec_flt(rxfile) - yield key, vec - finally: - if fd is not file_or_fd : fd.close() + fd = open_or_fd(file_or_fd) + try: + for line in fd: + (key, rxfile) = line.decode().split(' ') + vec = read_vec_flt(rxfile) + yield key, vec + finally: + if fd is not file_or_fd: fd.close() + def read_vec_flt_ark(file_or_fd): - """ generator(key,vec) = read_vec_flt_ark(file_or_fd) + """ generator(key,vec) = read_vec_flt_ark(file_or_fd) Create generator of (key,vector) tuples, reading from an ark file/stream. file_or_fd : ark, gzipped ark, pipe or opened file descriptor. Read ark to a 'dictionary': d = { u:d for u,d in kaldi_io.read_vec_flt_ark(file) } """ - fd = open_or_fd(file_or_fd) - try: - key = read_key(fd) - while key: - ali = read_vec_flt(fd) - yield key, ali - key = read_key(fd) - finally: - if fd is not file_or_fd: fd.close() + fd = open_or_fd(file_or_fd) + try: + key = read_key(fd) + while key: + ali = read_vec_flt(fd) + yield key, ali + key = read_key(fd) + finally: + if fd is not file_or_fd: fd.close() + def read_vec_flt(file_or_fd): - """ [flt-vec] = read_vec_flt(file_or_fd) + """ [flt-vec] = read_vec_flt(file_or_fd) Read kaldi float vector, ascii or binary input, """ - fd = open_or_fd(file_or_fd) - binary = fd.read(2).decode() - if binary == '\0B': # binary flag - # Data type, - header = fd.read(3).decode() - if header == 'FV ': sample_size = 4 # floats - elif header == 'DV ': sample_size = 8 # doubles - else: raise UnknownVectorHeader("The header contained '%s'" % header) - assert(sample_size > 0) - # Dimension, - assert(fd.read(1).decode() == '\4'); # int-size - vec_size = np.frombuffer(fd.read(4), dtype='int32', count=1)[0] # vector dim - # Read whole vector, - buf = fd.read(vec_size * sample_size) - if sample_size == 4 : ans = np.frombuffer(buf, dtype='float32') - elif sample_size == 8 : ans = np.frombuffer(buf, dtype='float64') - else : raise BadSampleSize + fd = open_or_fd(file_or_fd) + binary = fd.read(2).decode() + if binary == '\0B': # binary flag + # Data type, + header = fd.read(3).decode() + if header == 'FV ': sample_size = 4 # floats + elif header == 'DV ': sample_size = 8 # doubles + else: raise UnknownVectorHeader("The header contained '%s'" % header) + assert (sample_size > 0) + # Dimension, + assert (fd.read(1).decode() == '\4') + # int-size + vec_size = np.frombuffer(fd.read(4), dtype='int32', + count=1)[0] # vector dim + # Read whole vector, + buf = fd.read(vec_size * sample_size) + if sample_size == 4: ans = np.frombuffer(buf, dtype='float32') + elif sample_size == 8: ans = np.frombuffer(buf, dtype='float64') + else: raise BadSampleSize + return ans + else: # ascii, + arr = (binary + fd.readline().decode()).strip().split() + try: + arr.remove('[') + arr.remove(']') # optionally + except ValueError: + pass + ans = np.array(arr, dtype=float) + if fd is not file_or_fd: fd.close() # cleanup return ans - else: # ascii, - arr = (binary + fd.readline().decode()).strip().split() - try: - arr.remove('['); arr.remove(']') # optionally - except ValueError: - pass - ans = np.array(arr, dtype=float) - if fd is not file_or_fd : fd.close() # cleanup - return ans + # Writing, def write_vec_flt(file_or_fd, v, key=''): - """ write_vec_flt(f, v, key='') + """ write_vec_flt(f, v, key='') Write a binary kaldi vector to filename or stream. Supports 32bit and 64bit floats. Arguments: file_or_fd : filename or opened file descriptor for writing, @@ -312,30 +356,36 @@ def write_vec_flt(file_or_fd, v, key=''): for key,vec in dict.iteritems(): kaldi_io.write_vec_flt(f, vec, key=key) """ - fd = open_or_fd(file_or_fd, mode='wb') - if sys.version_info[0] == 3: assert(fd.mode == 'wb') - try: - if key != '' : fd.write((key+' ').encode("latin1")) # ark-files have keys (utterance-id), - fd.write('\0B'.encode()) # we write binary! - # Data-type, - if v.dtype == 'float32': fd.write('FV '.encode()) - elif v.dtype == 'float64': fd.write('DV '.encode()) - else: raise UnsupportedDataType("'%s', please use 'float32' or 'float64'" % v.dtype) - # Dim, - fd.write('\04'.encode()) - fd.write(struct.pack(np.dtype('uint32').char, v.shape[0])) # dim - # Data, - fd.write(v.tobytes()) - finally: - if fd is not file_or_fd : fd.close() + fd = open_or_fd(file_or_fd, mode='wb') + if sys.version_info[0] == 3: assert (fd.mode == 'wb') + try: + if key != '': + fd.write( + (key + + ' ').encode("latin1")) # ark-files have keys (utterance-id), + fd.write('\0B'.encode()) # we write binary! + # Data-type, + if v.dtype == 'float32': fd.write('FV '.encode()) + elif v.dtype == 'float64': fd.write('DV '.encode()) + else: + raise UnsupportedDataType( + "'%s', please use 'float32' or 'float64'" % v.dtype) + # Dim, + fd.write('\04'.encode()) + fd.write(struct.pack(np.dtype('uint32').char, v.shape[0])) # dim + # Data, + fd.write(v.tobytes()) + finally: + if fd is not file_or_fd: fd.close() ################################################# # Float matrices (features, transformations, ...), + # Reading, def read_mat_scp(file_or_fd): - """ generator(key,mat) = read_mat_scp(file_or_fd) + """ generator(key,mat) = read_mat_scp(file_or_fd) Returns generator of (key,matrix) tuples, read according to kaldi scp. file_or_fd : scp, gzipped scp, pipe or opened file descriptor. @@ -346,17 +396,18 @@ def read_mat_scp(file_or_fd): Read scp to a 'dictionary': d = { key:mat for key,mat in kaldi_io.read_mat_scp(file) } """ - fd = open_or_fd(file_or_fd) - try: - for line in fd: - (key,rxfile) = line.decode().split(' ') - mat = read_mat(rxfile) - yield key, mat - finally: - if fd is not file_or_fd : fd.close() + fd = open_or_fd(file_or_fd) + try: + for line in fd: + (key, rxfile) = line.decode().split(' ') + mat = read_mat(rxfile) + yield key, mat + finally: + if fd is not file_or_fd: fd.close() + def read_mat_ark(file_or_fd): - """ generator(key,mat) = read_mat_ark(file_or_fd) + """ generator(key,mat) = read_mat_ark(file_or_fd) Returns generator of (key,matrix) tuples, read from ark file/stream. file_or_fd : scp, gzipped scp, pipe or opened file descriptor. @@ -367,122 +418,145 @@ def read_mat_ark(file_or_fd): Read ark to a 'dictionary': d = { key:mat for key,mat in kaldi_io.read_mat_ark(file) } """ - fd = open_or_fd(file_or_fd) - try: - key = read_key(fd) - while key: - mat = read_mat(fd) - yield key, mat - key = read_key(fd) - finally: - if fd is not file_or_fd : fd.close() + fd = open_or_fd(file_or_fd) + try: + key = read_key(fd) + while key: + mat = read_mat(fd) + yield key, mat + key = read_key(fd) + finally: + if fd is not file_or_fd: fd.close() + def read_mat(file_or_fd): - """ [mat] = read_mat(file_or_fd) + """ [mat] = read_mat(file_or_fd) Reads single kaldi matrix, supports ascii and binary. file_or_fd : file, gzipped file, pipe or opened file descriptor. """ - fd = open_or_fd(file_or_fd) - try: - binary = fd.read(2).decode() - if binary == '\0B' : - mat = _read_mat_binary(fd) - else: - assert(binary == ' [') - mat = _read_mat_ascii(fd) - finally: - if fd is not file_or_fd: fd.close() - return mat + fd = open_or_fd(file_or_fd) + try: + binary = fd.read(2).decode() + if binary == '\0B': + mat = _read_mat_binary(fd) + else: + assert (binary == ' [') + mat = _read_mat_ascii(fd) + finally: + if fd is not file_or_fd: fd.close() + return mat + def _read_mat_binary(fd): - # Data type - header = fd.read(3).decode() - # 'CM', 'CM2', 'CM3' are possible values, - if header.startswith('CM'): return _read_compressed_mat(fd, header) - elif header == 'FM ': sample_size = 4 # floats - elif header == 'DM ': sample_size = 8 # doubles - else: raise UnknownMatrixHeader("The header contained '%s'" % header) - assert(sample_size > 0) - # Dimensions - s1, rows, s2, cols = np.frombuffer(fd.read(10), dtype='int8,int32,int8,int32', count=1)[0] - # Read whole matrix - buf = fd.read(rows * cols * sample_size) - if sample_size == 4 : vec = np.frombuffer(buf, dtype='float32') - elif sample_size == 8 : vec = np.frombuffer(buf, dtype='float64') - else : raise BadSampleSize - mat = np.reshape(vec,(rows,cols)) - return mat + # Data type + header = fd.read(3).decode() + # 'CM', 'CM2', 'CM3' are possible values, + if header.startswith('CM'): return _read_compressed_mat(fd, header) + elif header == 'FM ': sample_size = 4 # floats + elif header == 'DM ': sample_size = 8 # doubles + else: raise UnknownMatrixHeader("The header contained '%s'" % header) + assert (sample_size > 0) + # Dimensions + s1, rows, s2, cols = np.frombuffer(fd.read(10), + dtype='int8,int32,int8,int32', + count=1)[0] + # Read whole matrix + buf = fd.read(rows * cols * sample_size) + if sample_size == 4: vec = np.frombuffer(buf, dtype='float32') + elif sample_size == 8: vec = np.frombuffer(buf, dtype='float64') + else: raise BadSampleSize + mat = np.reshape(vec, (rows, cols)) + return mat + def _read_mat_ascii(fd): - rows = [] - while 1: - line = fd.readline().decode() - if (len(line) == 0) : raise BadInputFormat # eof, should not happen! - if len(line.strip()) == 0 : continue # skip empty line - arr = line.strip().split() - if arr[-1] != ']': - rows.append(np.array(arr,dtype='float32')) # not last line - else: - rows.append(np.array(arr[:-1],dtype='float32')) # last line - mat = np.vstack(rows) - return mat + rows = [] + while 1: + line = fd.readline().decode() + if (len(line) == 0): raise BadInputFormat # eof, should not happen! + if len(line.strip()) == 0: continue # skip empty line + arr = line.strip().split() + if arr[-1] != ']': + rows.append(np.array(arr, dtype='float32')) # not last line + else: + rows.append(np.array(arr[:-1], dtype='float32')) # last line + mat = np.vstack(rows) + return mat def _read_compressed_mat(fd, format): - """ Read a compressed matrix, + """ Read a compressed matrix, see: https://github.com/kaldi-asr/kaldi/blob/master/src/matrix/compressed-matrix.h methods: CompressedMatrix::Read(...), CompressedMatrix::CopyToMat(...), """ - assert(format == 'CM ') # The formats CM2, CM3 are not supported... - - # Format of header 'struct', - global_header = np.dtype([('minvalue','float32'),('range','float32'),('num_rows','int32'),('num_cols','int32')]) # member '.format' is not written, - per_col_header = np.dtype([('percentile_0','uint16'),('percentile_25','uint16'),('percentile_75','uint16'),('percentile_100','uint16')]) - - # Mapping for percentiles in col-headers, - def uint16_to_float(value, min, range): - return np.float32(min + range * 1.52590218966964e-05 * value) - - # Mapping for matrix elements, - def uint8_to_float_v2(vec, p0, p25, p75, p100): - # Split the vector by masks, - mask_0_64 = (vec <= 64); - mask_193_255 = (vec > 192); - mask_65_192 = (~(mask_0_64 | mask_193_255)); - # Sanity check (useful but slow...), - # assert(len(vec) == np.sum(np.hstack([mask_0_64,mask_65_192,mask_193_255]))) - # assert(len(vec) == np.sum(np.any([mask_0_64,mask_65_192,mask_193_255], axis=0))) - # Build the float vector, - ans = np.empty(len(vec), dtype='float32') - ans[mask_0_64] = p0 + (p25 - p0) / 64. * vec[mask_0_64] - ans[mask_65_192] = p25 + (p75 - p25) / 128. * (vec[mask_65_192] - 64) - ans[mask_193_255] = p75 + (p100 - p75) / 63. * (vec[mask_193_255] - 192) - return ans - - # Read global header, - globmin, globrange, rows, cols = np.frombuffer(fd.read(16), dtype=global_header, count=1)[0] - - # The data is structed as [Colheader, ... , Colheader, Data, Data , .... ] - # { cols }{ size } - col_headers = np.frombuffer(fd.read(cols*8), dtype=per_col_header, count=cols) - data = np.reshape(np.frombuffer(fd.read(cols*rows), dtype='uint8', count=cols*rows), newshape=(cols,rows)) # stored as col-major, + assert (format == 'CM ') # The formats CM2, CM3 are not supported... + + # Format of header 'struct', + global_header = np.dtype([('minvalue', 'float32'), ('range', 'float32'), + ('num_rows', 'int32'), ('num_cols', 'int32') + ]) # member '.format' is not written, + per_col_header = np.dtype([('percentile_0', 'uint16'), + ('percentile_25', 'uint16'), + ('percentile_75', 'uint16'), + ('percentile_100', 'uint16')]) + + # Mapping for percentiles in col-headers, + def uint16_to_float(value, min, range): + return np.float32(min + range * 1.52590218966964e-05 * value) + + # Mapping for matrix elements, + def uint8_to_float_v2(vec, p0, p25, p75, p100): + # Split the vector by masks, + mask_0_64 = (vec <= 64) + mask_193_255 = (vec > 192) + mask_65_192 = (~(mask_0_64 | mask_193_255)) + # Sanity check (useful but slow...), + # assert(len(vec) == np.sum(np.hstack([mask_0_64,mask_65_192,mask_193_255]))) + # assert(len(vec) == np.sum(np.any([mask_0_64,mask_65_192,mask_193_255], axis=0))) + # Build the float vector, + ans = np.empty(len(vec), dtype='float32') + ans[mask_0_64] = p0 + (p25 - p0) / 64. * vec[mask_0_64] + ans[mask_65_192] = p25 + (p75 - p25) / 128. * (vec[mask_65_192] - 64) + ans[mask_193_255] = p75 + (p100 - p75) / 63. * (vec[mask_193_255] - + 192) + return ans + + # Read global header, + globmin, globrange, rows, cols = np.frombuffer(fd.read(16), + dtype=global_header, + count=1)[0] + + # The data is structed as [Colheader, ... , Colheader, Data, Data , .... ] + # { cols }{ size } + col_headers = np.frombuffer(fd.read(cols * 8), + dtype=per_col_header, + count=cols) + data = np.reshape(np.frombuffer(fd.read(cols * rows), + dtype='uint8', + count=cols * rows), + newshape=(cols, rows)) # stored as col-major, + + mat = np.empty((cols, rows), dtype='float32') + for i, col_header in enumerate(col_headers): + col_header_flt = [ + uint16_to_float(percentile, globmin, globrange) + for percentile in col_header + ] + mat[i] = uint8_to_float_v2(data[i], *col_header_flt) + + return mat.T # transpose! col-major -> row-major, - mat = np.empty((cols,rows), dtype='float32') - for i, col_header in enumerate(col_headers): - col_header_flt = [ uint16_to_float(percentile, globmin, globrange) for percentile in col_header ] - mat[i] = uint8_to_float_v2(data[i], *col_header_flt) - - return mat.T # transpose! col-major -> row-major, def write_ark_scp(key, mat, ark_fout, scp_out): - mat_offset = write_mat(ark_fout, mat, key) - scp_line = '{}\t{}:{}'.format(key, ark_fout.name, mat_offset) - scp_out.write(scp_line) - scp_out.write('\n') + mat_offset = write_mat(ark_fout, mat, key) + scp_line = '{}\t{}:{}'.format(key, ark_fout.name, mat_offset) + scp_out.write(scp_line) + scp_out.write('\n') + # Writing, def write_mat(file_or_fd, m, key=''): - """ write_mat(f, m, key='') + """ write_mat(f, m, key='') Write a binary kaldi matrix to filename or stream. Supports 32bit and 64bit floats. Arguments: file_or_fd : filename of opened file descriptor for writing, @@ -497,27 +571,33 @@ def write_mat(file_or_fd, m, key=''): for key,mat in dict.iteritems(): kaldi_io.write_mat(f, mat, key=key) """ - mat_offset = 0 - fd = open_or_fd(file_or_fd, mode='wb') - if sys.version_info[0] == 3: assert(fd.mode == 'wb') - try: - if key != '' : fd.write((key+' ').encode("latin1")) # ark-files have keys (utterance-id), - mat_offset = fd.tell() - fd.write('\0B'.encode()) # we write binary! - # Data-type, - if m.dtype == 'float32': fd.write('FM '.encode()) - elif m.dtype == 'float64': fd.write('DM '.encode()) - else: raise UnsupportedDataType("'%s', please use 'float32' or 'float64'" % m.dtype) - # Dims, - fd.write('\04'.encode()) - fd.write(struct.pack(np.dtype('uint32').char, m.shape[0])) # rows - fd.write('\04'.encode()) - fd.write(struct.pack(np.dtype('uint32').char, m.shape[1])) # cols - # Data, - fd.write(m.tobytes()) - finally: - if fd is not file_or_fd : fd.close() - return mat_offset + mat_offset = 0 + fd = open_or_fd(file_or_fd, mode='wb') + if sys.version_info[0] == 3: assert (fd.mode == 'wb') + try: + if key != '': + fd.write( + (key + + ' ').encode("latin1")) # ark-files have keys (utterance-id), + mat_offset = fd.tell() + fd.write('\0B'.encode()) # we write binary! + # Data-type, + if m.dtype == 'float32': fd.write('FM '.encode()) + elif m.dtype == 'float64': fd.write('DM '.encode()) + else: + raise UnsupportedDataType( + "'%s', please use 'float32' or 'float64'" % m.dtype) + # Dims, + fd.write('\04'.encode()) + fd.write(struct.pack(np.dtype('uint32').char, m.shape[0])) # rows + fd.write('\04'.encode()) + fd.write(struct.pack(np.dtype('uint32').char, m.shape[1])) # cols + # Data, + fd.write(m.tobytes()) + finally: + if fd is not file_or_fd: fd.close() + return mat_offset + ################################################# # 'Posterior' kaldi type (posteriors, confusion network, nnet1 training targets, ...) @@ -527,12 +607,14 @@ def write_mat(file_or_fd, m, key=''): # - tuple: int = index, float = value # + def read_cnet_ark(file_or_fd): - """ Alias of function 'read_post_ark()', 'cnet' = confusion network """ - return read_post_ark(file_or_fd) + """ Alias of function 'read_post_ark()', 'cnet' = confusion network """ + return read_post_ark(file_or_fd) + def read_post_ark(file_or_fd): - """ generator(key,vec>) = read_post_ark(file) + """ generator(key,vec>) = read_post_ark(file) Returns generator of (key,posterior) tuples, read from ark file. file_or_fd : ark, gzipped ark, pipe or opened file descriptor. @@ -543,18 +625,19 @@ def read_post_ark(file_or_fd): Read ark to a 'dictionary': d = { key:post for key,post in kaldi_io.read_post_ark(file) } """ - fd = open_or_fd(file_or_fd) - try: - key = read_key(fd) - while key: - post = read_post(fd) - yield key, post - key = read_key(fd) - finally: - if fd is not file_or_fd: fd.close() + fd = open_or_fd(file_or_fd) + try: + key = read_key(fd) + while key: + post = read_post(fd) + yield key, post + key = read_key(fd) + finally: + if fd is not file_or_fd: fd.close() + def read_post(file_or_fd): - """ [post] = read_post(file_or_fd) + """ [post] = read_post(file_or_fd) Reads single kaldi 'Posterior' in binary format. The 'Posterior' is C++ type 'vector > >', @@ -565,23 +648,34 @@ def read_post(file_or_fd): Returns vector of vectors of tuples. """ - fd = open_or_fd(file_or_fd) - ans=[] - binary = fd.read(2).decode(); assert(binary == '\0B'); # binary flag - assert(fd.read(1).decode() == '\4'); # int-size - outer_vec_size = np.frombuffer(fd.read(4), dtype='int32', count=1)[0] # number of frames (or bins) - - # Loop over 'outer-vector', - for i in range(outer_vec_size): - assert(fd.read(1).decode() == '\4'); # int-size - inner_vec_size = np.frombuffer(fd.read(4), dtype='int32', count=1)[0] # number of records for frame (or bin) - data = np.frombuffer(fd.read(inner_vec_size*10), dtype=[('size_idx','int8'),('idx','int32'),('size_post','int8'),('post','float32')], count=inner_vec_size) - assert(data[0]['size_idx'] == 4) - assert(data[0]['size_post'] == 4) - ans.append(data[['idx','post']].tolist()) + fd = open_or_fd(file_or_fd) + ans = [] + binary = fd.read(2).decode() + assert (binary == '\0B') + # binary flag + assert (fd.read(1).decode() == '\4') + # int-size + outer_vec_size = np.frombuffer(fd.read(4), dtype='int32', + count=1)[0] # number of frames (or bins) + + # Loop over 'outer-vector', + for i in range(outer_vec_size): + assert (fd.read(1).decode() == '\4') + # int-size + inner_vec_size = np.frombuffer( + fd.read(4), dtype='int32', + count=1)[0] # number of records for frame (or bin) + data = np.frombuffer(fd.read(inner_vec_size * 10), + dtype=[('size_idx', 'int8'), ('idx', 'int32'), + ('size_post', 'int8'), + ('post', 'float32')], + count=inner_vec_size) + assert (data[0]['size_idx'] == 4) + assert (data[0]['size_post'] == 4) + ans.append(data[['idx', 'post']].tolist()) - if fd is not file_or_fd: fd.close() - return ans + if fd is not file_or_fd: fd.close() + return ans ################################################# @@ -589,8 +683,9 @@ def read_post(file_or_fd): # (kaldi stores CNs time info separately from the Posterior). # + def read_cntime_ark(file_or_fd): - """ generator(key,vec>) = read_cntime_ark(file_or_fd) + """ generator(key,vec>) = read_cntime_ark(file_or_fd) Returns generator of (key,cntime) tuples, read from ark file. file_or_fd : file, gzipped file, pipe or opened file descriptor. @@ -601,18 +696,19 @@ def read_cntime_ark(file_or_fd): Read ark to a 'dictionary': d = { key:time for key,time in kaldi_io.read_post_ark(file) } """ - fd = open_or_fd(file_or_fd) - try: - key = read_key(fd) - while key: - cntime = read_cntime(fd) - yield key, cntime - key = read_key(fd) - finally: - if fd is not file_or_fd : fd.close() + fd = open_or_fd(file_or_fd) + try: + key = read_key(fd) + while key: + cntime = read_cntime(fd) + yield key, cntime + key = read_key(fd) + finally: + if fd is not file_or_fd: fd.close() + def read_cntime(file_or_fd): - """ [cntime] = read_cntime(file_or_fd) + """ [cntime] = read_cntime(file_or_fd) Reads single kaldi 'Confusion Network time info', in binary format: C++ type: vector >. (begin/end times of bins at the confusion network). @@ -623,44 +719,54 @@ def read_cntime(file_or_fd): Returns vector of tuples. """ - fd = open_or_fd(file_or_fd) - binary = fd.read(2).decode(); assert(binary == '\0B'); # assuming it's binary - - assert(fd.read(1).decode() == '\4'); # int-size - vec_size = np.frombuffer(fd.read(4), dtype='int32', count=1)[0] # number of frames (or bins) - - data = np.frombuffer(fd.read(vec_size*10), dtype=[('size_beg','int8'),('t_beg','float32'),('size_end','int8'),('t_end','float32')], count=vec_size) - assert(data[0]['size_beg'] == 4) - assert(data[0]['size_end'] == 4) - ans = data[['t_beg','t_end']].tolist() # Return vector of tuples (t_beg,t_end), + fd = open_or_fd(file_or_fd) + binary = fd.read(2).decode() + assert (binary == '\0B') + # assuming it's binary + + assert (fd.read(1).decode() == '\4') + # int-size + vec_size = np.frombuffer(fd.read(4), dtype='int32', + count=1)[0] # number of frames (or bins) + + data = np.frombuffer(fd.read(vec_size * 10), + dtype=[('size_beg', 'int8'), ('t_beg', 'float32'), + ('size_end', 'int8'), ('t_end', 'float32')], + count=vec_size) + assert (data[0]['size_beg'] == 4) + assert (data[0]['size_end'] == 4) + ans = data[['t_beg', + 't_end']].tolist() # Return vector of tuples (t_beg,t_end), - if fd is not file_or_fd : fd.close() - return ans + if fd is not file_or_fd: fd.close() + return ans ################################################# # Segments related, # + # Segments as 'Bool vectors' can be handy, # - for 'superposing' the segmentations, # - for frame-selection in Speaker-ID experiments, def read_segments_as_bool_vec(segments_file): - """ [ bool_vec ] = read_segments_as_bool_vec(segments_file) + """ [ bool_vec ] = read_segments_as_bool_vec(segments_file) using kaldi 'segments' file for 1 wav, format : ' ' - t-beg, t-end is in seconds, - assumed 100 frames/second, """ - segs = np.loadtxt(segments_file, dtype='object,object,f,f', ndmin=1) - # Sanity checks, - assert(len(segs) > 0) # empty segmentation is an error, - assert(len(np.unique([rec[1] for rec in segs ])) == 1) # segments with only 1 wav-file, - # Convert time to frame-indexes, - start = np.rint([100 * rec[2] for rec in segs]).astype(int) - end = np.rint([100 * rec[3] for rec in segs]).astype(int) - # Taken from 'read_lab_to_bool_vec', htk.py, - frms = np.repeat(np.r_[np.tile([False,True], len(end)), False], - np.r_[np.c_[start - np.r_[0, end[:-1]], end-start].flat, 0]) - assert np.sum(end-start) == np.sum(frms) - return frms - + segs = np.loadtxt(segments_file, dtype='object,object,f,f', ndmin=1) + # Sanity checks, + assert (len(segs) > 0) # empty segmentation is an error, + assert (len(np.unique([rec[1] for rec in segs])) == 1 + ) # segments with only 1 wav-file, + # Convert time to frame-indexes, + start = np.rint([100 * rec[2] for rec in segs]).astype(int) + end = np.rint([100 * rec[3] for rec in segs]).astype(int) + # Taken from 'read_lab_to_bool_vec', htk.py, + frms = np.repeat( + np.r_[np.tile([False, True], len(end)), False], + np.r_[np.c_[start - np.r_[0, end[:-1]], end - start].flat, 0]) + assert np.sum(end - start) == np.sum(frms) + return frms diff --git a/wenet/dataset/processor.py b/wenet/dataset/processor.py index a769eba8e..7992d4303 100644 --- a/wenet/dataset/processor.py +++ b/wenet/dataset/processor.py @@ -348,20 +348,25 @@ def compute_log_mel_spectrogram(data, if padding > 0: waveform = F.pad(waveform, (0, padding)) window = torch.hann_window(n_fft) - stft = torch.stft(waveform, n_fft, hop_length, - window=window, return_complex=True) - magnitudes = stft[..., :-1].abs() ** 2 + stft = torch.stft(waveform, + n_fft, + hop_length, + window=window, + return_complex=True) + magnitudes = stft[..., :-1].abs()**2 filters = torch.from_numpy( - librosa.filters.mel(sr=sample_rate, n_fft=n_fft, n_mels=num_mel_bins) - ) + librosa.filters.mel(sr=sample_rate, + n_fft=n_fft, + n_mels=num_mel_bins)) mel_spec = filters @ magnitudes # NOTE(xcsong): https://github.com/openai/whisper/discussions/269 log_spec = torch.clamp(mel_spec, min=1e-10).log10() log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) log_spec = (log_spec + 4.0) / 4.0 - yield dict(key=sample['key'], label=sample['label'], + yield dict(key=sample['key'], + label=sample['label'], feat=log_spec.transpose(0, 1)) @@ -617,5 +622,10 @@ def padding(data): batch_first=True, padding_value=-1) - yield {"keys": sorted_keys, "feats": padded_feats, "target": padding_labels, - "feats_lengths": feats_lengths, "target_lengths": label_lengths} + yield { + "keys": sorted_keys, + "feats": padded_feats, + "target": padding_labels, + "feats_lengths": feats_lengths, + "target_lengths": label_lengths + } diff --git a/wenet/dataset/wav_distortion.py b/wenet/dataset/wav_distortion.py index 2917d3cc6..e4e60c036 100644 --- a/wenet/dataset/wav_distortion.py +++ b/wenet/dataset/wav_distortion.py @@ -18,15 +18,18 @@ import torchaudio import torch + torchaudio.set_audio_backend("sox_io") def db2amp(db): return pow(10, db / 20) + def amp2db(amp): return 20 * math.log10(amp) + def make_poly_distortion(conf): """Generate a db-domain ploynomial distortion function @@ -63,10 +66,13 @@ def poly_distortion(x): else: x = -amp return x + return poly_distortion + def make_quad_distortion(): - return make_poly_distortion({'a' : 1, 'm' : 1, 'n' : 1}) + return make_poly_distortion({'a': 1, 'm': 1, 'n': 1}) + # the amplitude are set to max for all non-zero point def make_max_distortion(conf): @@ -94,8 +100,8 @@ def max_distortion(x): else: x = 0.0 return x - return max_distortion + return max_distortion def make_amp_mask(db_mask=None): @@ -112,6 +118,7 @@ def make_amp_mask(db_mask=None): amp_mask = [(db2amp(db[0]), db2amp(db[1])) for db in db_mask] return amp_mask + default_mask = make_amp_mask() @@ -158,7 +165,7 @@ def make_fence_distortion(conf): mask_number = conf['mask_number'] max_db = conf['max_db'] max_amp = db2amp(max_db) # 0.997 - if mask_number <= 0 : + if mask_number <= 0: positive_mask = default_mask negative_mask = make_amp_mask([(-50, 0)]) else: @@ -186,6 +193,7 @@ def fence_distortion(x): return fence_distortion + # def make_jag_distortion(conf): """Generate a jag distortion function @@ -203,7 +211,7 @@ def make_jag_distortion(conf): a float amplitude value """ mask_number = conf['mask_number'] - if mask_number <= 0 : + if mask_number <= 0: positive_mask = default_mask negative_mask = make_amp_mask([(-50, 0)]) else: @@ -231,6 +239,7 @@ def jag_distortion(x): return jag_distortion + # gaining 20db means amp = amp * 10 # gaining -20db means amp = amp / 10 def make_gain_db(conf): @@ -269,6 +278,7 @@ def distort(x, func, rate=0.8): x[0][i] = func(float(x[0][i])) return x + def distort_chain(x, funcs, rate=0.8): for i in range(0, x.shape[1]): a = random.uniform(0, 1) @@ -277,6 +287,7 @@ def distort_chain(x, funcs, rate=0.8): x[0][i] = func(float(x[0][i])) return x + # x is numpy def distort_wav_conf(x, distort_type, distort_conf, rate=0.1): if distort_type == 'gain_db': @@ -303,12 +314,15 @@ def distort_wav_conf(x, distort_type, distort_conf, rate=0.1): print('unsupport type') return x -def distort_wav_conf_and_save(distort_type, distort_conf, rate, wav_in, wav_out): + +def distort_wav_conf_and_save(distort_type, distort_conf, rate, wav_in, + wav_out): x, sr = torchaudio.load(wav_in) x = x.detach().numpy() out = distort_wav_conf(x, distort_type, distort_conf, rate) torchaudio.save(wav_out, torch.from_numpy(out), sr) + if __name__ == "__main__": distort_type = sys.argv[1] wav_in = sys.argv[2] @@ -316,9 +330,9 @@ def distort_wav_conf_and_save(distort_type, distort_conf, rate, wav_in, wav_out) conf = None rate = 0.1 if distort_type == 'new_jag_distortion': - conf = {'mask_number' : 4} + conf = {'mask_number': 4} elif distort_type == 'new_fence_distortion': - conf = {'mask_number' : 1, 'max_db' : -30} + conf = {'mask_number': 1, 'max_db': -30} elif distort_type == 'poly_distortion': - conf = {'a' : 4, 'm' : 2, "n" : 2} + conf = {'a': 4, 'm': 2, "n": 2} distort_wav_conf_and_save(distort_type, conf, rate, wav_in, wav_out) diff --git a/wenet/e_branchformer/encoder.py b/wenet/e_branchformer/encoder.py index fc84e2347..2d4c6097e 100644 --- a/wenet/e_branchformer/encoder.py +++ b/wenet/e_branchformer/encoder.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # Modified from ESPnet(https://github.com/espnet/espnet) - """Encoder definition.""" import torch @@ -27,10 +26,13 @@ from wenet.utils.mask import make_pad_mask from wenet.utils.mask import add_optional_chunk_mask from wenet.utils.class_utils import ( - WENET_ATTENTION_CLASSES, WENET_EMB_CLASSES, WENET_SUBSAMPLE_CLASSES, + WENET_ATTENTION_CLASSES, + WENET_EMB_CLASSES, + WENET_SUBSAMPLE_CLASSES, WENET_ACTIVATION_CLASSES, ) + class EBranchformerEncoder(nn.Module): """E-Branchformer encoder module.""" @@ -72,8 +74,8 @@ def __init__( input_size, output_size, dropout_rate, - WENET_EMB_CLASSES[pos_enc_layer_type]( - output_size, positional_dropout_rate), + WENET_EMB_CLASSES[pos_enc_layer_type](output_size, + positional_dropout_rate), ) encoder_selfattn_layer_args = ( @@ -83,15 +85,9 @@ def __init__( ) cgmlp_layer = ConvolutionalGatingMLP - cgmlp_layer_args = ( - output_size, - cgmlp_linear_units, - cgmlp_conv_kernel, - dropout_rate, - use_linear_after_conv, - gate_activation, - causal - ) + cgmlp_layer_args = (output_size, cgmlp_linear_units, cgmlp_conv_kernel, + dropout_rate, use_linear_after_conv, + gate_activation, causal) # feed-forward module definition positionwise_layer = PositionwiseFeedForward @@ -107,15 +103,16 @@ def __init__( if len(stochastic_depth_rate) != num_blocks: raise ValueError( f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) " - f"should be equal to num_blocks ({num_blocks})" - ) + f"should be equal to num_blocks ({num_blocks})") self.encoders = torch.nn.ModuleList([ EBranchformerEncoderLayer( output_size, - WENET_ATTENTION_CLASSES[attention_layer_type](*encoder_selfattn_layer_args), + WENET_ATTENTION_CLASSES[attention_layer_type]( + *encoder_selfattn_layer_args), cgmlp_layer(*cgmlp_layer_args), - positionwise_layer(*positionwise_layer_args) if use_ffn else None, + positionwise_layer( + *positionwise_layer_args) if use_ffn else None, positionwise_layer(*positionwise_layer_args) if use_ffn and macaron_style else None, dropout_rate, @@ -175,7 +172,7 @@ def forward( self.static_chunk_size, num_decoding_left_chunks) for layer in self.encoders: - xs, chunk_masks, _ , _ = layer(xs, chunk_masks, pos_emb, mask_pad) + xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) xs = self.after_norm(xs) # Here we assume the mask is not changed in encoder layers, so just @@ -237,8 +234,8 @@ def forward_chunk( elayers, cache_t1 = att_cache.size(0), att_cache.size(2) chunk_size = xs.size(1) attention_key_size = cache_t1 + chunk_size - pos_emb = self.embed.position_encoding( - offset=offset - cache_t1, size=attention_key_size) + pos_emb = self.embed.position_encoding(offset=offset - cache_t1, + size=attention_key_size) if required_cache_size < 0: next_cache_start = 0 elif required_cache_size == 0: @@ -252,10 +249,11 @@ def forward_chunk( # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2), # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2) xs, _, new_att_cache, new_cnn_cache = layer( - xs, att_mask, pos_emb, + xs, + att_mask, + pos_emb, att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache, - cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache - ) + cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache) # NOTE(xcsong): After layer.forward # shape(new_att_cache) is (1, head, attention_key_size, d_k * 2), # shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2) @@ -322,10 +320,14 @@ def forward_chunk_by_chunk( for cur in range(0, num_frames - context + 1, stride): end = min(cur + decoding_window, num_frames) chunk_xs = xs[:, cur:end, :] - (y, att_cache, cnn_cache) = self.forward_chunk( - chunk_xs, offset, required_cache_size, att_cache, cnn_cache) + (y, att_cache, + cnn_cache) = self.forward_chunk(chunk_xs, offset, + required_cache_size, att_cache, + cnn_cache) outputs.append(y) offset += y.size(1) ys = torch.cat(outputs, 1) - masks = torch.ones((1, 1, ys.size(1)), device=ys.device, dtype=torch.bool) + masks = torch.ones((1, 1, ys.size(1)), + device=ys.device, + dtype=torch.bool) return ys, masks diff --git a/wenet/e_branchformer/encoder_layer.py b/wenet/e_branchformer/encoder_layer.py index 08b0c4d92..dba232383 100644 --- a/wenet/e_branchformer/encoder_layer.py +++ b/wenet/e_branchformer/encoder_layer.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # Modified from ESPnet(https://github.com/espnet/espnet) - """EBranchformerEncoderLayer definition.""" import torch @@ -135,8 +134,7 @@ def forward( residual = x x = self.norm_ff_macaron(x) x = residual + stoch_layer_coeff * self.ff_scale * self.dropout( - self.feed_forward_macaron(x) - ) + self.feed_forward_macaron(x)) # Two branches x1 = x @@ -162,15 +160,15 @@ def forward( assert x_tmp.size(2) > self.lorder x_tmp = self.depthwise_conv_fusion(x_tmp) x_tmp = x_tmp.transpose(1, 2) - x = x + stoch_layer_coeff * self.dropout(self.merge_proj(x_concat + x_tmp)) + x = x + stoch_layer_coeff * self.dropout( + self.merge_proj(x_concat + x_tmp)) if self.feed_forward is not None: # feed forward module residual = x x = self.norm_ff(x) x = residual + stoch_layer_coeff * self.ff_scale * self.dropout( - self.feed_forward(x) - ) + self.feed_forward(x)) x = self.norm_final(x) diff --git a/wenet/efficient_conformer/attention.py b/wenet/efficient_conformer/attention.py index 475131b15..dc7628760 100644 --- a/wenet/efficient_conformer/attention.py +++ b/wenet/efficient_conformer/attention.py @@ -14,7 +14,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Multi-Head Attention layer definition.""" import math @@ -36,6 +35,7 @@ class GroupedRelPositionMultiHeadedAttention(MultiHeadedAttention): n_feat (int): The number of features. dropout_rate (float): Dropout rate. """ + def __init__(self, n_head, n_feat, dropout_rate, group_size=3): """Construct an RelPositionMultiHeadedAttention object.""" super().__init__(n_head, n_feat, dropout_rate) @@ -46,8 +46,10 @@ def __init__(self, n_head, n_feat, dropout_rate, group_size=3): self.n_feat = n_feat # these two learnable bias are used in matrix c and matrix d # as described in https://arxiv.org/abs/1901.02860 Section 3.3 - self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k * self.group_size)) - self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k * self.group_size)) + self.pos_bias_u = nn.Parameter( + torch.Tensor(self.h, self.d_k * self.group_size)) + self.pos_bias_v = nn.Parameter( + torch.Tensor(self.h, self.d_k * self.group_size)) torch.nn.init.xavier_uniform_(self.pos_bias_u) torch.nn.init.xavier_uniform_(self.pos_bias_v) @@ -102,7 +104,7 @@ def pad4group(self, Q, K, V, P, mask, group_size: int = 3): K = F.pad(K, (0, 0, 0, padding_KV), value=0.0) V = F.pad(V, (0, 0, 0, padding_KV), value=0.0) - if mask is not None and mask.size(2) > 0 : # time2 > 0: + if mask is not None and mask.size(2) > 0: # time2 > 0: mask = mask[:, ::group_size, ::group_size] Q = Q.transpose(1, 2).contiguous().view( @@ -117,15 +119,17 @@ def pad4group(self, Q, K, V, P, mask, group_size: int = 3): overflow_P = P.size(1) % group_size padding_P = group_size - overflow_P if overflow_P else 0 P = F.pad(P, (0, 0, 0, padding_P), value=0.0) - P = P.view(P_batch_size, -1, self.h, self.d_k * group_size).transpose(1, 2) + P = P.view(P_batch_size, -1, self.h, + self.d_k * group_size).transpose(1, 2) return Q, K, V, P, mask, padding_Q - def forward_attention( - self, value: torch.Tensor, scores: torch.Tensor, - mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - padding_q: Optional[int] = None - ) -> torch.Tensor: + def forward_attention(self, + value: torch.Tensor, + scores: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), + dtype=torch.bool), + padding_q: Optional[int] = None) -> torch.Tensor: """Compute attention context vector. Args: @@ -147,7 +151,7 @@ def forward_attention( # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the # 1st chunk to ease the onnx export.] # 2. pytorch training - if mask.size(2) > 0 : # time2 > 0 + if mask.size(2) > 0: # time2 > 0 mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) # For last chunk, time2 might be larger than scores.size(-1) mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2) @@ -172,12 +176,15 @@ def forward_attention( return self.linear_out(x) # (batch, time1, d_model) - def forward(self, query: torch.Tensor, key: torch.Tensor, - value: torch.Tensor, - mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - pos_emb: torch.Tensor = torch.empty(0), - cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), - ) -> Tuple[torch.Tensor, torch.Tensor]: + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + pos_emb: torch.Tensor = torch.empty(0), + cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute 'Scaled Dot Product Attention' with rel. positional encoding. Args: query (torch.Tensor): Query tensor (#batch, time1, size). @@ -197,9 +204,9 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, and `head * d_k == size` """ q = self.linear_q(query) - k = self.linear_k(key) # (#batch, time2, size) + k = self.linear_k(key) # (#batch, time2, size) v = self.linear_v(value) - p = self.linear_pos(pos_emb) # (#batch, time2, size) + p = self.linear_pos(pos_emb) # (#batch, time2, size) batch_size, seq_len_KV, _ = k.size() # seq_len_KV = time2 @@ -209,8 +216,9 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, v = v.view(batch_size, -1, self.h, self.d_k).transpose(1, 2) if cache.size(0) > 0: # use attention cache - key_cache, value_cache = torch.split( - cache, cache.size(-1) // 2, dim=-1) + key_cache, value_cache = torch.split(cache, + cache.size(-1) // 2, + dim=-1) k = torch.cat([key_cache, k], dim=2) v = torch.cat([value_cache, v], dim=2) new_cache = torch.cat((k, v), dim=-1) @@ -222,7 +230,8 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, v = v[:, :, -time2:, :] # q k v p: (batch, head, time1, d_k) - q, k, v, p, mask, padding_q = self.pad4group(q, k, v, p, mask, self.group_size) + q, k, v, p, mask, padding_q = self.pad4group(q, k, v, p, mask, + self.group_size) # q_with_bias_u & q_with_bias_v = (batch, head, time1, d_k) q = q.transpose(1, 2) # (batch, time1, head, d_k) diff --git a/wenet/efficient_conformer/convolution.py b/wenet/efficient_conformer/convolution.py index d95f38a76..3fa3dff25 100644 --- a/wenet/efficient_conformer/convolution.py +++ b/wenet/efficient_conformer/convolution.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # Modified from ESPnet(https://github.com/espnet/espnet) - """ConvolutionModule definition.""" from typing import Tuple @@ -23,6 +22,7 @@ class ConvolutionModule(nn.Module): """ConvolutionModule in Conformer model.""" + def __init__(self, channels: int, kernel_size: int = 15, diff --git a/wenet/efficient_conformer/encoder.py b/wenet/efficient_conformer/encoder.py index ad01b6a63..ab2284c54 100644 --- a/wenet/efficient_conformer/encoder.py +++ b/wenet/efficient_conformer/encoder.py @@ -15,7 +15,6 @@ # limitations under the License. # Modified from EfficientConformer(https://github.com/burchim/EfficientConformer) # Paper(https://arxiv.org/abs/2109.01163) - """Encoder definition.""" from typing import Tuple, Optional, List, Union @@ -32,42 +31,45 @@ from wenet.utils.mask import make_pad_mask from wenet.utils.mask import add_optional_chunk_mask from wenet.utils.class_utils import ( - WENET_ATTENTION_CLASSES, WENET_EMB_CLASSES, WENET_SUBSAMPLE_CLASSES, + WENET_ATTENTION_CLASSES, + WENET_EMB_CLASSES, + WENET_SUBSAMPLE_CLASSES, WENET_ACTIVATION_CLASSES, ) + class EfficientConformerEncoder(torch.nn.Module): """Conformer encoder module.""" - def __init__( - self, - input_size: int, - output_size: int = 256, - attention_heads: int = 4, - linear_units: int = 2048, - num_blocks: int = 6, - dropout_rate: float = 0.1, - positional_dropout_rate: float = 0.1, - attention_dropout_rate: float = 0.0, - input_layer: str = "conv2d", - pos_enc_layer_type: str = "rel_pos", - normalize_before: bool = True, - static_chunk_size: int = 0, - use_dynamic_chunk: bool = False, - global_cmvn: torch.nn.Module = None, - use_dynamic_left_chunk: bool = False, - macaron_style: bool = True, - activation_type: str = "swish", - use_cnn_module: bool = True, - cnn_module_kernel: int = 15, - causal: bool = False, - cnn_module_norm: str = "batch_norm", - stride_layer_idx: Optional[Union[int, List[int]]] = 3, - stride: Optional[Union[int, List[int]]] = 2, - group_layer_idx: Optional[Union[int, List[int], tuple]] = (0, 1, 2, 3), - group_size: int = 3, - stride_kernel: bool = True, - **kwargs - ): + + def __init__(self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: str = "conv2d", + pos_enc_layer_type: str = "rel_pos", + normalize_before: bool = True, + static_chunk_size: int = 0, + use_dynamic_chunk: bool = False, + global_cmvn: torch.nn.Module = None, + use_dynamic_left_chunk: bool = False, + macaron_style: bool = True, + activation_type: str = "swish", + use_cnn_module: bool = True, + cnn_module_kernel: int = 15, + causal: bool = False, + cnn_module_norm: str = "batch_norm", + stride_layer_idx: Optional[Union[int, List[int]]] = 3, + stride: Optional[Union[int, List[int]]] = 2, + group_layer_idx: Optional[Union[int, List[int], + tuple]] = (0, 1, 2, 3), + group_size: int = 3, + stride_kernel: bool = True, + **kwargs): """Construct Efficient Conformer Encoder Args: @@ -87,15 +89,17 @@ def __init__( super().__init__() self._output_size = output_size - logging.info(f"input_layer = {input_layer}, " - f"subsampling_class = {WENET_SUBSAMPLE_CLASSES[input_layer]}") + logging.info( + f"input_layer = {input_layer}, " + f"subsampling_class = {WENET_SUBSAMPLE_CLASSES[input_layer]}") self.global_cmvn = global_cmvn self.embed = WENET_SUBSAMPLE_CLASSES[input_layer]( input_size, output_size, dropout_rate, - WENET_EMB_CLASSES[pos_enc_layer_type](output_size, positional_dropout_rate), + WENET_EMB_CLASSES[pos_enc_layer_type](output_size, + positional_dropout_rate), ) self.input_layer = input_layer self.normalize_before = normalize_before @@ -118,13 +122,15 @@ def __init__( if type(stride) == int else stride self.group_layer_idx = [group_layer_idx] \ if type(group_layer_idx) == int else group_layer_idx - self.grouped_size = group_size # group size of every GroupedAttention layer + self.grouped_size = group_size # group size of every GroupedAttention layer assert len(self.stride) == len(self.stride_layer_idx) - self.cnn_module_kernels = [cnn_module_kernel] # kernel size of each StridedConv + self.cnn_module_kernels = [cnn_module_kernel + ] # kernel size of each StridedConv for i in self.stride: if stride_kernel: - self.cnn_module_kernels.append(self.cnn_module_kernels[-1] // i) + self.cnn_module_kernels.append(self.cnn_module_kernels[-1] // + i) else: self.cnn_module_kernels.append(self.cnn_module_kernels[-1]) @@ -151,21 +157,20 @@ def __init__( for i in range(num_blocks): # self-attention module definition if i in self.group_layer_idx: - encoder_selfattn_layer = WENET_ATTENTION_CLASSES["grouped_rel_selfattn"] - encoder_selfattn_layer_args = ( - attention_heads, - output_size, - attention_dropout_rate, - self.grouped_size) + encoder_selfattn_layer = WENET_ATTENTION_CLASSES[ + "grouped_rel_selfattn"] + encoder_selfattn_layer_args = (attention_heads, output_size, + attention_dropout_rate, + self.grouped_size) else: if pos_enc_layer_type == "no_pos": - encoder_selfattn_layer = WENET_ATTENTION_CLASSES["selfattn"] + encoder_selfattn_layer = WENET_ATTENTION_CLASSES[ + "selfattn"] else: - encoder_selfattn_layer = WENET_ATTENTION_CLASSES["rel_selfattn"] - encoder_selfattn_layer_args = ( - attention_heads, - output_size, - attention_dropout_rate) + encoder_selfattn_layer = WENET_ATTENTION_CLASSES[ + "rel_selfattn"] + encoder_selfattn_layer_args = (attention_heads, output_size, + attention_dropout_rate) # conformer module definition if i in self.stride_layer_idx: @@ -173,38 +178,42 @@ def __init__( convolution_layer_args_stride = ( output_size, self.cnn_module_kernels[index], activation, cnn_module_norm, causal, True, self.stride[index]) - layers.append(StrideConformerEncoderLayer( - output_size, - encoder_selfattn_layer(*encoder_selfattn_layer_args), - positionwise_layer(*positionwise_layer_args), - positionwise_layer( - *positionwise_layer_args) if macaron_style else None, - convolution_layer( - *convolution_layer_args_stride) if use_cnn_module else None, - torch.nn.AvgPool1d( - kernel_size=self.stride[index], stride=self.stride[index], - padding=0, ceil_mode=True, - count_include_pad=False), # pointwise_conv_layer - dropout_rate, - normalize_before, - )) + layers.append( + StrideConformerEncoderLayer( + output_size, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + positionwise_layer(*positionwise_layer_args) + if macaron_style else None, + convolution_layer(*convolution_layer_args_stride) + if use_cnn_module else None, + torch.nn.AvgPool1d( + kernel_size=self.stride[index], + stride=self.stride[index], + padding=0, + ceil_mode=True, + count_include_pad=False), # pointwise_conv_layer + dropout_rate, + normalize_before, + )) index = index + 1 else: # conformer block convolution_layer_args_normal = ( output_size, self.cnn_module_kernels[index], activation, cnn_module_norm, causal) - layers.append(ConformerEncoderLayer( - output_size, - encoder_selfattn_layer(*encoder_selfattn_layer_args), - positionwise_layer(*positionwise_layer_args), - positionwise_layer( - *positionwise_layer_args) if macaron_style else None, - convolution_layer( - *convolution_layer_args_normal) if use_cnn_module else None, - dropout_rate, - normalize_before, - )) + layers.append( + ConformerEncoderLayer( + output_size, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + positionwise_layer(*positionwise_layer_args) + if macaron_style else None, + convolution_layer(*convolution_layer_args_normal) + if use_cnn_module else None, + dropout_rate, + normalize_before, + )) self.encoders = torch.nn.ModuleList(layers) @@ -232,12 +241,13 @@ def calculate_downsampling_factor(self, i: int) -> int: factor *= self.stride[idx] return factor - def forward(self, - xs: torch.Tensor, - xs_lens: torch.Tensor, - decoding_chunk_size: int = 0, - num_decoding_left_chunks: int = -1, - ) -> Tuple[torch.Tensor, torch.Tensor]: + def forward( + self, + xs: torch.Tensor, + xs_lens: torch.Tensor, + decoding_chunk_size: int = 0, + num_decoding_left_chunks: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor]: """Embed positions in tensor. Args: xs: padded input tensor (B, T, D) @@ -274,8 +284,8 @@ def forward(self, xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) if i in self.stride_layer_idx: masks = masks[:, :, ::self.stride[index]] - chunk_masks = chunk_masks[:, ::self.stride[index], - ::self.stride[index]] + chunk_masks = chunk_masks[:, ::self.stride[index], ::self. + stride[index]] mask_pad = masks pos_emb = pos_emb[:, ::self.stride[index], :] index = index + 1 @@ -366,7 +376,7 @@ def forward_chunk( xs.size(1), device=xs.device, dtype=torch.bool) - mask_pad = mask_pad.unsqueeze(1) # batchPad (b=1, 1, time=chunk_size) + mask_pad = mask_pad.unsqueeze(1) # batchPad (b=1, 1, time=chunk_size) if self.global_chunk_size > 0: # for ONNX decode simulation @@ -376,10 +386,10 @@ def forward_chunk( att_mask[:, :, -self.global_chunk_size:] = chunk_masks mask_pad = chunk_masks.to(torch.bool) else: - pos_emb = self.embed.position_encoding( - offset=offset - cache_t1, size=attention_key_size) + pos_emb = self.embed.position_encoding(offset=offset - cache_t1, + size=attention_key_size) - max_att_len, max_cnn_len = 0, 0 # for repeat_interleave of new_att_cache + max_att_len, max_cnn_len = 0, 0 # for repeat_interleave of new_att_cache for i, layer in enumerate(self.encoders): factor = self.calculate_downsampling_factor(i) # NOTE(xcsong): Before layer.forward @@ -392,20 +402,23 @@ def forward_chunk( att_cache_trunc = xs.size(1) + \ att_cache.size(2) // factor - pos_emb.size(1) + 1 xs, _, new_att_cache, new_cnn_cache = layer( - xs, att_mask, pos_emb, + xs, + att_mask, + pos_emb, mask_pad=mask_pad, - att_cache=att_cache[i:i + 1, :, ::factor, :][:, :, att_cache_trunc:, :], + att_cache=att_cache[i:i + + 1, :, ::factor, :][:, :, + att_cache_trunc:, :], cnn_cache=cnn_cache[i, :, :, :] - if cnn_cache.size(0) > 0 else cnn_cache - ) + if cnn_cache.size(0) > 0 else cnn_cache) if i in self.stride_layer_idx: # compute time dimension for next block efficient_index = self.stride_layer_idx.index(i) - att_mask = att_mask[:, ::self.stride[efficient_index], - ::self.stride[efficient_index]] - mask_pad = mask_pad[:, ::self.stride[efficient_index], - ::self.stride[efficient_index]] + att_mask = att_mask[:, ::self.stride[efficient_index], ::self. + stride[efficient_index]] + mask_pad = mask_pad[:, ::self.stride[efficient_index], ::self. + stride[efficient_index]] pos_emb = pos_emb[:, ::self.stride[efficient_index], :] # shape(new_att_cache) = [batch, head, time2, outdim] @@ -414,7 +427,8 @@ def forward_chunk( new_cnn_cache = new_cnn_cache.unsqueeze(0) # use repeat_interleave to new_att_cache - new_att_cache = new_att_cache.repeat_interleave(repeats=factor, dim=2) + new_att_cache = new_att_cache.repeat_interleave(repeats=factor, + dim=2) # padding new_cnn_cache to cnn.lorder for casual convolution new_cnn_cache = F.pad( new_cnn_cache, @@ -448,12 +462,11 @@ def forward_chunk( return xs, r_att_cache, r_cnn_cache def forward_chunk_by_chunk( - self, - xs: torch.Tensor, - decoding_chunk_size: int, - num_decoding_left_chunks: int = -1, - use_onnx=False - ) -> Tuple[torch.Tensor, torch.Tensor]: + self, + xs: torch.Tensor, + decoding_chunk_size: int, + num_decoding_left_chunks: int = -1, + use_onnx=False) -> Tuple[torch.Tensor, torch.Tensor]: """ Forward input chunk by chunk with chunk_size like a streaming fashion @@ -501,13 +514,16 @@ def forward_chunk_by_chunk( self.output_size() // self.attention_heads * 2), device=xs.device) cnn_cache: torch.Tensor = torch.zeros( - (self.num_blocks, 1, self.output_size(), self.cnn_module_kernel - 1), + (self.num_blocks, 1, self.output_size(), + self.cnn_module_kernel - 1), device=xs.device) self.set_global_chunk_size(chunk_size=decoding_chunk_size) else: logging.info("Simulating for JIT runtime ...") - att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device) - cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device) + att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), + device=xs.device) + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), + device=xs.device) # Feed forward overlap input step by step for cur in range(0, num_frames - context + 1, stride): @@ -518,12 +534,14 @@ def forward_chunk_by_chunk( if use_onnx: att_mask: torch.Tensor = torch.ones( (1, 1, required_cache_size + decoding_chunk_size), - dtype=torch.bool, device=xs.device) + dtype=torch.bool, + device=xs.device) if cur == 0: att_mask[:, :, :required_cache_size] = 0 else: - att_mask: torch.Tensor = torch.ones( - (0, 0, 0), dtype=torch.bool, device=xs.device) + att_mask: torch.Tensor = torch.ones((0, 0, 0), + dtype=torch.bool, + device=xs.device) chunk_xs = xs[:, cur:end, :] (y, att_cache, cnn_cache) = \ @@ -534,5 +552,9 @@ def forward_chunk_by_chunk( offset += y.size(1) ys = torch.cat(outputs, 1) - masks = torch.ones(1, 1, ys.size(1), device=ys.device, dtype=torch.bool) + masks = torch.ones(1, + 1, + ys.size(1), + device=ys.device, + dtype=torch.bool) return ys, masks diff --git a/wenet/efficient_conformer/encoder_layer.py b/wenet/efficient_conformer/encoder_layer.py index d35c2a99e..5d160564f 100644 --- a/wenet/efficient_conformer/encoder_layer.py +++ b/wenet/efficient_conformer/encoder_layer.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # Modified from ESPnet(https://github.com/espnet/espnet) - """Encoder self-attention layer definition.""" from typing import Optional, Tuple @@ -41,17 +40,16 @@ class StrideConformerEncoderLayer(nn.Module): True: use layer_norm before each sub-block. False: use layer_norm after each sub-block. """ - def __init__( - self, - size: int, - self_attn: torch.nn.Module, - feed_forward: Optional[nn.Module] = None, - feed_forward_macaron: Optional[nn.Module] = None, - conv_module: Optional[nn.Module] = None, - pointwise_conv_layer: Optional[nn.Module] = None, - dropout_rate: float = 0.1, - normalize_before: bool = True - ): + + def __init__(self, + size: int, + self_attn: torch.nn.Module, + feed_forward: Optional[nn.Module] = None, + feed_forward_macaron: Optional[nn.Module] = None, + conv_module: Optional[nn.Module] = None, + pointwise_conv_layer: Optional[nn.Module] = None, + dropout_rate: float = 0.1, + normalize_before: bool = True): """Construct an EncoderLayer object.""" super().__init__() self.self_attn = self_attn @@ -67,8 +65,7 @@ def __init__( else: self.ff_scale = 1.0 if self.conv_module is not None: - self.norm_conv = nn.LayerNorm(size, - eps=1e-5) # for the CNN module + self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module self.norm_final = nn.LayerNorm( size, eps=1e-5) # for the final output of the block self.dropout = nn.Dropout(dropout_rate) @@ -122,8 +119,8 @@ def forward( if self.normalize_before: x = self.norm_mha(x) - x_att, new_att_cache = self.self_attn( - x, x, x, mask, pos_emb, att_cache) + x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, + att_cache) x = residual + self.dropout(x_att) if not self.normalize_before: diff --git a/wenet/efficient_conformer/subsampling.py b/wenet/efficient_conformer/subsampling.py index 98b2c2228..ec05d9fb7 100644 --- a/wenet/efficient_conformer/subsampling.py +++ b/wenet/efficient_conformer/subsampling.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # Modified from ESPnet(https://github.com/espnet/espnet) - - """Subsampling layer definition.""" from typing import Tuple, Union @@ -32,14 +30,13 @@ class Conv2dSubsampling2(BaseSubsampling): dropout_rate (float): Dropout rate. """ + def __init__(self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module): """Construct an Conv2dSubsampling4 object.""" super().__init__() - self.conv = torch.nn.Sequential( - torch.nn.Conv2d(1, odim, 3, 2), - torch.nn.ReLU() - ) + self.conv = torch.nn.Sequential(torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU()) self.out = torch.nn.Sequential( torch.nn.Linear(odim * ((idim - 1) // 2), odim)) self.pos_enc = pos_enc_class @@ -50,10 +47,10 @@ def __init__(self, idim: int, odim: int, dropout_rate: float, self.right_context = 2 def forward( - self, - x: torch.Tensor, - x_mask: torch.Tensor, - offset: Union[int, torch.Tensor] = 0 + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Subsample x. diff --git a/wenet/k2/model.py b/wenet/k2/model.py index e9e39764e..f38c0c6df 100644 --- a/wenet/k2/model.py +++ b/wenet/k2/model.py @@ -25,6 +25,7 @@ class K2Model(ASRModel): + def __init__( self, vocab_size: int, diff --git a/wenet/squeezeformer/attention.py b/wenet/squeezeformer/attention.py index 97412badb..a3c973840 100644 --- a/wenet/squeezeformer/attention.py +++ b/wenet/squeezeformer/attention.py @@ -14,7 +14,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Multi-Head Attention layer definition.""" import math @@ -33,8 +32,13 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): dropout_rate (float): Dropout rate. """ - def __init__(self, n_head, n_feat, dropout_rate, - do_rel_shift=False, adaptive_scale=False, init_weights=False): + def __init__(self, + n_head, + n_feat, + dropout_rate, + do_rel_shift=False, + adaptive_scale=False, + init_weights=False): """Construct an RelPositionMultiHeadedAttention object.""" super().__init__(n_head, n_feat, dropout_rate) # linear transformation for positional encoding @@ -47,15 +51,15 @@ def __init__(self, n_head, n_feat, dropout_rate, torch.nn.init.xavier_uniform_(self.pos_bias_u) torch.nn.init.xavier_uniform_(self.pos_bias_v) self.adaptive_scale = adaptive_scale - self.ada_scale = nn.Parameter( - torch.ones([1, 1, n_feat]), requires_grad=adaptive_scale) - self.ada_bias = nn.Parameter( - torch.zeros([1, 1, n_feat]), requires_grad=adaptive_scale) + self.ada_scale = nn.Parameter(torch.ones([1, 1, n_feat]), + requires_grad=adaptive_scale) + self.ada_bias = nn.Parameter(torch.zeros([1, 1, n_feat]), + requires_grad=adaptive_scale) if init_weights: self.init_weights() def init_weights(self): - input_max = (self.h * self.d_k) ** -0.5 + input_max = (self.h * self.d_k)**-0.5 torch.nn.init.uniform_(self.linear_q.weight, -input_max, input_max) torch.nn.init.uniform_(self.linear_q.bias, -input_max, input_max) torch.nn.init.uniform_(self.linear_k.weight, -input_max, input_max) @@ -93,8 +97,10 @@ def rel_shift(self, x, zero_triu: bool = False): return x def forward_attention( - self, value: torch.Tensor, scores: torch.Tensor, - mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool) + self, + value: torch.Tensor, + scores: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool) ) -> torch.Tensor: """Compute attention context vector. @@ -137,12 +143,15 @@ def forward_attention( return self.linear_out(x) # (batch, time1, d_model) - def forward(self, query: torch.Tensor, - key: torch.Tensor, value: torch.Tensor, - mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - pos_emb: torch.Tensor = torch.empty(0), - cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) - ) -> Tuple[torch.Tensor, torch.Tensor]: + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + pos_emb: torch.Tensor = torch.empty(0), + cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) + ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute 'Scaled Dot Product Attention' with rel. positional encoding. Args: query (torch.Tensor): Query tensor (#batch, time1, size). @@ -185,8 +194,9 @@ def forward(self, query: torch.Tensor, # >>> d = torch.split(a, 2, dim=-1) # >>> torch.equal(d[0], d[1]) # True if cache.size(0) > 0: - key_cache, value_cache = torch.split( - cache, cache.size(-1) // 2, dim=-1) + key_cache, value_cache = torch.split(cache, + cache.size(-1) // 2, + dim=-1) k = torch.cat([key_cache, k], dim=2) v = torch.cat([value_cache, v], dim=2) # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's diff --git a/wenet/squeezeformer/conv2d.py b/wenet/squeezeformer/conv2d.py index c23026339..5107d2533 100644 --- a/wenet/squeezeformer/conv2d.py +++ b/wenet/squeezeformer/conv2d.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Conv2d Module with Valid Padding""" import torch.nn.functional as F @@ -22,6 +21,7 @@ class Conv2dValid(_ConvNd): """ Conv2d operator for VALID mode padding. """ + def __init__( self, in_channels: int, @@ -36,31 +36,31 @@ def __init__( device=None, dtype=None, valid_trigx: bool = False, - valid_trigy: bool = False - ) -> None: + valid_trigy: bool = False) -> None: factory_kwargs = {'device': device, 'dtype': dtype} kernel_size_ = _pair(kernel_size) stride_ = _pair(stride) padding_ = padding if isinstance(padding, str) else _pair(padding) dilation_ = _pair(dilation) - super(Conv2dValid, self).__init__( - in_channels, out_channels, kernel_size_, - stride_, padding_, dilation_, False, _pair(0), - groups, bias, padding_mode, **factory_kwargs) + super(Conv2dValid, + self).__init__(in_channels, out_channels, + kernel_size_, stride_, padding_, dilation_, False, + _pair(0), groups, bias, padding_mode, + **factory_kwargs) self.valid_trigx = valid_trigx self.valid_trigy = valid_trigy - def _conv_forward( - self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): + def _conv_forward(self, input: Tensor, weight: Tensor, + bias: Optional[Tensor]): validx, validy = 0, 0 if self.valid_trigx: - validx = (input.size(-2) * (self.stride[-2] - 1) - 1 - + self.kernel_size[-2]) // 2 + validx = (input.size(-2) * + (self.stride[-2] - 1) - 1 + self.kernel_size[-2]) // 2 if self.valid_trigy: - validy = (input.size(-1) * (self.stride[-1] - 1) - 1 - + self.kernel_size[-1]) // 2 - return F.conv2d(input, weight, bias, self.stride, - (validx, validy), self.dilation, self.groups) + validy = (input.size(-1) * + (self.stride[-1] - 1) - 1 + self.kernel_size[-1]) // 2 + return F.conv2d(input, weight, bias, self.stride, (validx, validy), + self.dilation, self.groups) def forward(self, input: Tensor) -> Tensor: return self._conv_forward(input, self.weight, self.bias) diff --git a/wenet/squeezeformer/convolution.py b/wenet/squeezeformer/convolution.py index 9e22f5e74..4218cbacb 100644 --- a/wenet/squeezeformer/convolution.py +++ b/wenet/squeezeformer/convolution.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # Modified from ESPnet(https://github.com/espnet/espnet) - """ConvolutionModule definition.""" from typing import Tuple @@ -24,6 +23,7 @@ class ConvolutionModule(nn.Module): """ConvolutionModule in Conformer model.""" + def __init__(self, channels: int, kernel_size: int = 15, @@ -32,8 +32,7 @@ def __init__(self, causal: bool = False, bias: bool = True, adaptive_scale: bool = False, - init_weights: bool = False - ): + init_weights: bool = False): """Construct an ConvolutionModule object. Args: channels (int): The number of channels of conv layers. @@ -45,10 +44,10 @@ def __init__(self, self.channels = channels self.kernel_size = kernel_size self.adaptive_scale = adaptive_scale - self.ada_scale = torch.nn.Parameter( - torch.ones([1, 1, channels]), requires_grad=adaptive_scale) - self.ada_bias = torch.nn.Parameter( - torch.zeros([1, 1, channels]), requires_grad=adaptive_scale) + self.ada_scale = torch.nn.Parameter(torch.ones([1, 1, channels]), + requires_grad=adaptive_scale) + self.ada_bias = torch.nn.Parameter(torch.zeros([1, 1, channels]), + requires_grad=adaptive_scale) self.pointwise_conv1 = nn.Conv1d( channels, @@ -101,17 +100,23 @@ def __init__(self, self.init_weights() def init_weights(self): - pw_max = self.channels ** -0.5 - dw_max = self.kernel_size ** -0.5 - torch.nn.init.uniform_(self.pointwise_conv1.weight.data, -pw_max, pw_max) + pw_max = self.channels**-0.5 + dw_max = self.kernel_size**-0.5 + torch.nn.init.uniform_(self.pointwise_conv1.weight.data, -pw_max, + pw_max) if self.bias: - torch.nn.init.uniform_(self.pointwise_conv1.bias.data, -pw_max, pw_max) - torch.nn.init.uniform_(self.depthwise_conv.weight.data, -dw_max, dw_max) + torch.nn.init.uniform_(self.pointwise_conv1.bias.data, -pw_max, + pw_max) + torch.nn.init.uniform_(self.depthwise_conv.weight.data, -dw_max, + dw_max) if self.bias: - torch.nn.init.uniform_(self.depthwise_conv.bias.data, -dw_max, dw_max) - torch.nn.init.uniform_(self.pointwise_conv2.weight.data, -pw_max, pw_max) + torch.nn.init.uniform_(self.depthwise_conv.bias.data, -dw_max, + dw_max) + torch.nn.init.uniform_(self.pointwise_conv2.weight.data, -pw_max, + pw_max) if self.bias: - torch.nn.init.uniform_(self.pointwise_conv2.bias.data, -pw_max, pw_max) + torch.nn.init.uniform_(self.pointwise_conv2.bias.data, -pw_max, + pw_max) def forward( self, diff --git a/wenet/squeezeformer/encoder.py b/wenet/squeezeformer/encoder.py index e9d97c462..bee71b92b 100644 --- a/wenet/squeezeformer/encoder.py +++ b/wenet/squeezeformer/encoder.py @@ -33,37 +33,36 @@ class SqueezeformerEncoder(nn.Module): - def __init__( - self, - input_size: int = 80, - encoder_dim: int = 256, - output_size: int = 256, - attention_heads: int = 4, - num_blocks: int = 12, - reduce_idx: Optional[Union[int, List[int]]] = 5, - recover_idx: Optional[Union[int, List[int]]] = 11, - feed_forward_expansion_factor: int = 4, - dw_stride: bool = False, - input_dropout_rate: float = 0.1, - pos_enc_layer_type: str = "rel_pos", - time_reduction_layer_type: str = "conv1d", - do_rel_shift: bool = True, - feed_forward_dropout_rate: float = 0.1, - attention_dropout_rate: float = 0.1, - cnn_module_kernel: int = 31, - cnn_norm_type: str = "batch_norm", - dropout: float = 0.1, - causal: bool = False, - adaptive_scale: bool = True, - activation_type: str = "swish", - init_weights: bool = True, - global_cmvn: torch.nn.Module = None, - normalize_before: bool = False, - use_dynamic_chunk: bool = False, - concat_after: bool = False, - static_chunk_size: int = 0, - use_dynamic_left_chunk: bool = False - ): + + def __init__(self, + input_size: int = 80, + encoder_dim: int = 256, + output_size: int = 256, + attention_heads: int = 4, + num_blocks: int = 12, + reduce_idx: Optional[Union[int, List[int]]] = 5, + recover_idx: Optional[Union[int, List[int]]] = 11, + feed_forward_expansion_factor: int = 4, + dw_stride: bool = False, + input_dropout_rate: float = 0.1, + pos_enc_layer_type: str = "rel_pos", + time_reduction_layer_type: str = "conv1d", + do_rel_shift: bool = True, + feed_forward_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.1, + cnn_module_kernel: int = 31, + cnn_norm_type: str = "batch_norm", + dropout: float = 0.1, + causal: bool = False, + adaptive_scale: bool = True, + activation_type: str = "swish", + init_weights: bool = True, + global_cmvn: torch.nn.Module = None, + normalize_before: bool = False, + use_dynamic_chunk: bool = False, + concat_after: bool = False, + static_chunk_size: int = 0, + use_dynamic_left_chunk: bool = False): """Construct SqueezeformerEncoder Args: @@ -126,51 +125,38 @@ def __init__( ) else: encoder_selfattn_layer = RelPositionMultiHeadedAttention - encoder_selfattn_layer_args = ( - attention_heads, - encoder_dim, - attention_dropout_rate, - do_rel_shift, - adaptive_scale, - init_weights - ) + encoder_selfattn_layer_args = (attention_heads, encoder_dim, + attention_dropout_rate, + do_rel_shift, adaptive_scale, + init_weights) # feed-forward module definition positionwise_layer = PositionwiseFeedForward - positionwise_layer_args = ( - encoder_dim, - encoder_dim * feed_forward_expansion_factor, - feed_forward_dropout_rate, - activation, - adaptive_scale, - init_weights - ) + positionwise_layer_args = (encoder_dim, + encoder_dim * feed_forward_expansion_factor, + feed_forward_dropout_rate, activation, + adaptive_scale, init_weights) # convolution module definition convolution_layer = ConvolutionModule - convolution_layer_args = ( - encoder_dim, cnn_module_kernel, activation, - cnn_norm_type, causal, True, adaptive_scale, init_weights) + convolution_layer_args = (encoder_dim, cnn_module_kernel, activation, + cnn_norm_type, causal, True, adaptive_scale, + init_weights) self.embed = DepthwiseConv2dSubsampling4( - 1, encoder_dim, - RelPositionalEncoding(encoder_dim, dropout_rate=0.1), - dw_stride, - input_size, - input_dropout_rate, - init_weights - ) + 1, encoder_dim, RelPositionalEncoding(encoder_dim, + dropout_rate=0.1), dw_stride, + input_size, input_dropout_rate, init_weights) self.preln = nn.LayerNorm(encoder_dim) - self.encoders = torch.nn.ModuleList([SqueezeformerEncoderLayer( - encoder_dim, - encoder_selfattn_layer(*encoder_selfattn_layer_args), - positionwise_layer(*positionwise_layer_args), - convolution_layer(*convolution_layer_args), - positionwise_layer(*positionwise_layer_args), - normalize_before, - dropout, - concat_after) for _ in range(num_blocks) + self.encoders = torch.nn.ModuleList([ + SqueezeformerEncoderLayer( + encoder_dim, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + convolution_layer(*convolution_layer_args), + positionwise_layer(*positionwise_layer_args), normalize_before, + dropout, concat_after) for _ in range(num_blocks) ]) if time_reduction_layer_type == 'conv1d': time_reduction_layer = TimeReductionLayer1D @@ -188,7 +174,8 @@ def __init__( time_reduction_layer = TimeReductionLayer2D time_reduction_layer_args = {'encoder_dim': encoder_dim} - self.time_reduction_layer = time_reduction_layer(**time_reduction_layer_args) + self.time_reduction_layer = time_reduction_layer( + **time_reduction_layer_args) self.time_recover_layer = nn.Linear(encoder_dim, encoder_dim) self.final_proj = None if output_size != encoder_dim: @@ -198,11 +185,11 @@ def output_size(self) -> int: return self._output_size def forward( - self, - xs: torch.Tensor, - xs_lens: torch.Tensor, - decoding_chunk_size: int = 0, - num_decoding_left_chunks: int = -1, + self, + xs: torch.Tensor, + xs_lens: torch.Tensor, + decoding_chunk_size: int = 0, + num_decoding_left_chunks: int = -1, ) -> Tuple[torch.Tensor, torch.Tensor]: T = xs.size(1) masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) @@ -224,7 +211,8 @@ def forward( for i, layer in enumerate(self.encoders): if self.reduce_idx is not None: if self.time_reduce is not None and i in self.reduce_idx: - recover_activations.append((xs, chunk_masks, pos_emb, mask_pad)) + recover_activations.append( + (xs, chunk_masks, pos_emb, mask_pad)) xs, xs_lens, chunk_masks, mask_pad = \ self.time_reduction_layer(xs, xs_lens, chunk_masks, mask_pad) pos_emb = pos_emb[:, ::2, :] @@ -272,16 +260,16 @@ def calculate_downsampling_factor(self, i: int) -> int: for exp, rc_idx in enumerate(self.recover_idx): if i >= rc_idx: recover_exp = exp + 1 - return int(2 ** (reduce_exp - recover_exp)) + return int(2**(reduce_exp - recover_exp)) def forward_chunk( - self, - xs: torch.Tensor, - offset: int, - required_cache_size: int, - att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), - cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), - att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + self, + xs: torch.Tensor, + offset: int, + required_cache_size: int, + att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), + cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), + att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Forward just one chunk @@ -328,8 +316,8 @@ def forward_chunk( elayers, cache_t1 = att_cache.size(0), att_cache.size(2) chunk_size = xs.size(1) attention_key_size = cache_t1 + chunk_size - pos_emb = self.embed.position_encoding( - offset=offset - cache_t1, size=attention_key_size) + pos_emb = self.embed.position_encoding(offset=offset - cache_t1, + size=attention_key_size) if required_cache_size < 0: next_cache_start = 0 elif required_cache_size == 0: @@ -357,7 +345,8 @@ def forward_chunk( # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2) if self.reduce_idx is not None: if self.time_reduce is not None and i in self.reduce_idx: - recover_activations.append((xs, att_mask, pos_emb, mask_pad)) + recover_activations.append( + (xs, att_mask, pos_emb, mask_pad)) xs, xs_lens, att_mask, mask_pad = \ self.time_reduction_layer(xs, xs_lens, att_mask, mask_pad) pos_emb = pos_emb[:, ::2, :] @@ -378,17 +367,19 @@ def forward_chunk( pos_emb = recover_pos_emb mask_pad = recover_mask_pad if att_mask.size(1) != 0: - xs = xs.masked_fill(~att_mask[:, 0, :].unsqueeze(-1), 0.0) + xs = xs.masked_fill(~att_mask[:, 0, :].unsqueeze(-1), + 0.0) factor = self.calculate_downsampling_factor(i) xs, _, new_att_cache, new_cnn_cache = layer( - xs, att_mask, pos_emb, + xs, + att_mask, + pos_emb, att_cache=att_cache[i:i + 1][:, :, ::factor, :] - [:, :, :pos_emb.size(1) - xs.size(1), :] if - elayers > 0 else att_cache[:, :, ::factor, :], - cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache - ) + [:, :, :pos_emb.size(1) - xs.size(1), :] + if elayers > 0 else att_cache[:, :, ::factor, :], + cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache) # NOTE(xcsong): After layer.forward # shape(new_att_cache) is (1, head, attention_key_size, d_k * 2), # shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2) @@ -413,10 +404,10 @@ def forward_chunk( return (xs, r_att_cache, r_cnn_cache) def forward_chunk_by_chunk( - self, - xs: torch.Tensor, - decoding_chunk_size: int, - num_decoding_left_chunks: int = -1, + self, + xs: torch.Tensor, + decoding_chunk_size: int, + num_decoding_left_chunks: int = -1, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Forward input chunk by chunk with chunk_size like a streaming fashion @@ -469,5 +460,7 @@ def forward_chunk_by_chunk( outputs.append(y) offset += y.size(1) ys = torch.cat(outputs, 1) - masks = torch.ones((1, 1, ys.size(1)), device=ys.device, dtype=torch.bool) + masks = torch.ones((1, 1, ys.size(1)), + device=ys.device, + dtype=torch.bool) return ys, masks diff --git a/wenet/squeezeformer/encoder_layer.py b/wenet/squeezeformer/encoder_layer.py index 3c6bdd44a..b354b3032 100644 --- a/wenet/squeezeformer/encoder_layer.py +++ b/wenet/squeezeformer/encoder_layer.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """SqueezeformerEncoderLayer definition.""" import torch @@ -39,15 +38,15 @@ class SqueezeformerEncoderLayer(nn.Module): """ def __init__( - self, - size: int, - self_attn: torch.nn.Module, - feed_forward1: Optional[nn.Module] = None, - conv_module: Optional[nn.Module] = None, - feed_forward2: Optional[nn.Module] = None, - normalize_before: bool = False, - dropout_rate: float = 0.1, - concat_after: bool = False, + self, + size: int, + self_attn: torch.nn.Module, + feed_forward1: Optional[nn.Module] = None, + conv_module: Optional[nn.Module] = None, + feed_forward2: Optional[nn.Module] = None, + normalize_before: bool = False, + dropout_rate: float = 0.1, + concat_after: bool = False, ): super(SqueezeformerEncoderLayer, self).__init__() self.size = size @@ -68,19 +67,20 @@ def __init__( self.concat_linear = nn.Identity() def forward( - self, - x: torch.Tensor, - mask: torch.Tensor, - pos_emb: torch.Tensor, - mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), - cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + self, + x: torch.Tensor, + mask: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # self attention module residual = x if self.normalize_before: x = self.layer_norm1(x) - x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, att_cache) + x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, + att_cache) if self.concat_after: x_concat = torch.cat((x, x_att), dim=-1) x = residual + self.concat_linear(x_concat) diff --git a/wenet/squeezeformer/positionwise_feed_forward.py b/wenet/squeezeformer/positionwise_feed_forward.py index 289062dcf..40100959b 100644 --- a/wenet/squeezeformer/positionwise_feed_forward.py +++ b/wenet/squeezeformer/positionwise_feed_forward.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Positionwise feed forward layer definition.""" import torch @@ -38,8 +37,7 @@ def __init__(self, dropout_rate: float, activation: torch.nn.Module = torch.nn.ReLU(), adaptive_scale: bool = False, - init_weights: bool = False - ): + init_weights: bool = False): """Construct a PositionwiseFeedForward object.""" super(PositionwiseFeedForward, self).__init__() self.idim = idim @@ -51,16 +49,16 @@ def __init__(self, self.ada_scale = None self.ada_bias = None self.adaptive_scale = adaptive_scale - self.ada_scale = torch.nn.Parameter( - torch.ones([1, 1, idim]), requires_grad=adaptive_scale) - self.ada_bias = torch.nn.Parameter( - torch.zeros([1, 1, idim]), requires_grad=adaptive_scale) + self.ada_scale = torch.nn.Parameter(torch.ones([1, 1, idim]), + requires_grad=adaptive_scale) + self.ada_bias = torch.nn.Parameter(torch.zeros([1, 1, idim]), + requires_grad=adaptive_scale) if init_weights: self.init_weights() def init_weights(self): - ffn1_max = self.idim ** -0.5 - ffn2_max = self.hidden_units ** -0.5 + ffn1_max = self.idim**-0.5 + ffn2_max = self.hidden_units**-0.5 torch.nn.init.uniform_(self.w_1.weight.data, -ffn1_max, ffn1_max) torch.nn.init.uniform_(self.w_1.bias.data, -ffn1_max, ffn1_max) torch.nn.init.uniform_(self.w_2.weight.data, -ffn2_max, ffn2_max) diff --git a/wenet/squeezeformer/subsampling.py b/wenet/squeezeformer/subsampling.py index fdb0101d6..c769e1025 100644 --- a/wenet/squeezeformer/subsampling.py +++ b/wenet/squeezeformer/subsampling.py @@ -14,7 +14,6 @@ # Modified from Squeezeformer(https://github.com/kssteven418/Squeezeformer) # Squeezeformer(https://github.com/upskyy/Squeezeformer) # NeMo(https://github.com/NVIDIA/NeMo) - """DepthwiseConv2dSubsampling4 and TimeReductionLayer definition.""" import torch @@ -37,37 +36,39 @@ class DepthwiseConv2dSubsampling4(BaseSubsampling): """ - def __init__( - self, idim: int, odim: int, - pos_enc_class: torch.nn.Module, - dw_stride: bool = False, - input_size: int = 80, - input_dropout_rate: float = 0.1, - init_weights: bool = True - ): + def __init__(self, + idim: int, + odim: int, + pos_enc_class: torch.nn.Module, + dw_stride: bool = False, + input_size: int = 80, + input_dropout_rate: float = 0.1, + init_weights: bool = True): super(DepthwiseConv2dSubsampling4, self).__init__() self.idim = idim self.odim = odim - self.pw_conv = nn.Conv2d( - in_channels=idim, out_channels=odim, kernel_size=3, stride=2) + self.pw_conv = nn.Conv2d(in_channels=idim, + out_channels=odim, + kernel_size=3, + stride=2) self.act1 = nn.ReLU() - self.dw_conv = nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2, - groups=odim if dw_stride else 1 - ) + self.dw_conv = nn.Conv2d(in_channels=odim, + out_channels=odim, + kernel_size=3, + stride=2, + groups=odim if dw_stride else 1) self.act2 = nn.ReLU() self.pos_enc = pos_enc_class self.input_proj = nn.Sequential( - nn.Linear( - odim * (((input_size - 1) // 2 - 1) // 2), odim), + nn.Linear(odim * (((input_size - 1) // 2 - 1) // 2), odim), nn.Dropout(p=input_dropout_rate), ) if init_weights: - linear_max = (odim * input_size / 4) ** -0.5 - torch.nn.init.uniform_( - self.input_proj.state_dict()['0.weight'], -linear_max, linear_max) - torch.nn.init.uniform_( - self.input_proj.state_dict()['0.bias'], -linear_max, linear_max) + linear_max = (odim * input_size / 4)**-0.5 + torch.nn.init.uniform_(self.input_proj.state_dict()['0.weight'], + -linear_max, linear_max) + torch.nn.init.uniform_(self.input_proj.state_dict()['0.bias'], + -linear_max, linear_max) self.subsampling_rate = 4 # 6 = (3 - 1) * 1 + (3 - 1) * 2 self.right_context = 6 @@ -105,8 +106,11 @@ class TimeReductionLayer1D(nn.Module): stride (int): Downsampling factor in time dimension. """ - def __init__(self, channel: int, out_dim: int, - kernel_size: int = 5, stride: int = 2): + def __init__(self, + channel: int, + out_dim: int, + kernel_size: int = 5, + stride: int = 2): super(TimeReductionLayer1D, self).__init__() self.channel = channel @@ -125,24 +129,31 @@ def __init__(self, channel: int, out_dim: int, ) self.pw_conv = nn.Conv1d( - in_channels=channel, out_channels=out_dim, - kernel_size=1, stride=1, padding=0, groups=1, + in_channels=channel, + out_channels=out_dim, + kernel_size=1, + stride=1, + padding=0, + groups=1, ) self.init_weights() def init_weights(self): - dw_max = self.kernel_size ** -0.5 - pw_max = self.channel ** -0.5 + dw_max = self.kernel_size**-0.5 + pw_max = self.channel**-0.5 torch.nn.init.uniform_(self.dw_conv.weight, -dw_max, dw_max) torch.nn.init.uniform_(self.dw_conv.bias, -dw_max, dw_max) torch.nn.init.uniform_(self.pw_conv.weight, -pw_max, pw_max) torch.nn.init.uniform_(self.pw_conv.bias, -pw_max, pw_max) - def forward(self, xs, xs_lens: torch.Tensor, - mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - ): + def forward( + self, + xs, + xs_lens: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + ): xs = xs.transpose(1, 2) # [B, C, T] xs = xs.masked_fill(mask_pad.eq(0), 0.0) @@ -167,18 +178,19 @@ def forward(self, xs, xs_lens: torch.Tensor, class TimeReductionLayer2D(nn.Module): - def __init__( - self, kernel_size: int = 5, stride: int = 2, encoder_dim: int = 256): + + def __init__(self, + kernel_size: int = 5, + stride: int = 2, + encoder_dim: int = 256): super(TimeReductionLayer2D, self).__init__() self.encoder_dim = encoder_dim self.kernel_size = kernel_size - self.dw_conv = Conv2dValid( - in_channels=encoder_dim, - out_channels=encoder_dim, - kernel_size=(kernel_size, 1), - stride=stride, - valid_trigy=True - ) + self.dw_conv = Conv2dValid(in_channels=encoder_dim, + out_channels=encoder_dim, + kernel_size=(kernel_size, 1), + stride=stride, + valid_trigy=True) self.pw_conv = Conv2dValid( in_channels=encoder_dim, out_channels=encoder_dim, @@ -193,23 +205,26 @@ def __init__( self.init_weights() def init_weights(self): - dw_max = self.kernel_size ** -0.5 - pw_max = self.encoder_dim ** -0.5 + dw_max = self.kernel_size**-0.5 + pw_max = self.encoder_dim**-0.5 torch.nn.init.uniform_(self.dw_conv.weight, -dw_max, dw_max) torch.nn.init.uniform_(self.dw_conv.bias, -dw_max, dw_max) torch.nn.init.uniform_(self.pw_conv.weight, -pw_max, pw_max) torch.nn.init.uniform_(self.pw_conv.bias, -pw_max, pw_max) def forward( - self, xs: torch.Tensor, xs_lens: torch.Tensor, - mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + self, + xs: torch.Tensor, + xs_lens: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: xs = xs.masked_fill(mask_pad.transpose(1, 2).eq(0), 0.0) xs = xs.unsqueeze(2) padding1 = self.kernel_size - self.stride xs = F.pad(xs, (0, 0, 0, 0, 0, padding1, 0, 0), - mode='constant', value=0.) + mode='constant', + value=0.) xs = self.dw_conv(xs.permute(0, 3, 1, 2)) xs = self.pw_conv(xs).permute(0, 3, 2, 1).squeeze(1).contiguous() tmp_length = xs.size(1) @@ -236,8 +251,11 @@ class TimeReductionLayerStream(nn.Module): stride (int): Downsampling factor in time dimension. """ - def __init__(self, channel: int, out_dim: int, - kernel_size: int = 1, stride: int = 2): + def __init__(self, + channel: int, + out_dim: int, + kernel_size: int = 1, + stride: int = 2): super(TimeReductionLayerStream, self).__init__() self.channel = channel @@ -255,24 +273,31 @@ def __init__(self, channel: int, out_dim: int, ) self.pw_conv = nn.Conv1d( - in_channels=channel, out_channels=out_dim, - kernel_size=1, stride=1, padding=0, groups=1, + in_channels=channel, + out_channels=out_dim, + kernel_size=1, + stride=1, + padding=0, + groups=1, ) self.init_weights() def init_weights(self): - dw_max = self.kernel_size ** -0.5 - pw_max = self.channel ** -0.5 + dw_max = self.kernel_size**-0.5 + pw_max = self.channel**-0.5 torch.nn.init.uniform_(self.dw_conv.weight, -dw_max, dw_max) torch.nn.init.uniform_(self.dw_conv.bias, -dw_max, dw_max) torch.nn.init.uniform_(self.pw_conv.weight, -pw_max, pw_max) torch.nn.init.uniform_(self.pw_conv.bias, -pw_max, pw_max) - def forward(self, xs, xs_lens: torch.Tensor, - mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - ): + def forward( + self, + xs, + xs_lens: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + ): xs = xs.transpose(1, 2) # [B, C, T] xs = xs.masked_fill(mask_pad.eq(0), 0.0) diff --git a/wenet/ssl/bestrq/mask.py b/wenet/ssl/bestrq/mask.py index 1b345d388..6fc8b2b74 100644 --- a/wenet/ssl/bestrq/mask.py +++ b/wenet/ssl/bestrq/mask.py @@ -1,6 +1,7 @@ import torch import numpy as np + def _sampler(pdf: torch.Tensor, num_samples: int, device=torch.device('cpu')) -> torch.Tensor: size = pdf.size() diff --git a/wenet/text/tokenize_utils.py b/wenet/text/tokenize_utils.py index 2abd3fc62..3a1bffbf2 100644 --- a/wenet/text/tokenize_utils.py +++ b/wenet/text/tokenize_utils.py @@ -14,6 +14,7 @@ import re + def tokenize_by_bpe_model(sp, txt): tokens = [] # CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref: diff --git a/wenet/transformer/asr_model.py b/wenet/transformer/asr_model.py index 20503ef02..63e7bc812 100644 --- a/wenet/transformer/asr_model.py +++ b/wenet/transformer/asr_model.py @@ -239,8 +239,8 @@ def decode( ctc_probs, encoder_lens, blank_id) if 'ctc_prefix_beam_search' in methods: ctc_prefix_result = ctc_prefix_beam_search(ctc_probs, encoder_lens, - beam_size, context_graph, - blank_id) + beam_size, + context_graph, blank_id) results['ctc_prefix_beam_search'] = ctc_prefix_result if 'attention_rescoring' in methods: # attention_rescoring depends on ctc_prefix_beam_search nbest diff --git a/wenet/transformer/attention.py b/wenet/transformer/attention.py index ba6d8e0f8..3b215c10b 100644 --- a/wenet/transformer/attention.py +++ b/wenet/transformer/attention.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Multi-Head Attention layer definition.""" import math @@ -32,7 +31,11 @@ class MultiHeadedAttention(nn.Module): dropout_rate (float): Dropout rate. """ - def __init__(self, n_head: int, n_feat: int, dropout_rate: float, + + def __init__(self, + n_head: int, + n_feat: int, + dropout_rate: float, key_bias: bool = True): """Construct an MultiHeadedAttention object.""" super().__init__() @@ -76,7 +79,9 @@ def forward_qkv( return q, k, v def forward_attention( - self, value: torch.Tensor, scores: torch.Tensor, + self, + value: torch.Tensor, + scores: torch.Tensor, mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool) ) -> torch.Tensor: """Compute attention context vector. @@ -99,7 +104,7 @@ def forward_attention( # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the # 1st chunk to ease the onnx export.] # 2. pytorch training - if mask.size(2) > 0 : # time2 > 0 + if mask.size(2) > 0: # time2 > 0 mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) # For last chunk, time2 might be larger than scores.size(-1) mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2) @@ -120,12 +125,15 @@ def forward_attention( return self.linear_out(x) # (batch, time1, d_model) - def forward(self, query: torch.Tensor, key: torch.Tensor, - value: torch.Tensor, - mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - pos_emb: torch.Tensor = torch.empty(0), - cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) - ) -> Tuple[torch.Tensor, torch.Tensor]: + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + pos_emb: torch.Tensor = torch.empty(0), + cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) + ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute scaled dot product attention. Args: @@ -175,8 +183,9 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, # >>> d = torch.split(a, 2, dim=-1) # >>> torch.equal(d[0], d[1]) # True if cache.size(0) > 0: - key_cache, value_cache = torch.split( - cache, cache.size(-1) // 2, dim=-1) + key_cache, value_cache = torch.split(cache, + cache.size(-1) // 2, + dim=-1) k = torch.cat([key_cache, k], dim=2) v = torch.cat([value_cache, v], dim=2) # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's @@ -195,7 +204,11 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): n_feat (int): The number of features. dropout_rate (float): Dropout rate. """ - def __init__(self, n_head: int, n_feat: int, dropout_rate: float, + + def __init__(self, + n_head: int, + n_feat: int, + dropout_rate: float, key_bias: bool = True): """Construct an RelPositionMultiHeadedAttention object.""" super().__init__(n_head, n_feat, dropout_rate, key_bias) @@ -234,12 +247,15 @@ def rel_shift(self, x, zero_triu: bool = False): return x - def forward(self, query: torch.Tensor, - key: torch.Tensor, value: torch.Tensor, - mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - pos_emb: torch.Tensor = torch.empty(0), - cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) - ) -> Tuple[torch.Tensor, torch.Tensor]: + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + pos_emb: torch.Tensor = torch.empty(0), + cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) + ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute 'Scaled Dot Product Attention' with rel. positional encoding. Args: query (torch.Tensor): Query tensor (#batch, time1, size). @@ -278,8 +294,9 @@ def forward(self, query: torch.Tensor, # >>> d = torch.split(a, 2, dim=-1) # >>> torch.equal(d[0], d[1]) # True if cache.size(0) > 0: - key_cache, value_cache = torch.split( - cache, cache.size(-1) // 2, dim=-1) + key_cache, value_cache = torch.split(cache, + cache.size(-1) // 2, + dim=-1) k = torch.cat([key_cache, k], dim=2) v = torch.cat([value_cache, v], dim=2) # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's diff --git a/wenet/transformer/cmvn.py b/wenet/transformer/cmvn.py index 3a1e7457f..754b22168 100644 --- a/wenet/transformer/cmvn.py +++ b/wenet/transformer/cmvn.py @@ -16,6 +16,7 @@ class GlobalCMVN(torch.nn.Module): + def __init__(self, mean: torch.Tensor, istd: torch.Tensor, diff --git a/wenet/transformer/convolution.py b/wenet/transformer/convolution.py index 722ef3a2d..071f25aac 100644 --- a/wenet/transformer/convolution.py +++ b/wenet/transformer/convolution.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # Modified from ESPnet(https://github.com/espnet/espnet) - """ConvolutionModule definition.""" from typing import Tuple @@ -23,6 +22,7 @@ class ConvolutionModule(nn.Module): """ConvolutionModule in Conformer model.""" + def __init__(self, channels: int, kernel_size: int = 15, diff --git a/wenet/transformer/ctc.py b/wenet/transformer/ctc.py index ef6662060..08ec50630 100644 --- a/wenet/transformer/ctc.py +++ b/wenet/transformer/ctc.py @@ -19,6 +19,7 @@ class CTC(torch.nn.Module): """CTC module""" + def __init__( self, odim: int, @@ -41,7 +42,8 @@ def __init__( self.ctc_lo = torch.nn.Linear(eprojs, odim) reduction_type = "sum" if reduce else "none" - self.ctc_loss = torch.nn.CTCLoss(blank=blank_id, reduction=reduction_type) + self.ctc_loss = torch.nn.CTCLoss(blank=blank_id, + reduction=reduction_type) def forward(self, hs_pad: torch.Tensor, hlens: torch.Tensor, ys_pad: torch.Tensor, ys_lens: torch.Tensor) -> torch.Tensor: diff --git a/wenet/transformer/decoder.py b/wenet/transformer/decoder.py index 159c3caf4..4bcec75a6 100644 --- a/wenet/transformer/decoder.py +++ b/wenet/transformer/decoder.py @@ -22,7 +22,8 @@ from wenet.transformer.decoder_layer import DecoderLayer from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward from wenet.utils.class_utils import ( - WENET_EMB_CLASSES, WENET_ATTENTION_CLASSES, + WENET_EMB_CLASSES, + WENET_ATTENTION_CLASSES, WENET_ACTIVATION_CLASSES, ) from wenet.utils.mask import (subsequent_mask, make_pad_mask) @@ -75,9 +76,10 @@ def __init__( activation = WENET_ACTIVATION_CLASSES[activation_type]() self.embed = torch.nn.Sequential( - torch.nn.Identity() if input_layer == "no_pos" else torch.nn.Embedding( - vocab_size, attention_dim), - WENET_EMB_CLASSES[input_layer](attention_dim, positional_dropout_rate), + torch.nn.Identity() if input_layer == "no_pos" else + torch.nn.Embedding(vocab_size, attention_dim), + WENET_EMB_CLASSES[input_layer](attention_dim, + positional_dropout_rate), ) self.normalize_before = normalize_before @@ -93,12 +95,10 @@ def __init__( attention_dim, WENET_ATTENTION_CLASSES["selfattn"]( attention_heads, attention_dim, - self_attention_dropout_rate, key_bias - ), + self_attention_dropout_rate, key_bias), WENET_ATTENTION_CLASSES["selfattn"]( - attention_heads, attention_dim, - src_attention_dropout_rate, key_bias - ) if src_attention else None, + attention_heads, attention_dim, src_attention_dropout_rate, + key_bias) if src_attention else None, PositionwiseFeedForward(attention_dim, linear_units, dropout_rate, activation), dropout_rate, @@ -150,7 +150,8 @@ def forward( tgt_mask = tgt_mask & m x, _ = self.embed(tgt) if self.gradient_checkpointing and self.training: - x = self.forward_layers_checkpointed(x, tgt_mask, memory, memory_mask) + x = self.forward_layers_checkpointed(x, tgt_mask, memory, + memory_mask) else: x = self.forward_layers(x, tgt_mask, memory, memory_mask) if self.normalize_before: @@ -160,20 +161,19 @@ def forward( olens = tgt_mask.sum(1) return x, torch.tensor(0.0), olens - def forward_layers( - self, x: torch.Tensor, tgt_mask: torch.Tensor, - memory: torch.Tensor, memory_mask: torch.Tensor - ) -> torch.Tensor: + def forward_layers(self, x: torch.Tensor, tgt_mask: torch.Tensor, + memory: torch.Tensor, + memory_mask: torch.Tensor) -> torch.Tensor: for layer in self.decoders: x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory, memory_mask) return x @torch.jit.ignore(drop=True) - def forward_layers_checkpointed( - self, x: torch.Tensor, tgt_mask: torch.Tensor, - memory: torch.Tensor, memory_mask: torch.Tensor - ) -> torch.Tensor: + def forward_layers_checkpointed(self, x: torch.Tensor, + tgt_mask: torch.Tensor, + memory: torch.Tensor, + memory_mask: torch.Tensor) -> torch.Tensor: for layer in self.decoders: x, tgt_mask, memory, memory_mask = ckpt.checkpoint( layer.__call__, x, tgt_mask, memory, memory_mask) @@ -229,7 +229,8 @@ def tie_or_clone_weights(self, jit_mode: bool = True): return if jit_mode: logging.info("clone emb.weight to output.weight") - self.output_layer.weight = torch.nn.Parameter(self.embed[0].weight.clone()) + self.output_layer.weight = torch.nn.Parameter( + self.embed[0].weight.clone()) else: logging.info("tie emb.weight with output.weight") self.output_layer.weight = self.embed[0].weight @@ -239,7 +240,8 @@ def tie_or_clone_weights(self, jit_mode: bool = True): self.output_layer.bias.data, ( 0, - self.output_layer.weight.shape[0] - self.output_layer.bias.shape[0], + self.output_layer.weight.shape[0] - + self.output_layer.bias.shape[0], ), "constant", 0, @@ -287,18 +289,36 @@ def __init__( super().__init__() self.left_decoder = TransformerDecoder( - vocab_size, encoder_output_size, attention_heads, linear_units, - num_blocks, dropout_rate, positional_dropout_rate, - self_attention_dropout_rate, src_attention_dropout_rate, - input_layer, use_output_layer, normalize_before, - key_bias=key_bias, gradient_checkpointing=gradient_checkpointing) + vocab_size, + encoder_output_size, + attention_heads, + linear_units, + num_blocks, + dropout_rate, + positional_dropout_rate, + self_attention_dropout_rate, + src_attention_dropout_rate, + input_layer, + use_output_layer, + normalize_before, + key_bias=key_bias, + gradient_checkpointing=gradient_checkpointing) self.right_decoder = TransformerDecoder( - vocab_size, encoder_output_size, attention_heads, linear_units, - r_num_blocks, dropout_rate, positional_dropout_rate, - self_attention_dropout_rate, src_attention_dropout_rate, - input_layer, use_output_layer, normalize_before, - key_bias=key_bias, gradient_checkpointing=gradient_checkpointing) + vocab_size, + encoder_output_size, + attention_heads, + linear_units, + r_num_blocks, + dropout_rate, + positional_dropout_rate, + self_attention_dropout_rate, + src_attention_dropout_rate, + input_layer, + use_output_layer, + normalize_before, + key_bias=key_bias, + gradient_checkpointing=gradient_checkpointing) def forward( self, @@ -331,8 +351,8 @@ def forward( ys_in_lens) r_x = torch.tensor(0.0) if reverse_weight > 0.0: - r_x, _, olens = self.right_decoder(memory, memory_mask, r_ys_in_pad, - ys_in_lens) + r_x, _, olens = self.right_decoder(memory, memory_mask, + r_ys_in_pad, ys_in_lens) return l_x, r_x, olens def forward_one_step( diff --git a/wenet/transformer/decoder_layer.py b/wenet/transformer/decoder_layer.py index 195ca5f79..91c7c5d7f 100644 --- a/wenet/transformer/decoder_layer.py +++ b/wenet/transformer/decoder_layer.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Decoder self-attention layer definition.""" from typing import Optional, Tuple @@ -38,6 +37,7 @@ class DecoderLayer(nn.Module): True: use layer_norm before each sub-block. False: to use layer_norm after each sub-block. """ + def __init__( self, size: int, diff --git a/wenet/transformer/embedding.py b/wenet/transformer/embedding.py index 845caaaa2..17d8810ff 100644 --- a/wenet/transformer/embedding.py +++ b/wenet/transformer/embedding.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # Modified from ESPnet(https://github.com/espnet/espnet) - """Positonal Encoding Module.""" import math @@ -22,6 +21,7 @@ import torch.nn.functional as F import numpy as np + class PositionalEncoding(torch.nn.Module): """Positional encoding. @@ -32,6 +32,7 @@ class PositionalEncoding(torch.nn.Module): PE(pos, 2i) = sin(pos/(10000^(2i/dmodel))) PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel))) """ + def __init__(self, d_model: int, dropout_rate: float, @@ -74,7 +75,9 @@ def forward(self, x = x * self.xscale + pos_emb return self.dropout(x), self.dropout(pos_emb) - def position_encoding(self, offset: Union[int, torch.Tensor], size: int, + def position_encoding(self, + offset: Union[int, torch.Tensor], + size: int, apply_dropout: bool = True) -> torch.Tensor: """ For getting encoding in a streaming fashion @@ -112,6 +115,7 @@ def position_encoding(self, offset: Union[int, torch.Tensor], size: int, pos_emb = self.dropout(pos_emb) return pos_emb + class RelPositionalEncoding(PositionalEncoding): """Relative positional encoding module. See : Appendix B in https://arxiv.org/abs/1901.02860 @@ -120,6 +124,7 @@ class RelPositionalEncoding(PositionalEncoding): dropout_rate (float): Dropout rate. max_len (int): Maximum input length. """ + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): """Initialize class.""" super().__init__(d_model, dropout_rate, max_len, reverse=True) @@ -144,6 +149,7 @@ def forward(self, class WhisperPositionalEncoding(PositionalEncoding): """ Sinusoids position encoding used in openai-whisper.encoder """ + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500): super().__init__(d_model, dropout_rate, max_len) self.xscale = 1.0 @@ -160,6 +166,7 @@ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500): class LearnablePositionalEncoding(PositionalEncoding): """ Learnable position encoding used in openai-whisper.decoder """ + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448): super().__init__(d_model, dropout_rate, max_len) # NOTE(xcsong): overwrite self.pe & self.xscale @@ -170,6 +177,7 @@ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448): class NoPositionalEncoding(torch.nn.Module): """ No position encoding """ + def __init__(self, d_model: int, dropout_rate: float): super().__init__() self.d_model = d_model @@ -184,6 +192,6 @@ def forward(self, pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device) return self.dropout(x), pos_emb - def position_encoding( - self, offset: Union[int, torch.Tensor], size: int) -> torch.Tensor: + def position_encoding(self, offset: Union[int, torch.Tensor], + size: int) -> torch.Tensor: return torch.zeros(1, size, self.d_model) diff --git a/wenet/transformer/encoder.py b/wenet/transformer/encoder.py index 4549838fb..894caf59d 100644 --- a/wenet/transformer/encoder.py +++ b/wenet/transformer/encoder.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # Modified from ESPnet(https://github.com/espnet/espnet) - """Encoder definition.""" from typing import Tuple @@ -25,7 +24,9 @@ from wenet.transformer.encoder_layer import ConformerEncoderLayer from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward from wenet.utils.class_utils import ( - WENET_EMB_CLASSES, WENET_SUBSAMPLE_CLASSES, WENET_ATTENTION_CLASSES, + WENET_EMB_CLASSES, + WENET_SUBSAMPLE_CLASSES, + WENET_ATTENTION_CLASSES, WENET_ACTIVATION_CLASSES, ) from wenet.utils.mask import make_pad_mask @@ -33,6 +34,7 @@ class BaseEncoder(torch.nn.Module): + def __init__( self, input_size: int, @@ -91,7 +93,8 @@ def __init__( input_size, output_size, dropout_rate, - WENET_EMB_CLASSES[pos_enc_layer_type](output_size, positional_dropout_rate), + WENET_EMB_CLASSES[pos_enc_layer_type](output_size, + positional_dropout_rate), ) self.normalize_before = normalize_before @@ -147,7 +150,8 @@ def forward( self.static_chunk_size, num_decoding_left_chunks) if self.gradient_checkpointing and self.training: - xs = self.forward_layers_checkpointed(xs, chunk_masks, pos_emb, mask_pad) + xs = self.forward_layers_checkpointed(xs, chunk_masks, pos_emb, + mask_pad) else: xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad) if self.normalize_before: @@ -157,22 +161,22 @@ def forward( # for cross attention with decoder later return xs, masks - def forward_layers( - self, xs: torch.Tensor, chunk_masks: torch.Tensor, - pos_emb: torch.Tensor, mask_pad: torch.Tensor - ) -> torch.Tensor: + def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor) -> torch.Tensor: for layer in self.encoders: xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) return xs @torch.jit.ignore(drop=True) - def forward_layers_checkpointed( - self, xs: torch.Tensor, chunk_masks: torch.Tensor, - pos_emb: torch.Tensor, mask_pad: torch.Tensor - ) -> torch.Tensor: + def forward_layers_checkpointed(self, xs: torch.Tensor, + chunk_masks: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor) -> torch.Tensor: for layer in self.encoders: - xs, chunk_masks, _, _ = ckpt.checkpoint( - layer.__call__, xs, chunk_masks, pos_emb, mask_pad) + xs, chunk_masks, _, _ = ckpt.checkpoint(layer.__call__, xs, + chunk_masks, pos_emb, + mask_pad) return xs def forward_chunk( @@ -229,8 +233,8 @@ def forward_chunk( elayers, cache_t1 = att_cache.size(0), att_cache.size(2) chunk_size = xs.size(1) attention_key_size = cache_t1 + chunk_size - pos_emb = self.embed.position_encoding( - offset=offset - cache_t1, size=attention_key_size) + pos_emb = self.embed.position_encoding(offset=offset - cache_t1, + size=attention_key_size) if required_cache_size < 0: next_cache_start = 0 elif required_cache_size == 0: @@ -244,10 +248,11 @@ def forward_chunk( # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2), # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2) xs, _, new_att_cache, new_cnn_cache = layer( - xs, att_mask, pos_emb, + xs, + att_mask, + pos_emb, att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache, - cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache - ) + cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache) # NOTE(xcsong): After layer.forward # shape(new_att_cache) is (1, head, attention_key_size, d_k * 2), # shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2) @@ -314,17 +319,22 @@ def forward_chunk_by_chunk( for cur in range(0, num_frames - context + 1, stride): end = min(cur + decoding_window, num_frames) chunk_xs = xs[:, cur:end, :] - (y, att_cache, cnn_cache) = self.forward_chunk( - chunk_xs, offset, required_cache_size, att_cache, cnn_cache) + (y, att_cache, + cnn_cache) = self.forward_chunk(chunk_xs, offset, + required_cache_size, att_cache, + cnn_cache) outputs.append(y) offset += y.size(1) ys = torch.cat(outputs, 1) - masks = torch.ones((1, 1, ys.size(1)), device=ys.device, dtype=torch.bool) + masks = torch.ones((1, 1, ys.size(1)), + device=ys.device, + dtype=torch.bool) return ys, masks class TransformerEncoder(BaseEncoder): """Transformer encoder module.""" + def __init__( self, input_size: int, @@ -354,24 +364,25 @@ def __init__( linear_units, num_blocks, dropout_rate, positional_dropout_rate, attention_dropout_rate, input_layer, pos_enc_layer_type, normalize_before, - static_chunk_size, use_dynamic_chunk, - global_cmvn, use_dynamic_left_chunk, - gradient_checkpointing) + static_chunk_size, use_dynamic_chunk, global_cmvn, + use_dynamic_left_chunk, gradient_checkpointing) activation = WENET_ACTIVATION_CLASSES[activation_type]() self.encoders = torch.nn.ModuleList([ TransformerEncoderLayer( output_size, - WENET_ATTENTION_CLASSES["selfattn"](attention_heads, output_size, - attention_dropout_rate, key_bias), + WENET_ATTENTION_CLASSES["selfattn"](attention_heads, + output_size, + attention_dropout_rate, + key_bias), PositionwiseFeedForward(output_size, linear_units, dropout_rate, activation), - dropout_rate, - normalize_before) for _ in range(num_blocks) + dropout_rate, normalize_before) for _ in range(num_blocks) ]) class ConformerEncoder(BaseEncoder): """Conformer encoder module.""" + def __init__( self, input_size: int, @@ -421,9 +432,8 @@ def __init__( linear_units, num_blocks, dropout_rate, positional_dropout_rate, attention_dropout_rate, input_layer, pos_enc_layer_type, normalize_before, - static_chunk_size, use_dynamic_chunk, - global_cmvn, use_dynamic_left_chunk, - gradient_checkpointing) + static_chunk_size, use_dynamic_chunk, global_cmvn, + use_dynamic_left_chunk, gradient_checkpointing) activation = WENET_ACTIVATION_CLASSES[activation_type]() # self-attention module definition diff --git a/wenet/transformer/encoder_layer.py b/wenet/transformer/encoder_layer.py index 6807c3267..aafcec412 100644 --- a/wenet/transformer/encoder_layer.py +++ b/wenet/transformer/encoder_layer.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # Modified from ESPnet(https://github.com/espnet/espnet) - """Encoder self-attention layer definition.""" from typing import Optional, Tuple @@ -37,6 +36,7 @@ class TransformerEncoderLayer(nn.Module): True: use layer_norm before each sub-block. False: to use layer_norm after each sub-block. """ + def __init__( self, size: int, @@ -90,8 +90,7 @@ def forward( residual = x if self.normalize_before: x = self.norm1(x) - x_att, new_att_cache = self.self_attn( - x, x, x, mask, cache=att_cache) + x_att, new_att_cache = self.self_attn(x, x, x, mask, cache=att_cache) x = residual + self.dropout(x_att) if not self.normalize_before: x = self.norm1(x) @@ -126,6 +125,7 @@ class ConformerEncoderLayer(nn.Module): True: use layer_norm before each sub-block. False: use layer_norm after each sub-block. """ + def __init__( self, size: int, @@ -150,15 +150,13 @@ def __init__( else: self.ff_scale = 1.0 if self.conv_module is not None: - self.norm_conv = nn.LayerNorm(size, - eps=1e-5) # for the CNN module + self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module self.norm_final = nn.LayerNorm( size, eps=1e-5) # for the final output of the block self.dropout = nn.Dropout(dropout_rate) self.size = size self.normalize_before = normalize_before - def forward( self, x: torch.Tensor, @@ -204,8 +202,8 @@ def forward( residual = x if self.normalize_before: x = self.norm_mha(x) - x_att, new_att_cache = self.self_attn( - x, x, x, mask, pos_emb, att_cache) + x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, + att_cache) x = residual + self.dropout(x_att) if not self.normalize_before: x = self.norm_mha(x) diff --git a/wenet/transformer/label_smoothing_loss.py b/wenet/transformer/label_smoothing_loss.py index 428fedcb0..feacabf09 100644 --- a/wenet/transformer/label_smoothing_loss.py +++ b/wenet/transformer/label_smoothing_loss.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Label smoothing module.""" import torch @@ -51,6 +50,7 @@ class LabelSmoothingLoss(nn.Module): normalize loss by sequence length if True normalize loss by batch size if False """ + def __init__(self, size: int, padding_idx: int, diff --git a/wenet/transformer/positionwise_feed_forward.py b/wenet/transformer/positionwise_feed_forward.py index 73ba239e3..25c578ee1 100644 --- a/wenet/transformer/positionwise_feed_forward.py +++ b/wenet/transformer/positionwise_feed_forward.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Positionwise feed forward layer definition.""" import torch @@ -30,6 +29,7 @@ class PositionwiseFeedForward(torch.nn.Module): dropout_rate (float): Dropout rate. activation (torch.nn.Module): Activation function """ + def __init__(self, idim: int, hidden_units: int, diff --git a/wenet/transformer/search.py b/wenet/transformer/search.py index 7140897fe..7e47b135b 100644 --- a/wenet/transformer/search.py +++ b/wenet/transformer/search.py @@ -28,6 +28,7 @@ class DecodeResult: + def __init__(self, tokens: List[int], score: float = 0.0, @@ -60,6 +61,7 @@ def __init__(self, class PrefixScore: """ For CTC prefix beam search """ + def __init__(self, s: float = float('-inf'), ns: float = float('-inf'), @@ -120,10 +122,13 @@ def ctc_greedy_search(ctc_probs: torch.Tensor, return results -def ctc_prefix_beam_search(ctc_probs: torch.Tensor, ctc_lens: torch.Tensor, - beam_size: int, context_graph: ContextGraph = None, - blank_id: int = 0, - ) -> List[DecodeResult]: +def ctc_prefix_beam_search( + ctc_probs: torch.Tensor, + ctc_lens: torch.Tensor, + beam_size: int, + context_graph: ContextGraph = None, + blank_id: int = 0, +) -> List[DecodeResult]: """ Returns: List[List[List[int]]]: nbest result for each utterance @@ -265,7 +270,8 @@ def attention_beam_search( device=device) # (B*N, 4) # TODO(xcsong): add args for language, task, etc hyps[:, 0] = model.special_tokens["sot"] - hyps[:, 1] = model.special_tokens["sot"] + 1 + WHISPER_LANGS.index("zh") + hyps[:, + 1] = model.special_tokens["sot"] + 1 + WHISPER_LANGS.index("zh") hyps[:, 2] = model.special_tokens["transcribe"] hyps[:, 3] = model.special_tokens["no_timestamps"] else: @@ -374,10 +380,13 @@ def attention_rescoring( if model.special_tokens is not None and "transcribe" in model.special_tokens: # TODO(xcsong): add args for language, task, etc prev_len = hyps_pad.size(1) - hyps_pad, _ = add_whisper_tokens( - model.special_tokens, hyps_pad, model.ignore_id, task="transcribe", - no_timestamp=True, language="zh", use_prev=False - ) + hyps_pad, _ = add_whisper_tokens(model.special_tokens, + hyps_pad, + model.ignore_id, + task="transcribe", + no_timestamp=True, + language="zh", + use_prev=False) cur_len = hyps_pad.size(1) hyps_lens = hyps_lens + cur_len - prev_len prefix_len = 4 @@ -404,7 +413,8 @@ def attention_rescoring( if reverse_weight > 0 and r_decoder_out.dim() > 0: r_score = 0.0 for j, w in enumerate(hyp): - s = r_decoder_out[i][len(hyp) - j - 1 + (prefix_len - 1)][w] + s = r_decoder_out[i][len(hyp) - j - 1 + + (prefix_len - 1)][w] r_score += s tc[j] = (tc[j] + math.exp(s)) / 2 r_score += r_decoder_out[i][len(hyp) + (prefix_len - 1)][eos] diff --git a/wenet/transformer/swish.py b/wenet/transformer/swish.py index b4250f5c9..c5cffc5e1 100644 --- a/wenet/transformer/swish.py +++ b/wenet/transformer/swish.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Swish() activation function for Conformer.""" import torch @@ -21,6 +20,7 @@ class Swish(torch.nn.Module): """Construct an Swish object.""" + def forward(self, x: torch.Tensor) -> torch.Tensor: """Return Swish activation function.""" return x * torch.sigmoid(x) diff --git a/wenet/utils/checkpoint.py b/wenet/utils/checkpoint.py index 89f813c84..42fc8fe67 100644 --- a/wenet/utils/checkpoint.py +++ b/wenet/utils/checkpoint.py @@ -26,8 +26,8 @@ def load_checkpoint(model: torch.nn.Module, path: str) -> dict: logging.info('Checkpoint: loading from checkpoint %s' % path) checkpoint = torch.load(path, map_location='cpu') - missing_keys, unexpected_keys = model.load_state_dict( - checkpoint, strict=False) + missing_keys, unexpected_keys = model.load_state_dict(checkpoint, + strict=False) for key in missing_keys: logging.info("missing tensor: {}".format(key)) for key in unexpected_keys: diff --git a/wenet/utils/class_utils.py b/wenet/utils/class_utils.py index 547117e9b..2ae1cc943 100644 --- a/wenet/utils/class_utils.py +++ b/wenet/utils/class_utils.py @@ -5,23 +5,24 @@ from wenet.transformer.swish import Swish from wenet.transformer.subsampling import ( - LinearNoSubsampling, EmbedinigNoSubsampling, - Conv1dSubsampling2, Conv2dSubsampling4, - Conv2dSubsampling6, Conv2dSubsampling8, + LinearNoSubsampling, + EmbedinigNoSubsampling, + Conv1dSubsampling2, + Conv2dSubsampling4, + Conv2dSubsampling6, + Conv2dSubsampling8, ) from wenet.efficient_conformer.subsampling import Conv2dSubsampling2 from wenet.squeezeformer.subsampling import DepthwiseConv2dSubsampling4 -from wenet.transformer.embedding import ( - PositionalEncoding, RelPositionalEncoding, - WhisperPositionalEncoding, LearnablePositionalEncoding, - NoPositionalEncoding -) -from wenet.transformer.attention import ( - MultiHeadedAttention, RelPositionMultiHeadedAttention -) +from wenet.transformer.embedding import (PositionalEncoding, + RelPositionalEncoding, + WhisperPositionalEncoding, + LearnablePositionalEncoding, + NoPositionalEncoding) +from wenet.transformer.attention import (MultiHeadedAttention, + RelPositionMultiHeadedAttention) from wenet.efficient_conformer.attention import GroupedRelPositionMultiHeadedAttention - WENET_ACTIVATION_CLASSES = { "hardtanh": torch.nn.Hardtanh, "tanh": torch.nn.Tanh, diff --git a/wenet/utils/common.py b/wenet/utils/common.py index 6c4585b2c..df1d1f36f 100644 --- a/wenet/utils/common.py +++ b/wenet/utils/common.py @@ -50,14 +50,22 @@ def pad_list(xs: List[torch.Tensor], pad_value: int): batchs = len(xs) ndim = xs[0].ndim if ndim == 1: - pad_res = torch.zeros(batchs, max_len, - dtype=xs[0].dtype, device=xs[0].device) + pad_res = torch.zeros(batchs, + max_len, + dtype=xs[0].dtype, + device=xs[0].device) elif ndim == 2: - pad_res = torch.zeros(batchs, max_len, xs[0].shape[1], - dtype=xs[0].dtype, device=xs[0].device) + pad_res = torch.zeros(batchs, + max_len, + xs[0].shape[1], + dtype=xs[0].dtype, + device=xs[0].device) elif ndim == 3: - pad_res = torch.zeros(batchs, max_len, xs[0].shape[1], - xs[0].shape[2], dtype=xs[0].dtype, + pad_res = torch.zeros(batchs, + max_len, + xs[0].shape[1], + xs[0].shape[2], + dtype=xs[0].dtype, device=xs[0].device) else: raise ValueError(f"Unsupported ndim: {ndim}") @@ -147,11 +155,9 @@ def add_sos_eos(ys_pad: torch.Tensor, sos: int, eos: int, return pad_list(ys_in, eos), pad_list(ys_out, ignore_id) -def add_whisper_tokens( - special_tokens, ys_pad: torch.Tensor, - ignore_id: int, task: str, no_timestamp: bool, - language: str, use_prev: bool -) -> Tuple[torch.Tensor, torch.Tensor]: +def add_whisper_tokens(special_tokens, ys_pad: torch.Tensor, ignore_id: int, + task: str, no_timestamp: bool, language: str, + use_prev: bool) -> Tuple[torch.Tensor, torch.Tensor]: """Add whisper-style tokens. ([PREV] -> [previous text tokens or hotwords]).optional -- @@ -214,8 +220,10 @@ def add_whisper_tokens( else: raise NotImplementedError - _sot = torch.tensor(_sot, dtype=torch.long, - requires_grad=False, device=ys_pad.device) + _sot = torch.tensor(_sot, + dtype=torch.long, + requires_grad=False, + device=ys_pad.device) ys_in = [torch.cat([_sot, y], dim=0) for y in ys] ys_out = [torch.cat([_sot[1:], y, _eot], dim=0) for y in ys] return pad_list(ys_in, special_tokens["eot"]), pad_list(ys_out, ignore_id) diff --git a/wenet/utils/config.py b/wenet/utils/config.py index 50170ced4..e153d0242 100644 --- a/wenet/utils/config.py +++ b/wenet/utils/config.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. - import copy + def override_config(configs, override_list): new_configs = copy.deepcopy(configs) for item in override_list: diff --git a/wenet/utils/ctc_utils.py b/wenet/utils/ctc_utils.py index 084e32c1c..718d42926 100644 --- a/wenet/utils/ctc_utils.py +++ b/wenet/utils/ctc_utils.py @@ -19,7 +19,8 @@ import torch -def remove_duplicates_and_blank(hyp: List[int], blank_id: int = 0) -> List[int]: +def remove_duplicates_and_blank(hyp: List[int], + blank_id: int = 0) -> List[int]: new_hyp: List[int] = [] cur = 0 while cur < len(hyp): @@ -31,14 +32,16 @@ def remove_duplicates_and_blank(hyp: List[int], blank_id: int = 0) -> List[int]: return new_hyp -def replace_duplicates_with_blank(hyp: List[int], blank_id: int = 0) -> List[int]: +def replace_duplicates_with_blank(hyp: List[int], + blank_id: int = 0) -> List[int]: new_hyp: List[int] = [] cur = 0 while cur < len(hyp): new_hyp.append(hyp[cur]) prev = cur cur += 1 - while cur < len(hyp) and hyp[cur] == hyp[prev] and hyp[cur] != blank_id: + while cur < len( + hyp) and hyp[cur] == hyp[prev] and hyp[cur] != blank_id: new_hyp.append(blank_id) cur += 1 return new_hyp @@ -164,10 +167,12 @@ def get_blank_id(configs, symbol_table): if '' in symbol_table: if 'ctc_blank_id' in configs['ctc_conf']: - assert configs['ctc_conf']['ctc_blank_id'] == symbol_table[''] + assert configs['ctc_conf']['ctc_blank_id'] == symbol_table[ + ''] else: configs['ctc_conf']['ctc_blank_id'] = symbol_table[''] else: - assert 'ctc_blank_id' in configs['ctc_conf'], "PLZ set ctc_blank_id in yaml" + assert 'ctc_blank_id' in configs[ + 'ctc_conf'], "PLZ set ctc_blank_id in yaml" return configs, configs['ctc_conf']['ctc_blank_id'] diff --git a/wenet/utils/executor.py b/wenet/utils/executor.py index 3cc24e699..cfe1084b0 100644 --- a/wenet/utils/executor.py +++ b/wenet/utils/executor.py @@ -29,8 +29,8 @@ class Executor: def __init__(self): self.step = 0 - def train(self, model, optimizer, scheduler, data_loader, writer, - configs, scaler, group_join): + def train(self, model, optimizer, scheduler, data_loader, writer, configs, + scaler, group_join): ''' Train one epoch ''' model.train() @@ -70,15 +70,16 @@ def train(self, model, optimizer, scheduler, data_loader, writer, context = nullcontext with context(): - info_dict = batch_forward(model, batch_dict, scaler, info_dict) + info_dict = batch_forward(model, batch_dict, scaler, + info_dict) info_dict = batch_backward(model, scaler, info_dict) - info_dict = update_parameter_and_lr( - model, optimizer, scheduler, - scaler, info_dict - ) + info_dict = update_parameter_and_lr(model, optimizer, + scheduler, scaler, + info_dict) log_per_step(writer, info_dict) - self.step += 1 if (batch_idx + 1) % info_dict["accum_grad"] == 0 else 0 + self.step += 1 if (batch_idx + + 1) % info_dict["accum_grad"] == 0 else 0 def cv(self, model, data_loader, configs): ''' Cross validation on diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index 537049890..adbead528 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -99,7 +99,8 @@ def init_model(args, configs): assert configs['decoder_conf']['r_num_blocks'] > 0 decoder = BiTransformerDecoder(vocab_size, encoder.output_size(), **configs['decoder_conf']) - ctc = CTC(vocab_size, encoder.output_size(), + ctc = CTC(vocab_size, + encoder.output_size(), blank_id=configs['ctc_conf']['ctc_blank_id'] if 'ctc_conf' in configs else 0) diff --git a/wenet/utils/mask.py b/wenet/utils/mask.py index 6da5890b8..0480fb4f6 100644 --- a/wenet/utils/mask.py +++ b/wenet/utils/mask.py @@ -13,9 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - import torch - ''' def subsequent_mask( size: int, @@ -50,6 +48,7 @@ def subsequent_mask( return torch.tril(ret) ''' + def subsequent_mask( size: int, device: torch.device = torch.device("cpu"), @@ -124,10 +123,12 @@ def subsequent_chunk_mask( return ret -def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor, +def add_optional_chunk_mask(xs: torch.Tensor, + masks: torch.Tensor, use_dynamic_chunk: bool, use_dynamic_left_chunk: bool, - decoding_chunk_size: int, static_chunk_size: int, + decoding_chunk_size: int, + static_chunk_size: int, num_decoding_left_chunks: int, enable_full_context: bool = True): """ Apply optional mask for encoder. @@ -301,6 +302,7 @@ def mask_finished_preds(pred: torch.Tensor, flag: torch.Tensor, finished = flag.repeat([1, beam_size]) return pred.masked_fill_(finished, eos) + def causal_or_lookahead_mask( mask: torch.Tensor, right_context: int, diff --git a/wenet/utils/scheduler.py b/wenet/utils/scheduler.py index f24c84406..6a78bb6c7 100644 --- a/wenet/utils/scheduler.py +++ b/wenet/utils/scheduler.py @@ -41,10 +41,10 @@ class WarmupLR(_LRScheduler): """ def __init__( - self, - optimizer: torch.optim.Optimizer, - warmup_steps: Union[int, float] = 25000, - last_epoch: int = -1, + self, + optimizer: torch.optim.Optimizer, + warmup_steps: Union[int, float] = 25000, + last_epoch: int = -1, ): self.warmup_steps = warmup_steps @@ -58,15 +58,11 @@ def __repr__(self): def get_lr(self): step_num = self.last_epoch + 1 if self.warmup_steps == 0: - return [ - lr * step_num ** -0.5 - for lr in self.base_lrs - ] + return [lr * step_num**-0.5 for lr in self.base_lrs] else: return [ - lr - * self.warmup_steps ** 0.5 - * min(step_num ** -0.5, step_num * self.warmup_steps ** -1.5) + lr * self.warmup_steps**0.5 * + min(step_num**-0.5, step_num * self.warmup_steps**-1.5) for lr in self.base_lrs ] @@ -84,8 +80,14 @@ class WarmupPolicy(_LRScheduler): infinite training """ - def __init__(self, optimizer, *, warmup_steps=None, warmup_ratio=None, - max_steps=None, min_lr=0.0, last_epoch=-1): + def __init__(self, + optimizer, + *, + warmup_steps=None, + warmup_ratio=None, + max_steps=None, + min_lr=0.0, + last_epoch=-1): assert not (warmup_steps is not None and warmup_ratio is not None),\ "Either use particular number of step or ratio" assert warmup_ratio is None or max_steps is not None, \ @@ -109,8 +111,8 @@ def get_lr(self): warnings.warn( "To get the last learning rate computed " "by the scheduler, please use `get_last_lr()`.", - UserWarning, stacklevel=2 - ) + UserWarning, + stacklevel=2) step = self.last_epoch @@ -141,10 +143,14 @@ class SquareRootConstantPolicy(_LRScheduler): infinite training """ - def __init__( - self, optimizer, *, constant_steps=None, constant_ratio=None, - max_steps=None, min_lr=0.0, last_epoch=-1 - ): + def __init__(self, + optimizer, + *, + constant_steps=None, + constant_ratio=None, + max_steps=None, + min_lr=0.0, + last_epoch=-1): assert not (constant_steps is not None and constant_ratio is not None), \ "Either use particular number of step or ratio" @@ -161,7 +167,7 @@ def __init__( else: self.constant_steps = 0 - self.constant_lr = 1 / (constant_steps ** 0.5) + self.constant_lr = 1 / (constant_steps**0.5) self.min_lr = min_lr super().__init__(optimizer, last_epoch) @@ -170,8 +176,8 @@ def get_lr(self): warnings.warn( "To get the last learning rate computed " "by the scheduler, please use `get_last_lr()`.", - UserWarning, stacklevel=2 - ) + UserWarning, + stacklevel=2) step = self.last_epoch @@ -203,16 +209,16 @@ class WarmupHoldPolicy(WarmupPolicy): """ def __init__( - self, - optimizer, - *, - warmup_steps=None, - warmup_ratio=None, - hold_steps=None, - hold_ratio=None, - max_steps=None, - min_lr=0.0, - last_epoch=-1, + self, + optimizer, + *, + warmup_steps=None, + warmup_ratio=None, + hold_steps=None, + hold_ratio=None, + max_steps=None, + min_lr=0.0, + last_epoch=-1, ): assert not (hold_steps is not None and hold_ratio is not None), \ "Either use particular number of step or ratio" @@ -251,9 +257,10 @@ def get_lr(self): if not self._get_lr_called_within_step: warnings.warn( "To get the last learning rate computed by the scheduler," - " " "please use `get_last_lr()`.", - UserWarning, stacklevel=2 - ) + " " + "please use `get_last_lr()`.", + UserWarning, + stacklevel=2) step = self.last_epoch @@ -285,16 +292,16 @@ class WarmupAnnealHoldPolicy(_LRScheduler): """ def __init__( - self, - optimizer, - *, - warmup_steps=None, - warmup_ratio=None, - constant_steps=None, - constant_ratio=None, - max_steps=None, - min_lr=0.0, - last_epoch=-1, + self, + optimizer, + *, + warmup_steps=None, + warmup_ratio=None, + constant_steps=None, + constant_ratio=None, + max_steps=None, + min_lr=0.0, + last_epoch=-1, ): assert not (warmup_steps is not None and warmup_ratio is not None), \ @@ -323,7 +330,8 @@ def __init__( else: self.constant_steps = 0 - self.decay_steps = max_steps - (self.constant_steps + self.warmup_steps) + self.decay_steps = max_steps - (self.constant_steps + + self.warmup_steps) self.min_lr = min_lr super().__init__(optimizer, last_epoch) @@ -333,8 +341,8 @@ def get_lr(self): warnings.warn( "To get the last learning rate computed " "by the scheduler, please use `get_last_lr()`.", - UserWarning, stacklevel=2 - ) + UserWarning, + stacklevel=2) step = self.last_epoch @@ -366,14 +374,14 @@ def _get_lr(self, step): def _squareroot_annealing(initial_lr, step, max_steps, min_lr): - mult = ((max_steps - step) / max_steps) ** 0.5 + mult = ((max_steps - step) / max_steps)**0.5 out_lr = initial_lr * mult out_lr = max(out_lr, min_lr) return out_lr def _square_annealing(initial_lr, step, max_steps, min_lr): - mult = ((max_steps - step) / max_steps) ** 2 + mult = ((max_steps - step) / max_steps)**2 out_lr = initial_lr * mult out_lr = max(out_lr, min_lr) return out_lr @@ -421,22 +429,31 @@ def _poly_decay(initial_lr, step, decay_steps, power, min_lr, cycle): return lr -def _noam_hold_annealing(initial_lr, step, warmup_steps, - hold_steps, decay_rate, min_lr): +def _noam_hold_annealing(initial_lr, step, warmup_steps, hold_steps, + decay_rate, min_lr): # hold_steps = total number of steps # to hold the LR, not the warmup + hold steps. - T_warmup_decay = max(1, warmup_steps ** decay_rate) - T_hold_decay = max(1, (step - hold_steps) ** decay_rate) + T_warmup_decay = max(1, warmup_steps**decay_rate) + T_hold_decay = max(1, (step - hold_steps)**decay_rate) lr = (initial_lr * T_warmup_decay) / T_hold_decay lr = max(lr, min_lr) return lr class SquareAnnealing(WarmupPolicy): - def __init__(self, optimizer, *, max_steps, min_lr=1e-5, last_epoch=-1, + + def __init__(self, + optimizer, + *, + max_steps, + min_lr=1e-5, + last_epoch=-1, **kwargs): - super().__init__(optimizer=optimizer, max_steps=max_steps, - last_epoch=last_epoch, min_lr=min_lr, **kwargs) + super().__init__(optimizer=optimizer, + max_steps=max_steps, + last_epoch=last_epoch, + min_lr=min_lr, + **kwargs) def _get_lr(self, step): new_lrs = [ @@ -445,40 +462,58 @@ def _get_lr(self, step): step=step - self.warmup_steps, max_steps=self.max_steps - self.warmup_steps, min_lr=self.min_lr, - ) - for initial_lr in self.base_lrs + ) for initial_lr in self.base_lrs ] return new_lrs class SquareRootAnnealing(WarmupPolicy): - def __init__(self, optimizer, *, max_steps, min_lr=0, last_epoch=-1, + + def __init__(self, + optimizer, + *, + max_steps, + min_lr=0, + last_epoch=-1, **kwargs): - super().__init__(optimizer=optimizer, max_steps=max_steps, - last_epoch=last_epoch, min_lr=min_lr, **kwargs) + super().__init__(optimizer=optimizer, + max_steps=max_steps, + last_epoch=last_epoch, + min_lr=min_lr, + **kwargs) def _get_lr(self, step): new_lrs = [ - _squareroot_annealing(initial_lr=initial_lr, step=step, - max_steps=self.max_steps, min_lr=self.min_lr) + _squareroot_annealing(initial_lr=initial_lr, + step=step, + max_steps=self.max_steps, + min_lr=self.min_lr) for initial_lr in self.base_lrs ] return new_lrs class CosineAnnealing(WarmupAnnealHoldPolicy): - def __init__(self, optimizer, *, max_steps, min_lr=0, last_epoch=-1, + + def __init__(self, + optimizer, + *, + max_steps, + min_lr=0, + last_epoch=-1, **kwargs): - super().__init__(optimizer=optimizer, max_steps=max_steps, - last_epoch=last_epoch, min_lr=min_lr, **kwargs) + super().__init__(optimizer=optimizer, + max_steps=max_steps, + last_epoch=last_epoch, + min_lr=min_lr, + **kwargs) def _get_lr(self, step): for initial_lr in self.base_lrs: if initial_lr < self.min_lr: raise ValueError( f"{self} received an initial learning rate " - f"that was lower than the minimum learning rate." - ) + f"that was lower than the minimum learning rate.") if self.constant_steps is None or self.constant_steps == 0: new_lrs = [ @@ -487,8 +522,7 @@ def _get_lr(self, step): step=step - self.warmup_steps, max_steps=self.max_steps - self.warmup_steps, min_lr=self.min_lr, - ) - for initial_lr in self.base_lrs + ) for initial_lr in self.base_lrs ] else: new_lrs = self._get_linear_warmup_with_cosine_annealing_lr(step) @@ -515,18 +549,23 @@ def _get_linear_warmup_with_cosine_annealing_lr(self, step): step=step, decay_steps=self.decay_steps, min_lr=self.min_lr, - ) - for _ in self.base_lrs + ) for _ in self.base_lrs ] return new_lrs class NoamAnnealing(_LRScheduler): - def __init__( - self, optimizer, *, d_model, warmup_steps=None, warmup_ratio=None, - max_steps=None, min_lr=0.0, last_epoch=-1 - ): - self._normalize = d_model ** (-0.5) + + def __init__(self, + optimizer, + *, + d_model, + warmup_steps=None, + warmup_ratio=None, + max_steps=None, + min_lr=0.0, + last_epoch=-1): + self._normalize = d_model**(-0.5) assert not (warmup_steps is not None and warmup_ratio is not None), \ "Either use particular number of step or ratio" @@ -551,8 +590,8 @@ def get_lr(self): warnings.warn( "To get the last learning rate computed " "by the scheduler, please use `get_last_lr()`.", - UserWarning, stacklevel=2 - ) + UserWarning, + stacklevel=2) step = max(1, self.last_epoch) @@ -560,19 +599,20 @@ def get_lr(self): if initial_lr < self.min_lr: raise ValueError( f"{self} received an initial learning rate " - f"that was lower than the minimum learning rate." - ) + f"that was lower than the minimum learning rate.") - new_lrs = [self._noam_annealing(initial_lr=initial_lr, step=step) for - initial_lr in self.base_lrs] + new_lrs = [ + self._noam_annealing(initial_lr=initial_lr, step=step) + for initial_lr in self.base_lrs + ] return new_lrs def _noam_annealing(self, initial_lr, step): if self.warmup_steps > 0: - mult = self._normalize * min(step ** (-0.5), - step * (self.warmup_steps ** (-1.5))) + mult = self._normalize * min(step**(-0.5), + step * (self.warmup_steps**(-1.5))) else: - mult = self._normalize * step ** (-0.5) + mult = self._normalize * step**(-0.5) out_lr = initial_lr * mult if step > self.warmup_steps: @@ -581,8 +621,15 @@ def _noam_annealing(self, initial_lr, step): class NoamHoldAnnealing(WarmupHoldPolicy): - def __init__(self, optimizer, *, max_steps, decay_rate=0.5, min_lr=0.0, - last_epoch=-1, **kwargs): + + def __init__(self, + optimizer, + *, + max_steps, + decay_rate=0.5, + min_lr=0.0, + last_epoch=-1, + **kwargs): """ From Nemo: Implementation of the Noam Hold Annealing policy @@ -637,8 +684,11 @@ def __init__(self, optimizer, *, max_steps, decay_rate=0.5, min_lr=0.0, min_lr: Minimum learning rate. """ self.decay_rate = decay_rate - super().__init__(optimizer=optimizer, max_steps=max_steps, - last_epoch=last_epoch, min_lr=min_lr, **kwargs) + super().__init__(optimizer=optimizer, + max_steps=max_steps, + last_epoch=last_epoch, + min_lr=min_lr, + **kwargs) def _get_lr(self, step): if self.warmup_steps is None or self.warmup_steps == 0: @@ -658,8 +708,7 @@ def _get_lr(self, step): hold_steps=hold_steps, decay_rate=self.decay_rate, min_lr=self.min_lr, - ) - for initial_lr in self.base_lrs + ) for initial_lr in self.base_lrs ] return new_lrs diff --git a/wenet/utils/train_utils.py b/wenet/utils/train_utils.py index 9343c2431..7c3abcf06 100644 --- a/wenet/utils/train_utils.py +++ b/wenet/utils/train_utils.py @@ -29,14 +29,11 @@ from torch.utils.data import DataLoader from torch.nn.utils import clip_grad_norm_ from deepspeed.runtime.zero.stage_1_and_2 import ( - estimate_zero2_model_states_mem_needs_all_live -) + estimate_zero2_model_states_mem_needs_all_live) from deepspeed.runtime.zero.stage3 import ( - estimate_zero3_model_states_mem_needs_all_live -) + estimate_zero3_model_states_mem_needs_all_live) from deepspeed.utils.zero_to_fp32 import ( - convert_zero_checkpoint_to_fp32_state_dict -) + convert_zero_checkpoint_to_fp32_state_dict) from wenet.dataset.dataset import Dataset from wenet.utils.checkpoint import save_checkpoint from wenet.utils.scheduler import WarmupLR, NoamHoldAnnealing @@ -54,8 +51,9 @@ def add_model_args(parser): parser.add_argument('--symbol_table', required=True, help='model unit symbol table for training') - parser.add_argument("--non_lang_syms", - help="non-linguistic symbol file. One symbol per line.") + parser.add_argument( + "--non_lang_syms", + help="non-linguistic symbol file. One symbol per line.") parser.add_argument('--bpe_model', default=None, type=str, @@ -68,10 +66,11 @@ def add_model_args(parser): default=None, type=str, help="Pre-trained model to initialize encoder") - parser.add_argument("--enc_init_mods", - default="encoder.", - type=lambda s: [str(mod) for mod in s.split(",") if s != ""], - help="List of encoder modules \ + parser.add_argument( + "--enc_init_mods", + default="encoder.", + type=lambda s: [str(mod) for mod in s.split(",") if s != ""], + help="List of encoder modules \ to initialize ,separated by a comma") parser.add_argument('--lfmmi_dir', default='', @@ -79,6 +78,7 @@ def add_model_args(parser): help='LF-MMI dir') return parser + def add_trace_args(parser): parser.add_argument('--jit', action='store_true', @@ -90,6 +90,7 @@ def add_trace_args(parser): help='print model') return parser + def add_dataset_args(parser): parser.add_argument('--data_type', default='raw', @@ -130,10 +131,14 @@ def add_ddp_args(parser): def add_deepspeed_args(parser): - parser.add_argument('--timeout', default=30, type=int, + parser.add_argument('--timeout', + default=30, + type=int, help='timeout (in seconds) of wenet_join. ' + - '30s for aishell & 300s for wenetspeech') - parser.add_argument('--local_rank', type=int, default=-1, + '30s for aishell & 300s for wenetspeech') + parser.add_argument('--local_rank', + type=int, + default=-1, help='local rank passed from distributed launcher') parser.add_argument('--deepspeed.save_states', dest='save_states', @@ -201,14 +206,16 @@ def check_modify_and_save_config(args, configs, symbol_table): else: configs["dtype"] = "fp32" assert ds_configs["train_micro_batch_size_per_gpu"] == 1 - assert ds_configs["gradient_accumulation_steps"] == configs['accum_grad'] + assert ds_configs["gradient_accumulation_steps"] == configs[ + 'accum_grad'] assert ds_configs["gradient_clipping"] == configs['grad_clip'] assert ds_configs["steps_per_print"] == configs['log_interval'] if 'fbank_conf' in configs['dataset_conf']: input_dim = configs['dataset_conf']['fbank_conf']['num_mel_bins'] elif 'log_mel_spectrogram_conf' in configs['dataset_conf']: - input_dim = configs['dataset_conf']['log_mel_spectrogram_conf']['num_mel_bins'] + input_dim = configs['dataset_conf']['log_mel_spectrogram_conf'][ + 'num_mel_bins'] else: input_dim = configs['dataset_conf']['mfcc_conf']['num_mel_bins'] @@ -245,11 +252,8 @@ def init_dataset_and_dataloader(args, configs, tokenizer): cv_conf['shuffle'] = False configs['vocab_size'] = tokenizer.vocab_size() - train_dataset = Dataset(args.data_type, - args.train_data, - tokenizer, - train_conf, - True) + train_dataset = Dataset(args.data_type, args.train_data, tokenizer, + train_conf, True) cv_dataset = Dataset(args.data_type, args.cv_data, tokenizer, @@ -285,25 +289,25 @@ def wrap_cuda_model(args, model): device = torch.device("cuda") if args.fp16_grad_sync: from torch.distributed.algorithms.ddp_comm_hooks import ( - default as comm_hooks, - ) - model.register_comm_hook( - state=None, hook=comm_hooks.fp16_compress_hook - ) + default as comm_hooks, ) + model.register_comm_hook(state=None, + hook=comm_hooks.fp16_compress_hook) elif args.train_engine == "deepspeed": # deepspeed # NOTE(xcsong): look in detail how the memory estimator API works: # https://deepspeed.readthedocs.io/en/latest/memory.html#discussion if int(os.environ.get('RANK', 0)) == 0: logging.info("Estimating model states memory needs (zero2)...") estimate_zero2_model_states_mem_needs_all_live( - model, num_gpus_per_node=local_world_size, + model, + num_gpus_per_node=local_world_size, num_nodes=world_size // local_world_size) logging.info("Estimating model states memory needs (zero3)...") estimate_zero3_model_states_mem_needs_all_live( - model, num_gpus_per_node=local_world_size, + model, + num_gpus_per_node=local_world_size, num_nodes=world_size // local_world_size) - device = None # Init device later - pass # Init DeepSpeed later + device = None # Init device later + pass # Init DeepSpeed later else: logging.error("not supported engine: {}".format(args.train_engine)) @@ -343,11 +347,16 @@ def init_optimizer_and_scheduler(args, configs, model): if "scheduler" in ds_configs: scheduler = None else: + def scheduler(opt): return scheduler_type(opt, **configs['scheduler_conf']) + model, optimizer, _, scheduler = deepspeed.initialize( - args=args, model=model, optimizer=optimizer, - lr_scheduler=scheduler, model_parameters=model.parameters()) + args=args, + model=model, + optimizer=optimizer, + lr_scheduler=scheduler, + model_parameters=model.parameters()) step = configs["init_infos"].get("step", -1) scheduler.set_step(step) @@ -388,10 +397,13 @@ def save_model(model, info_dict): # https://github.com/microsoft/DeepSpeed/issues/2993 with torch.no_grad(): model.save_checkpoint(save_dir=model_dir, - tag=tag, client_state=info_dict) + tag=tag, + client_state=info_dict) if info_dict["save_states"] == "model_only" and rank == 0: - convert_zero_checkpoint_to_fp32_state_dict( - model_dir, "{}/{}.pt".format(model_dir, tag), tag=tag) + convert_zero_checkpoint_to_fp32_state_dict(model_dir, + "{}/{}.pt".format( + model_dir, tag), + tag=tag) os.system("rm -rf {}/{}".format(model_dir, tag)) elif rank == 0: # NOTE(xcsong): For torch_ddp, only rank-0 should call this. @@ -426,8 +438,8 @@ def wenet_join(group_join, info_dict): except RuntimeError as e: logging.info("Detected uneven workload distribution: {}\n".format(e) + "Break current worker to manually join all workers, " + - "world_size {}, current rank {}, current local_rank {}\n".format( - world_size, rank, local_rank)) + "world_size {}, current rank {}, current local_rank {}\n". + format(world_size, rank, local_rank)) return True return False @@ -448,9 +460,9 @@ def batch_forward(model, batch, scaler, info_dict): if train_engine == "deepspeed": # deepspeed - with torch.cuda.amp.autocast( - enabled=dtype is not None, dtype=dtype, cache_enabled=False - ): + with torch.cuda.amp.autocast(enabled=dtype is not None, + dtype=dtype, + cache_enabled=False): loss_dict = model(batch["feats"].to(device), batch["feats_lengths"].to(device), batch["target"].to(device), @@ -494,10 +506,7 @@ def batch_backward(model, scaler, info_dict): return info_dict -def update_parameter_and_lr( - model, optimizer, - scheduler, scaler, info_dict -): +def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict): rank = int(os.environ.get('RANK', 0)) train_engine = info_dict.get("train_engine", "torch_ddp") accum_grad = info_dict.get('accum_grad', 1) @@ -570,7 +579,8 @@ def log_per_step(writer, info_dict): (train_engine == "torch_ddp" and (batch_idx + 1) % accum_grad == 0): writer.add_scalar('train/train_loss', loss_dict['loss'].item() * accum_grad, step + 1) - writer.add_scalar('train/grad_norm', info_dict['grad_norm'], step + 1) + writer.add_scalar('train/grad_norm', info_dict['grad_norm'], + step + 1) if (batch_idx + 1) % log_interval == 0: log_str = '{} Batch {}/{} loss {:.6f} '.format( diff --git a/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py b/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py index 039f3c67d..3caf1ebaa 100644 --- a/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py +++ b/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """ Requirements: @@ -29,7 +28,6 @@ ``` """ - import argparse import copy import os @@ -40,6 +38,7 @@ _cpath_ = sys.path[0] sys.path.remove(_cpath_) from whisper.tokenizer import get_tokenizer + sys.path.insert(0, _cpath_) @@ -102,18 +101,22 @@ def convert_to_wenet_yaml(tokenizer, dims, wenet_yaml_path: str): configs['model_conf']['special_tokens']['sot'] = tokenizer.sot configs['model_conf']['special_tokens']['eot'] = tokenizer.sot configs['model_conf']['special_tokens']['sot_prev'] = tokenizer.sot_prev - configs['model_conf']['special_tokens']['transcribe'] = tokenizer.transcribe + configs['model_conf']['special_tokens'][ + 'transcribe'] = tokenizer.transcribe configs['model_conf']['special_tokens']['translate'] = tokenizer.translate - configs['model_conf']['special_tokens']['no_timestamps'] = tokenizer.no_timestamps + configs['model_conf']['special_tokens'][ + 'no_timestamps'] = tokenizer.no_timestamps configs['model_conf']['special_tokens']['no_speech'] = tokenizer.no_speech configs['model_conf']['special_tokens']['timestamp_begin'] = \ tokenizer.timestamp_begin configs['dataset_conf'] = {} configs['dataset_conf']['filter_conf'] = {} - configs['dataset_conf']['filter_conf']['max_length'] = dims['n_audio_ctx'] * 2 # 1/2 subsample # noqa + configs['dataset_conf']['filter_conf'][ + 'max_length'] = dims['n_audio_ctx'] * 2 # 1/2 subsample # noqa configs['dataset_conf']['filter_conf']['min_length'] = 0 - configs['dataset_conf']['filter_conf']['token_max_length'] = dims['n_text_ctx'] + configs['dataset_conf']['filter_conf']['token_max_length'] = dims[ + 'n_text_ctx'] configs['dataset_conf']['filter_conf']['token_min_length'] = 1 configs['dataset_conf']['resample_conf'] = {} configs['dataset_conf']['resample_conf']['resample_rate'] = 16000 @@ -137,7 +140,8 @@ def convert_to_wenet_yaml(tokenizer, dims, wenet_yaml_path: str): configs['dataset_conf']['log_mel_spectrogram_conf'] = {} configs['dataset_conf']['log_mel_spectrogram_conf']['n_fft'] = 400 configs['dataset_conf']['log_mel_spectrogram_conf']['hop_length'] = 160 - configs['dataset_conf']['log_mel_spectrogram_conf']['num_mel_bins'] = dims['n_mels'] + configs['dataset_conf']['log_mel_spectrogram_conf']['num_mel_bins'] = dims[ + 'n_mels'] configs['dataset_conf']['log_mel_spectrogram_conf']['padding'] = 0 configs['dataset_conf']['batch_conf'] = {} configs['dataset_conf']['batch_conf']['batch_type'] = 'dynamic' @@ -167,7 +171,9 @@ def convert_to_wenet_yaml(tokenizer, dims, wenet_yaml_path: str): def convert_to_wenet_state_dict(whisper_state_dict, wenet_state_dict_path): wenet_state_dict = {} unused = [] - print("===================== start CKPT Conversion =========================") + print( + "===================== start CKPT Conversion =========================" + ) for name in whisper_state_dict.keys(): original_name = copy.deepcopy(name) name = name.replace("encoder.conv1", "encoder.embed.conv.0") @@ -211,7 +217,9 @@ def convert_to_wenet_state_dict(whisper_state_dict, wenet_state_dict_path): print("NOTE!!! drop {}".format(name)) print("Saving fp32 ckpt to {}...".format(wenet_state_dict_path)) torch.save(wenet_state_dict, wenet_state_dict_path) - print("DONE\n===================== End CKPT Conversion =========================\n") + print( + "DONE\n===================== End CKPT Conversion =========================\n" + ) def convert_to_wenet_units(tokenizer, units_txt_path): @@ -235,11 +243,17 @@ def convert_to_wenet_units(tokenizer, units_txt_path): def get_args(): parser = argparse.ArgumentParser(description='load and parse whisper') - parser.add_argument('--whisper_ckpt', required=True, - help='https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt') - parser.add_argument('--output_dir', default='.', + # yapf: disable + parser.add_argument( + '--whisper_ckpt', + required=True, + help='https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt' # noqa + ) + # yapf: enable + parser.add_argument('--output_dir', + default='.', help='output file in wenet\'s style: ' + - 'units.txt, train.yaml, model.pt') + 'units.txt, train.yaml, model.pt') args = parser.parse_args() return args @@ -249,20 +263,16 @@ def main(): checkpoint = torch.load(args.whisper_ckpt, map_location="cpu") multilingual = checkpoint["dims"]['n_vocab'] >= 51865 num_languages = checkpoint["dims"]['n_vocab'] - 51765 - int(multilingual) - tokenizer = get_tokenizer(multilingual=multilingual, num_languages=num_languages) + tokenizer = get_tokenizer(multilingual=multilingual, + num_languages=num_languages) convert_to_wenet_state_dict( checkpoint["model_state_dict"], - os.path.join(args.output_dir, 'wenet_whisper.pt') - ) - convert_to_wenet_units( - tokenizer, - os.path.join(args.output_dir, 'units.txt') - ) - convert_to_wenet_yaml( - tokenizer, checkpoint["dims"], - os.path.join(args.output_dir, 'train.yaml') - ) + os.path.join(args.output_dir, 'wenet_whisper.pt')) + convert_to_wenet_units(tokenizer, os.path.join(args.output_dir, + 'units.txt')) + convert_to_wenet_yaml(tokenizer, checkpoint["dims"], + os.path.join(args.output_dir, 'train.yaml')) if __name__ == "__main__": diff --git a/wenet/whisper/whisper.py b/wenet/whisper/whisper.py index 0cc73e6ea..48e64f60e 100644 --- a/wenet/whisper/whisper.py +++ b/wenet/whisper/whisper.py @@ -26,6 +26,7 @@ class Whisper(ASRModel): + def __init__( self, vocab_size: int, @@ -39,9 +40,9 @@ def __init__( length_normalized_loss: bool = False, special_tokens: dict = None, ): - super().__init__(vocab_size, encoder, decoder, ctc, ctc_weight, ignore_id, - reverse_weight, lsm_weight, length_normalized_loss, - special_tokens) + super().__init__(vocab_size, encoder, decoder, ctc, ctc_weight, + ignore_id, reverse_weight, lsm_weight, + length_normalized_loss, special_tokens) assert reverse_weight == 0.0 self.sos = special_tokens["sot"] self.eos = special_tokens["eot"] @@ -67,10 +68,13 @@ def _calc_att_loss( ) -> Tuple[torch.Tensor, float]: # TODO(xcsong): add args for no_timestamp, language, etc prev_len = ys_pad.size(1) - ys_in_pad, ys_out_pad = add_whisper_tokens( - self.special_tokens, ys_pad, self.ignore_id, task="transcribe", - no_timestamp=True, language="zh", use_prev=False - ) + ys_in_pad, ys_out_pad = add_whisper_tokens(self.special_tokens, + ys_pad, + self.ignore_id, + task="transcribe", + no_timestamp=True, + language="zh", + use_prev=False) cur_len = ys_in_pad.size(1) ys_in_lens = ys_pad_lens + cur_len - prev_len