Skip to content

Commit

Permalink
refactor(yaml): Config ctc/cmvn/tokenizer in train.yaml (#2205)
Browse files Browse the repository at this point in the history
* refactor(yaml): Config ctc/cmvn/tokenizer in train.yaml

* refactor(yaml): pass training

* refactor(yaml): try to pass unittest

* refactor(yaml): remove lfmmi

* refactor(yaml): nst recipe

* [refactor] refine run.sh

* [refactor] refine run.sh

* [refactor] rebase main

* [refactor] try to pass ut

* [refactor] refine librispeech in next PR

* [refactor] add todo

* [refactor] refine paraformer in next PR

* [refactor] make sos = 2

* [refactor] make sos = 2

* [refactor] try to pass ut

* [refactor] refine onnx_gpu

* [refactor] try to pass ut

* [refactor] try to pass ut

* [refactor] try to pass ut

* [refactor] try to pass ut

* refactor: pass decoding

* refactor: pass decoding

* refactor: pass decoding

* refactor: refine tokenizer

* refactor: try to pass ut
  • Loading branch information
xingchensong authored Dec 13, 2023
1 parent cd9c93a commit fac1f0c
Show file tree
Hide file tree
Showing 38 changed files with 613 additions and 128 deletions.
25 changes: 25 additions & 0 deletions examples/aishell/NST/conf/train_conformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,37 @@ decoder_conf:
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0

tokenizer: char
tokenizer_conf:
symbol_table_path: 'data/dict/lang_char.txt'
split_with_space: false
bpe_path: null
non_lang_syms_path: null
is_multilingual: false
num_languages: 1
special_tokens:
<blank>: 0
<unk>: 1
<sos>: 2
<eos>: 2

ctc: ctc
ctc_conf:
ctc_blank_id: 0

cmvn: global_cmvn
cmvn_conf:
cmvn_file: 'data/train/global_cmvn'
is_json_cmvn: true

# hybrid CTC/attention
model: asr_model
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false

dataset: asr
dataset_conf:
filter_conf:
max_length: 1200
Expand Down
9 changes: 0 additions & 9 deletions examples/aishell/NST/run_nst.sh
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ data_type=shard
num_utts_per_shard=1000
train_set=train
train_config=conf/train_conformer.yaml
cmvn=true
average_checkpoint=true
target_pt=80
decode_checkpoint=$dir/$target_pt.pt
Expand Down Expand Up @@ -113,9 +112,6 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
dist_backend="nccl"
# the global_cmvn file need to be calculated by combining both supervised/unsupervised datasets,
# and it should be positioned at data/${train_set}/global_cmvn .
cmvn_opts=
$cmvn && cp data/${train_set}/global_cmvn $dir/global_cmvn
$cmvn && cmvn_opts="--cmvn ${dir}/global_cmvn"

# train.py rewrite $train_config to $dir/train.yaml with model input
# and output dimension, and $dir/train.yaml will be used for inference
Expand All @@ -133,14 +129,12 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
--train_engine ${train_engine} \
--config $train_config \
--data_type $data_type \
--symbol_table $dict \
--train_data data/$train_set/$data_list \
--cv_data data/dev/data.list \
${checkpoint:+--checkpoint $checkpoint} \
--model_dir $dir \
--ddp.dist_backend $dist_backend \
--num_workers 1 \
$cmvn_opts \
--pin_memory \
--deepspeed_config ${deepspeed_config} \
--deepspeed.save_states ${deepspeed_save_states}
Expand Down Expand Up @@ -190,7 +184,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
--beam_size 10 \
--batch_size 1 \
--penalty 0.0 \
--dict $dict \
--ctc_weight $ctc_weight \
--reverse_weight $reverse_weight \
--result_file $test_dir/text \
Expand All @@ -216,7 +209,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
--beam_size 10 \
--batch_size 1 \
--penalty 0.0 \
--dict $dict \
--ctc_weight $ctc_weight \
--reverse_weight $reverse_weight \
--result_file $dev_dir/text \
Expand Down Expand Up @@ -275,7 +267,6 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
--beam_size 10 \
--batch_size 1 \
--penalty 0.0 \
--dict $dict \
--ctc_weight $ctc_weight \
--reverse_weight $reverse_weight \
--result_file data/train/${dir_split}data_sublist${job_num}/${hypo_name} \
Expand Down
24 changes: 24 additions & 0 deletions examples/aishell/rnnt/conf/conformer_rnnt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,29 @@ decoder_conf:
self_attention_dropout_rate: 0.1
src_attention_dropout_rate: 0.1

tokenizer: char
tokenizer_conf:
symbol_table_path: 'data/dict/lang_char.txt'
split_with_space: false
bpe_path: null
non_lang_syms_path: null
is_multilingual: false
num_languages: 1
special_tokens:
<blank>: 0
<unk>: 1
<sos>: 2
<eos>: 2

ctc: ctc
ctc_conf:
ctc_blank_id: 0

cmvn: global_cmvn
cmvn_conf:
cmvn_file: 'data/train/global_cmvn'
is_json_cmvn: true

# hybrid transducer+ctc+attention
model: transducer
model_conf:
Expand All @@ -59,6 +82,7 @@ model_conf:
length_normalized_loss: false
reverse_weight: 0.3

dataset: asr
dataset_conf:
filter_conf:
max_length: 40960
Expand Down
24 changes: 24 additions & 0 deletions examples/aishell/rnnt/conf/conformer_u2pp_rnnt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,29 @@ decoder_conf:
self_attention_dropout_rate: 0.1
src_attention_dropout_rate: 0.1

tokenizer: char
tokenizer_conf:
symbol_table_path: 'data/dict/lang_char.txt'
split_with_space: false
bpe_path: null
non_lang_syms_path: null
is_multilingual: false
num_languages: 1
special_tokens:
<blank>: 0
<unk>: 1
<sos>: 2
<eos>: 2

ctc: ctc
ctc_conf:
ctc_blank_id: 0

cmvn: global_cmvn
cmvn_conf:
cmvn_file: 'data/train/global_cmvn'
is_json_cmvn: true

# hybrid transducer+ctc+attention
model: transducer
model_conf:
Expand All @@ -63,6 +86,7 @@ model_conf:
length_normalized_loss: false
reverse_weight: 0.3

dataset: asr
dataset_conf:
filter_conf:
max_length: 40960
Expand Down
24 changes: 24 additions & 0 deletions examples/aishell/rnnt/conf/example_embedding_predictor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,29 @@ decoder_conf:
self_attention_dropout_rate: 0.1
src_attention_dropout_rate: 0.1

tokenizer: char
tokenizer_conf:
symbol_table_path: 'data/dict/lang_char.txt'
split_with_space: false
bpe_path: null
non_lang_syms_path: null
is_multilingual: false
num_languages: 1
special_tokens:
<blank>: 0
<unk>: 1
<sos>: 2
<eos>: 2

ctc: ctc
ctc_conf:
ctc_blank_id: 0

cmvn: global_cmvn
cmvn_conf:
cmvn_file: 'data/train/global_cmvn'
is_json_cmvn: true

# hybrid transducer+ctc+attention
model: transducer
model_conf:
Expand All @@ -55,6 +78,7 @@ model_conf:
length_normalized_loss: false
reverse_weight: 0.3

dataset: asr
dataset_conf:
filter_conf:
max_length: 40960
Expand Down
12 changes: 2 additions & 10 deletions examples/aishell/rnnt/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ num_utts_per_shard=1000

train_set=train
train_config=conf/conformer_u2pp_rnnt.yaml
cmvn=true
dir=exp/conformer_rnnt
checkpoint=

Expand Down Expand Up @@ -92,11 +91,10 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
mkdir -p $(dirname $dict)
echo "<blank> 0" > ${dict} # 0 is for "blank" in CTC
echo "<unk> 1" >> ${dict} # <unk> must be 1
echo "<sos/eos> 2" >> $dict
tools/text2token.py -s 1 -n 1 data/train/text | cut -f 2- -d" " \
| tr " " "\n" | sort | uniq | grep -a -v -e '^\s*$' | \
awk '{print $0 " " NR+1}' >> ${dict}
num_token=$(cat $dict | wc -l)
echo "<sos/eos> $num_token" >> $dict
awk '{print $0 " " NR+2}' >> ${dict}
fi

if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
Expand All @@ -118,9 +116,6 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
# Use "nccl" if it works, otherwise use "gloo"
dist_backend="nccl"
cmvn_opts=
$cmvn && cp data/${train_set}/global_cmvn $dir
$cmvn && cmvn_opts="--cmvn ${dir}/global_cmvn"

# train.py rewrite $train_config to $dir/train.yaml with model input
# and output dimension, and $dir/train.yaml will be used for inference
Expand All @@ -137,14 +132,12 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
--train_engine ${train_engine} \
--config $train_config \
--data_type $data_type \
--symbol_table $dict \
--train_data data/$train_set/data.list \
--cv_data data/dev/data.list \
${checkpoint:+--checkpoint $checkpoint} \
--model_dir $dir \
--ddp.dist_backend $dist_backend \
--num_workers 1 \
$cmvn_opts \
--pin_memory \
--deepspeed_config ${deepspeed_config} \
--deepspeed.save_states ${deepspeed_save_states}
Expand Down Expand Up @@ -183,7 +176,6 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
--beam_size 10 \
--batch_size 32 \
--penalty 0.0 \
--dict $dict \
--ctc_weight $rescore_ctc_weight \
--transducer_weight $rescore_transducer_weight \
--attn_weight $rescore_attn_weight \
Expand Down
25 changes: 25 additions & 0 deletions examples/aishell/s0/conf/train_conformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,37 @@ decoder_conf:
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0

tokenizer: char
tokenizer_conf:
symbol_table_path: 'data/dict/lang_char.txt'
split_with_space: false
bpe_path: null
non_lang_syms_path: null
is_multilingual: false
num_languages: 1
special_tokens:
<blank>: 0
<unk>: 1
<sos>: 2
<eos>: 2

ctc: ctc
ctc_conf:
ctc_blank_id: 0

cmvn: global_cmvn
cmvn_conf:
cmvn_file: 'data/train/global_cmvn'
is_json_cmvn: true

# hybrid CTC/attention
model: asr_model
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false

dataset: asr
dataset_conf:
filter_conf:
max_length: 40960
Expand Down
25 changes: 25 additions & 0 deletions examples/aishell/s0/conf/train_conformer_no_pos.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,37 @@ decoder_conf:
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0

tokenizer: char
tokenizer_conf:
symbol_table_path: 'data/dict/lang_char.txt'
split_with_space: false
bpe_path: null
non_lang_syms_path: null
is_multilingual: false
num_languages: 1
special_tokens:
<blank>: 0
<unk>: 1
<sos>: 2
<eos>: 2

ctc: ctc
ctc_conf:
ctc_blank_id: 0

cmvn: global_cmvn
cmvn_conf:
cmvn_file: 'data/train/global_cmvn'
is_json_cmvn: true

# hybrid CTC/attention
model: asr_model
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false

dataset: asr
dataset_conf:
filter_conf:
max_length: 40960
Expand Down
25 changes: 25 additions & 0 deletions examples/aishell/s0/conf/train_ebranchformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,37 @@ decoder_conf:
self_attention_dropout_rate: 0.1
src_attention_dropout_rate: 0.1

tokenizer: char
tokenizer_conf:
symbol_table_path: 'data/dict/lang_char.txt'
split_with_space: false
bpe_path: null
non_lang_syms_path: null
is_multilingual: false
num_languages: 1
special_tokens:
<blank>: 0
<unk>: 1
<sos>: 2
<eos>: 2

ctc: ctc
ctc_conf:
ctc_blank_id: 0

cmvn: global_cmvn
cmvn_conf:
cmvn_file: 'data/train/global_cmvn'
is_json_cmvn: true

# hybrid CTC/attention
model: asr_model
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false

dataset: asr
dataset_conf:
filter_conf:
max_length: 40960
Expand Down
Loading

0 comments on commit fac1f0c

Please sign in to comment.