Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[lint] auto format all by pre-commit, including c++, python #2199

Merged
merged 3 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,15 @@
#
import os
import sys
sys.path.insert(0, os.path.abspath('..'))

sys.path.insert(0, os.path.abspath('..'))

# -- Project information -----------------------------------------------------

project = 'wenet'
copyright = '2020, wenet-team'
author = 'wenet-team'


# -- General configuration ---------------------------------------------------

# Add any Sphinx extension module names here, as strings. They can be
Expand All @@ -43,7 +42,6 @@
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']


# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
source_suffix = {
Expand All @@ -57,7 +55,6 @@
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']


# -- Options for HTML output -------------------------------------------------

# The theme to use for HTML and HTML Help pages. See the documentation for
Expand Down
9 changes: 6 additions & 3 deletions examples/aishell/NST/local/generate_data_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os
import random


def get_args():
parser = argparse.ArgumentParser(description='generate data.list file ')
parser.add_argument('--tar_dir', help='path for tar file')
Expand All @@ -23,8 +24,8 @@ def get_args():
parser.add_argument('--pseudo_data_ratio',
type=float,
help='ratio of pseudo data, '
'0 means none pseudo data, '
'1 means all using pseudo data.')
'0 means none pseudo data, '
'1 means all using pseudo data.')
parser.add_argument('--out_data_list', help='output path for data list')
args = parser.parse_args()
return args
Expand Down Expand Up @@ -55,7 +56,9 @@ def main():
for i in range(len(pseudo_data_list)):
pseudo_data_list[i] = target_dir + "/" + pseudo_data_list[i] + "\n"

fused_list = pseudo_data_list[:pseudo_len] + supervised_data_list[:supervised_len]
fused_list = pseudo_data_list[:
pseudo_len] + supervised_data_list[:
supervised_len]

with open(output_file, "w") as writer:
for line in fused_list:
Expand Down
58 changes: 37 additions & 21 deletions examples/aishell/NST/local/generate_filtered_pseudo_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,41 @@


def get_args():
parser = argparse.ArgumentParser(description='generate filter pseudo label')
parser.add_argument('--dir_num', required=True, help='split directory number')
parser.add_argument('--cer_hypo_dir', required=True,
parser = argparse.ArgumentParser(
description='generate filter pseudo label')
parser.add_argument('--dir_num',
required=True,
help='split directory number')
parser.add_argument('--cer_hypo_dir',
required=True,
help='prefix for cer_hypo_dir')
parser.add_argument('--utter_time_file', required=True,
parser.add_argument('--utter_time_file',
required=True,
help='the json file that contains audio time infos ')
parser.add_argument('--cer_hypo_threshold', required=True, type=float,
parser.add_argument('--cer_hypo_threshold',
required=True,
type=float,
help='the cer-hypo threshold used to filter')
parser.add_argument('--speak_rate_threshold', type=float,
parser.add_argument('--speak_rate_threshold',
type=float,
help='the cer threshold we use to filter')
parser.add_argument('--dir', required=True, help='dir for the experiment ')
# output untar and tar
parser.add_argument('--untar_dir', required=True,
parser.add_argument('--untar_dir',
required=True,
help='the output path, '
'eg: data/train/wenet_untar_cer_hypo_nst1/')
parser.add_argument('--tar_dir', required=True,
'eg: data/train/wenet_untar_cer_hypo_nst1/')
parser.add_argument('--tar_dir',
required=True,
help='the tar file path, '
'eg: data/train/wenet_tar_cer_hypo_leq_10_nst1/')
parser.add_argument('--wav_dir', required=True,
'eg: data/train/wenet_tar_cer_hypo_leq_10_nst1/')
parser.add_argument('--wav_dir',
required=True,
help='dir to store wav files, '
'eg "data/train/wenet_1k_untar/"')
parser.add_argument('--start_tar_id', default=0 , type=int,
'eg "data/train/wenet_1k_untar/"')
parser.add_argument('--start_tar_id',
default=0,
type=int,
help='the initial tar id (for debugging)')
args = parser.parse_args()
return args
Expand Down Expand Up @@ -118,11 +131,14 @@ def main():

utt_time = utter_time[utt_id]

cer_dict[utt_id] = [pred_no_lm, pred_lm, wer_pred_lm,
utt_time, n_hypo, prediction]
cer_dict[utt_id] = [
pred_no_lm, pred_lm, wer_pred_lm, utt_time, n_hypo,
prediction
]
else:
cer_dict[utt_id] = [pred_no_lm, pred_lm,
wer_pred_lm, -1, -1, prediction]
cer_dict[utt_id] = [
pred_no_lm, pred_lm, wer_pred_lm, -1, -1, prediction
]

c = 0
cer_preds = []
Expand Down Expand Up @@ -170,8 +186,8 @@ def main():
os.makedirs(out_put_dir, exist_ok=True)

for i in range(cur_id, end_id):
print("dir:", dir_num, ", " "tar: ", tar_id,
", ", "progress:", i / len_data)
print("dir:", dir_num, ", "
"tar: ", tar_id, ", ", "progress:", i / len_data)

t_id, utter = data_filtered[i]

Expand All @@ -184,8 +200,8 @@ def main():
with open(output_path, "w", encoding="utf-8") as writer:
writer.write(utter)
# update .wav
os.system("cp" + " " + wav_path + " "
+ out_put_dir + t_id + ".wav")
os.system("cp" + " " + wav_path + " " + out_put_dir + t_id +
".wav")
else:
print(" wav does not exists ! ", wav_path)
not_exist.append(wav_path)
Expand Down
27 changes: 19 additions & 8 deletions examples/aishell/NST/local/get_wav_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,27 @@

def get_args():
parser = argparse.ArgumentParser(description='sum up prediction wer')
parser.add_argument('--job_num', type=int, default=8,
parser.add_argument('--job_num',
type=int,
default=8,
help='number of total split dir')
parser.add_argument('--dir_split', required=True,
parser.add_argument('--dir_split',
required=True,
help='the path to the data_list dir '
'eg data/train/wenet1k_good_split_60/')
parser.add_argument('--label', type=int, default=0,
'eg data/train/wenet1k_good_split_60/')
parser.add_argument('--label',
type=int,
default=0,
help='if ture, label file will also be considered.')
parser.add_argument('--hypo_name', type=str, required=True,
parser.add_argument('--hypo_name',
type=str,
required=True,
help='the hypothesis path. eg. /hypothesis_0.txt ')
parser.add_argument('--wav_dir', type=str, required=True,
help='the wav dir path. eg. data/train/wenet_1k_untar/ ')
parser.add_argument(
'--wav_dir',
type=str,
required=True,
help='the wav dir path. eg. data/train/wenet_1k_untar/ ')
args = parser.parse_args()
return args

Expand All @@ -53,7 +63,8 @@ def main():
output_wav = data_list_dir + "data_sublist" + str(i) + "/wav.scp"
output_label = data_list_dir + "data_sublist" + str(i) + "/label.txt"
# bad lines are just for debugging
output_bad_lines = data_list_dir + "data_sublist" + str(i) + "/bad_line.txt"
output_bad_lines = data_list_dir + "data_sublist" + str(
i) + "/bad_line.txt"

with open(hypo_path, 'r', encoding="utf-8") as reader:
hypo_lines = reader.readlines()
Expand Down
16 changes: 10 additions & 6 deletions examples/aishell/NST/local/split_data_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@

def get_args():
parser = argparse.ArgumentParser(description='')
parser.add_argument('--job_nums', type=int, default=8,
parser.add_argument('--job_nums',
type=int,
default=8,
help='number of total split jobs')
parser.add_argument('--data_list_path', required=True,
parser.add_argument('--data_list_path',
required=True,
help='the path to the data.list file')
parser.add_argument('--output_dir', required=True,
parser.add_argument('--output_dir',
required=True,
help='path to output dir, '
'eg --output_dir=data/train/aishell_split_60')
'eg --output_dir=data/train/aishell_split_60')
args = parser.parse_args()
return args

Expand All @@ -46,7 +50,7 @@ def main():
len_d = int(len(data_list_we) / num_lists)
rest_lines = data_list_we[num_lists * len_d:]
rest_len = len(rest_lines)
print("total num of lines", len(data_list_we) , "rest len is", rest_len)
print("total num of lines", len(data_list_we), "rest len is", rest_len)

# generate N sublist
for i in range(num_lists):
Expand All @@ -57,7 +61,7 @@ def main():

with open(output_list, 'w', encoding="utf-8") as writer:

new_list = data_list_we[i * len_d: (i + 1) * len_d]
new_list = data_list_we[i * len_d:(i + 1) * len_d]
if i < rest_len:
new_list.append(rest_lines[i])
for x in new_list:
Expand Down
1 change: 1 addition & 0 deletions examples/aishell4/s0/local/aishell4_process_textgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@


class Segment(object):

def __init__(self, uttid, spkr, stime, etime, text):
self.uttid = uttid
self.spkr = spkr
Expand Down
9 changes: 7 additions & 2 deletions examples/commonvoice/fr/local/create_scp_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,25 @@
import sys
import os
import re


def process(src_str):
punc = '~`!#$%^&*()_+-=|\';":/.,?><~·!@#¥%……&*()——+-=“:’;、。,?》《{}'
return re.sub(r"[{0}]+".format(punc), "", src_str).upper()


if __name__ == '__main__':
src_dir = sys.argv[1]
tsv_file = src_dir + "/" + sys.argv[2] + ".tsv"
output_dir = sys.argv[3]
for file_path in os.listdir(src_dir + "/clips"):
if(os.path.exists(src_dir + "/wavs/" + file_path.split('.')[0] + ".wav")):
if (os.path.exists(src_dir + "/wavs/" + file_path.split('.')[0] +
".wav")):
continue
t_str = src_dir + "/clips/" + file_path
tt_str = src_dir + "/wavs/" + file_path.split('.')[0] + ".wav"
os.system("ffmpeg -i {0} -ac 1 -ar 16000 -f wav {1}".format(t_str, tt_str))
os.system("ffmpeg -i {0} -ac 1 -ar 16000 -f wav {1}".format(
t_str, tt_str))
import pandas
tsv_content = pandas.read_csv(tsv_file, sep="\t")
path_list = tsv_content["path"]
Expand Down
12 changes: 8 additions & 4 deletions examples/csj/s0/csj_tools/wn.0.parse.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@

# parse xml files and output simplified version

import xml.dom.minidom
import os
import sys
import multiprocessing


def parsexml(afile, outpath):
outfile = os.path.join(outpath, afile.split('/')[-1] + '.simp')

Expand Down Expand Up @@ -40,7 +40,8 @@ def parsexml(afile, outpath):
if suw.hasAttribute('OrthographicTranscription'):
txt = suw.getAttribute('OrthographicTranscription')
if suw.hasAttribute('PlainOrthographicTranscription'):
plaintxt = suw.getAttribute('PlainOrthographicTranscription')
plaintxt = suw.getAttribute(
'PlainOrthographicTranscription')
if suw.hasAttribute('PhoneticTranscription'):
prontxt = suw.getAttribute('PhoneticTranscription')
wlist.append(txt)
Expand All @@ -63,10 +64,11 @@ def parsexml(afile, outpath):
lemmasent = ' '.join(lemmalist)
dictlemmasent = ' '.join(dictlemmalist)
outrow = '{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format(
starttime, endtime, txtsent, plaintxtsent,
prontxtsent, lemmasent, dictlemmasent)
starttime, endtime, txtsent, plaintxtsent, prontxtsent,
lemmasent, dictlemmasent)
bw.write(outrow)


def procfolder_orig(apath, outpath):
count = 0
for afile in os.listdir(apath):
Expand All @@ -77,6 +79,7 @@ def procfolder_orig(apath, outpath):
count += 1
print('done: {} [{}]'.format(afile, count))


def procfolder(apath, outpath):
# count = 0
fnlist = list()
Expand All @@ -98,6 +101,7 @@ def procfolder(apath, outpath):
print('parallel {} threads done for {} files in total.'.format(
nthreads, len(fnlist)))


if __name__ == '__main__':
if len(sys.argv) < 3:
print("Usage: {} <in.csj.path> <out.csj.path>".format(sys.argv[0]))
Expand Down
Loading
Loading