From 77b7198fbe2fc6e9b45909870cc859dd41f0fd3d Mon Sep 17 00:00:00 2001 From: xingchensong Date: Fri, 8 Dec 2023 23:21:55 +0800 Subject: [PATCH 1/9] [refactor] use model_class to unify module init --- .../aishell/rnnt/conf/conformer_rnnt.yaml | 3 +- .../rnnt/conf/conformer_u2pp_rnnt.yaml | 3 +- .../conf/example_embedding_predictor.yaml | 3 +- .../s0/conf/train_unified_conformer_ctl.yaml | 4 +- examples/aishell/s0/run.sh | 6 +- wenet/utils/init_model.py | 193 +++++++++--------- 6 files changed, 104 insertions(+), 108 deletions(-) diff --git a/examples/aishell/rnnt/conf/conformer_rnnt.yaml b/examples/aishell/rnnt/conf/conformer_rnnt.yaml index 3af76ed78..e162f59d2 100644 --- a/examples/aishell/rnnt/conf/conformer_rnnt.yaml +++ b/examples/aishell/rnnt/conf/conformer_rnnt.yaml @@ -17,7 +17,7 @@ encoder_conf: pos_enc_layer_type: 'rel_pos' selfattention_layer_type: 'rel_selfattn' - +joint: transducerjoint joint_conf: enc_output_size: 256 pred_output_size: 256 @@ -50,6 +50,7 @@ decoder_conf: src_attention_dropout_rate: 0.1 # hybrid transducer+ctc+attention +model: transducer model_conf: transducer_weight: 0.75 ctc_weight: 0.1 diff --git a/examples/aishell/rnnt/conf/conformer_u2pp_rnnt.yaml b/examples/aishell/rnnt/conf/conformer_u2pp_rnnt.yaml index 3481f20b3..3bd8ab36c 100644 --- a/examples/aishell/rnnt/conf/conformer_u2pp_rnnt.yaml +++ b/examples/aishell/rnnt/conf/conformer_u2pp_rnnt.yaml @@ -21,7 +21,7 @@ encoder_conf: cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster use_dynamic_left_chunk: false - +joint: transducerjoint joint_conf: enc_output_size: 256 pred_output_size: 256 @@ -54,6 +54,7 @@ decoder_conf: src_attention_dropout_rate: 0.1 # hybrid transducer+ctc+attention +model: transducer model_conf: transducer_weight: 0.75 ctc_weight: 0.1 diff --git a/examples/aishell/rnnt/conf/example_embedding_predictor.yaml b/examples/aishell/rnnt/conf/example_embedding_predictor.yaml index ce701b57c..3f1423169 100644 --- a/examples/aishell/rnnt/conf/example_embedding_predictor.yaml +++ b/examples/aishell/rnnt/conf/example_embedding_predictor.yaml @@ -15,7 +15,7 @@ encoder_conf: pos_enc_layer_type: 'rel_pos' selfattention_layer_type: 'rel_selfattn' - +joint: transducerjoint joint_conf: enc_output_size: 256 pred_output_size: 320 @@ -46,6 +46,7 @@ decoder_conf: src_attention_dropout_rate: 0.1 # hybrid transducer+ctc+attention +model: transducer model_conf: transducer_weight: 0.4 ctc_weight: 0.2 diff --git a/examples/aishell/s0/conf/train_unified_conformer_ctl.yaml b/examples/aishell/s0/conf/train_unified_conformer_ctl.yaml index 188bdb414..ea2f548cd 100644 --- a/examples/aishell/s0/conf/train_unified_conformer_ctl.yaml +++ b/examples/aishell/s0/conf/train_unified_conformer_ctl.yaml @@ -1,6 +1,6 @@ # network architecture # encoder related -encoder: conformer +encoder: dual_conformer encoder_conf: output_size: 256 # dimension of attention attention_heads: 4 @@ -32,8 +32,8 @@ decoder_conf: self_attention_dropout_rate: 0.0 src_attention_dropout_rate: 0.0 -ctlmodel: true # hybrid CTC/attention +model: ctlmodel model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option diff --git a/examples/aishell/s0/run.sh b/examples/aishell/s0/run.sh index 3f940d1ea..ee126e6d0 100644 --- a/examples/aishell/s0/run.sh +++ b/examples/aishell/s0/run.sh @@ -304,8 +304,10 @@ if [ ${stage} -le 9 ] && [ ${stop_stage} -ge 9 ]; then # 9.1 Build token level bigram fst for LF-MMI training tools/k2/prepare_mmi.sh data/train/ data/dev data/local/lfmmi - # 9.2 Run LF-MMI training from stage 4, with below new args - # --lfmmi_dir data/local/lfmmi + # 9.2 Run LF-MMI training from stage 4, modify below args in train.yaml + # model: k2model + # model_conf: + # lfmmi_dir data/local/lfmmi # 9.3 Run HLG decode from stage 8.2 fi diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index 31dff5531..4f83af30a 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -22,27 +22,58 @@ from wenet.transformer.asr_model import ASRModel from wenet.transformer.cmvn import GlobalCMVN from wenet.transformer.ctc import CTC +from wenet.transformer.encoder import TransformerEncoder, ConformerEncoder from wenet.transformer.decoder import BiTransformerDecoder, TransformerDecoder from wenet.branchformer.encoder import BranchformerEncoder from wenet.e_branchformer.encoder import EBranchformerEncoder from wenet.squeezeformer.encoder import SqueezeformerEncoder from wenet.efficient_conformer.encoder import EfficientConformerEncoder +from wenet.ctl_model.encoder import DualTransformerEncoder, DualConformerEncoder from wenet.ctl_model.asr_model_ctl import CTLModel from wenet.whisper.whisper import Whisper from wenet.utils.cmvn import load_cmvn from wenet.utils.checkpoint import load_checkpoint, load_trained_modules +WENET_ENCODER_CLASSES = { + "transformer": TransformerEncoder, + "conformer": ConformerEncoder, + "squeezeformer": SqueezeformerEncoder, + "efficientConformer": EfficientConformerEncoder, + "branchformer": BranchformerEncoder, + "e_branchformer": EBranchformerEncoder, + "dual_transformer": DualTransformerEncoder, + "dual_conformer": DualConformerEncoder, +} + +WENET_DECODER_CLASSES = { + "transformer": TransformerDecoder, + "bitransformer": BiTransformerDecoder, +} + +WENET_CTC_CLASSES = { + "ctc": CTC, +} + +WENET_PREDICTOR_CLASSES = { + "rnn": RNNPredictor, + "embedding": EmbeddingPredictor, + "conv": ConvPredictor, +} + +WENET_JOINT_CLASSES = { + "transducerjoint": TransducerJoint, +} + +WENET_MODEL_CLASSES = { + "asrmodel": ASRModel, + "ctlmodel": CTLModel, + "whisper": Whisper, + "k2model": K2Model, + "transducer": Transducer, +} + def init_model(args, configs): - if 'paraformer' in configs: - """ NOTE(Mddct): support fintune paraformer, if there is a need for - sanmencoder/decoder in the future, simplify here. - """ - from wenet.utils.init_ali_paraformer import (init_model as - init_ali_paraformer_model) - model, configs = init_ali_paraformer_model(configs, args.checkpoint) - print(configs) - return model, configs if configs['cmvn_file'] is not None: mean, istd = load_cmvn(configs['cmvn_file'], configs['is_json_cmvn']) @@ -57,101 +88,61 @@ def init_model(args, configs): encoder_type = configs.get('encoder', 'conformer') decoder_type = configs.get('decoder', 'bitransformer') + ctc_type = configs.get('ctc', 'ctc') - if 'ctlmodel' in configs: - from wenet.ctl_model.encoder import DualConformerEncoder as ConformerEncoder - from wenet.ctl_model.encoder import DualTransformerEncoder as TransformerEncoder - else: - from wenet.transformer.encoder import ConformerEncoder, TransformerEncoder - - if encoder_type == 'conformer': - encoder = ConformerEncoder(input_dim, - global_cmvn=global_cmvn, - **configs['encoder_conf']) - elif encoder_type == 'squeezeformer': - encoder = SqueezeformerEncoder(input_dim, - global_cmvn=global_cmvn, - **configs['encoder_conf']) - elif encoder_type == 'efficientConformer': - encoder = EfficientConformerEncoder( - input_dim, - global_cmvn=global_cmvn, - **configs['encoder_conf'], - **configs['encoder_conf']['efficient_conf'] - if 'efficient_conf' in configs['encoder_conf'] else {}) - elif encoder_type == 'branchformer': - encoder = BranchformerEncoder(input_dim, - global_cmvn=global_cmvn, - **configs['encoder_conf']) - elif encoder_type == 'e_branchformer': - encoder = EBranchformerEncoder(input_dim, - global_cmvn=global_cmvn, - **configs['encoder_conf']) - else: - encoder = TransformerEncoder(input_dim, - global_cmvn=global_cmvn, - **configs['encoder_conf']) - if decoder_type == 'transformer': - decoder = TransformerDecoder(vocab_size, encoder.output_size(), - **configs['decoder_conf']) - else: - assert 0.0 < configs['model_conf']['reverse_weight'] < 1.0 - assert configs['decoder_conf']['r_num_blocks'] > 0 - decoder = BiTransformerDecoder(vocab_size, encoder.output_size(), - **configs['decoder_conf']) - ctc = CTC(vocab_size, - encoder.output_size(), - blank_id=configs['ctc_conf']['ctc_blank_id'] - if 'ctc_conf' in configs else 0) - - # Init joint CTC/Attention or Transducer model - if 'predictor' in configs: + encoder = WENET_ENCODER_CLASSES[encoder_type]( + input_dim, + global_cmvn=global_cmvn, + **configs['encoder_conf'], + **configs['encoder_conf']['efficient_conf'] + if 'efficient_conf' in configs['encoder_conf'] else {}) + + decoder = WENET_DECODER_CLASSES[decoder_type](vocab_size, + encoder.output_size(), + **configs['decoder_conf']) + + ctc = WENET_CTC_CLASSES[ctc_type]( + vocab_size, + encoder.output_size(), + blank_id=configs['ctc_conf']['ctc_blank_id'] + if 'ctc_conf' in configs else 0) + + if configs['model'] == "transducer": predictor_type = configs.get('predictor', 'rnn') - if predictor_type == 'rnn': - predictor = RNNPredictor(vocab_size, **configs['predictor_conf']) - elif predictor_type == 'embedding': - predictor = EmbeddingPredictor(vocab_size, - **configs['predictor_conf']) - elif predictor_type == 'conv': - predictor = ConvPredictor(vocab_size, **configs['predictor_conf']) - else: - raise NotImplementedError( - "only rnn, embedding and conv type support now") - joint = TransducerJoint(vocab_size, **configs['joint_conf']) - model = Transducer(vocab_size=vocab_size, - blank=0, - predictor=predictor, - encoder=encoder, - attention_decoder=decoder, - joint=joint, - ctc=ctc, - **configs['model_conf']) - elif 'ctlmodel' in configs: - model = CTLModel(vocab_size=vocab_size, - encoder=encoder, - decoder=decoder, - ctc=ctc, - **configs['model_conf']) - elif 'whisper' in configs: - model = Whisper(vocab_size=vocab_size, - encoder=encoder, - decoder=decoder, - ctc=ctc, - **configs['model_conf']) + joint_type = configs.get('joint', 'transducerjoint') + predictor = WENET_PREDICTOR_CLASSES[predictor_type]( + vocab_size, **configs['predictor_conf']) + joint = WENET_JOINT_CLASSES[joint_type](vocab_size, + **configs['joint_conf']) + model = WENET_MODEL_CLASSES[configs['model']]( + vocab_size=vocab_size, + blank=0, + predictor=predictor, + encoder=encoder, + attention_decoder=decoder, + joint=joint, + ctc=ctc, + special_tokens=configs['tokenizer_conf']['special_tokens'], + **configs['model_conf']) + elif configs['model'] == 'paraformer': + """ NOTE(Mddct): support fintune paraformer, if there is a need for + sanmencoder/decoder in the future, simplify here. + """ + # TODO(Mddct): refine this + from wenet.utils.init_ali_paraformer import (init_model as + init_ali_paraformer_model) + model, configs = init_ali_paraformer_model(configs, args.checkpoint) + print(configs) + return model, configs else: - if configs.get('lfmmi_dir', '') != '': - model = K2Model(vocab_size=vocab_size, - encoder=encoder, - decoder=decoder, - ctc=ctc, - lfmmi_dir=configs['lfmmi_dir'], - **configs['model_conf']) - else: - model = ASRModel(vocab_size=vocab_size, - encoder=encoder, - decoder=decoder, - ctc=ctc, - **configs['model_conf']) + model = WENET_MODEL_CLASSES[configs['model']]( + vocab_size=vocab_size, + encoder=encoder, + decoder=decoder, + ctc=ctc, + special_tokens=configs['tokenizer_conf']['special_tokens'], + **configs['model_conf']) + # If specify checkpoint, load some info from checkpoint if args.checkpoint is not None: infos = load_checkpoint(model, args.checkpoint) From f7b57ff89dd55ddb5ae964e7b6b1a188b1e46041 Mon Sep 17 00:00:00 2001 From: xingchensong Date: Sat, 9 Dec 2023 12:02:49 +0800 Subject: [PATCH 2/9] [refactor] add config in yaml --- examples/aishell/s0/conf/train_conformer.yaml | 1 + examples/aishell/s0/conf/train_conformer_no_pos.yaml | 1 + examples/aishell/s0/conf/train_ebranchformer.yaml | 1 + examples/aishell/s0/conf/train_transformer.yaml | 1 + examples/aishell/s0/conf/train_u2++_branchformer.yaml | 1 + examples/aishell/s0/conf/train_u2++_conformer.yaml | 1 + examples/aishell/s0/conf/train_u2++_conformer_1.8B.yaml | 1 + examples/aishell/s0/conf/train_u2++_efficonformer_v1.yaml | 1 + .../aishell/s0/conf/train_u2++_efficonformer_v1_stream.yaml | 1 + examples/aishell/s0/conf/train_u2++_efficonformer_v2.yaml | 1 + examples/aishell/s0/conf/train_u2++_lite_conformer.yaml | 1 + examples/aishell/s0/conf/train_u2++_transformer.yaml | 1 + examples/aishell/s0/conf/train_unified_conformer.yaml | 1 + examples/aishell/s0/conf/train_unified_transformer.yaml | 1 + wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py | 6 +----- 15 files changed, 15 insertions(+), 5 deletions(-) diff --git a/examples/aishell/s0/conf/train_conformer.yaml b/examples/aishell/s0/conf/train_conformer.yaml index b8ce511cd..d33ff0d2a 100644 --- a/examples/aishell/s0/conf/train_conformer.yaml +++ b/examples/aishell/s0/conf/train_conformer.yaml @@ -29,6 +29,7 @@ decoder_conf: src_attention_dropout_rate: 0.0 # hybrid CTC/attention +model: asrmodel model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option diff --git a/examples/aishell/s0/conf/train_conformer_no_pos.yaml b/examples/aishell/s0/conf/train_conformer_no_pos.yaml index a2d5d03f5..1e8aba35c 100644 --- a/examples/aishell/s0/conf/train_conformer_no_pos.yaml +++ b/examples/aishell/s0/conf/train_conformer_no_pos.yaml @@ -29,6 +29,7 @@ decoder_conf: src_attention_dropout_rate: 0.0 # hybrid CTC/attention +model: asrmodel model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option diff --git a/examples/aishell/s0/conf/train_ebranchformer.yaml b/examples/aishell/s0/conf/train_ebranchformer.yaml index edc952295..218cd13a6 100644 --- a/examples/aishell/s0/conf/train_ebranchformer.yaml +++ b/examples/aishell/s0/conf/train_ebranchformer.yaml @@ -32,6 +32,7 @@ decoder_conf: src_attention_dropout_rate: 0.1 # hybrid CTC/attention +model: asrmodel model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option diff --git a/examples/aishell/s0/conf/train_transformer.yaml b/examples/aishell/s0/conf/train_transformer.yaml index b7d7eee83..88cb293e1 100644 --- a/examples/aishell/s0/conf/train_transformer.yaml +++ b/examples/aishell/s0/conf/train_transformer.yaml @@ -24,6 +24,7 @@ decoder_conf: src_attention_dropout_rate: 0.0 # hybrid CTC/attention +model: asrmodel model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option diff --git a/examples/aishell/s0/conf/train_u2++_branchformer.yaml b/examples/aishell/s0/conf/train_u2++_branchformer.yaml index ef12c13a4..6256927f4 100644 --- a/examples/aishell/s0/conf/train_u2++_branchformer.yaml +++ b/examples/aishell/s0/conf/train_u2++_branchformer.yaml @@ -38,6 +38,7 @@ decoder_conf: src_attention_dropout_rate: 0.1 # hybrid CTC/attention +model: asrmodel model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option diff --git a/examples/aishell/s0/conf/train_u2++_conformer.yaml b/examples/aishell/s0/conf/train_u2++_conformer.yaml index b4587bce3..cd25fec6f 100644 --- a/examples/aishell/s0/conf/train_u2++_conformer.yaml +++ b/examples/aishell/s0/conf/train_u2++_conformer.yaml @@ -34,6 +34,7 @@ decoder_conf: src_attention_dropout_rate: 0.1 # hybrid CTC/attention +model: asrmodel model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option diff --git a/examples/aishell/s0/conf/train_u2++_conformer_1.8B.yaml b/examples/aishell/s0/conf/train_u2++_conformer_1.8B.yaml index c13b4b295..ed91c2240 100644 --- a/examples/aishell/s0/conf/train_u2++_conformer_1.8B.yaml +++ b/examples/aishell/s0/conf/train_u2++_conformer_1.8B.yaml @@ -34,6 +34,7 @@ decoder_conf: src_attention_dropout_rate: 0.1 # hybrid CTC/attention +model: asrmodel model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option diff --git a/examples/aishell/s0/conf/train_u2++_efficonformer_v1.yaml b/examples/aishell/s0/conf/train_u2++_efficonformer_v1.yaml index 3d0de82db..654da120b 100644 --- a/examples/aishell/s0/conf/train_u2++_efficonformer_v1.yaml +++ b/examples/aishell/s0/conf/train_u2++_efficonformer_v1.yaml @@ -39,6 +39,7 @@ decoder_conf: src_attention_dropout_rate: 0.1 # hybrid CTC/attention +model: asrmodel model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option diff --git a/examples/aishell/s0/conf/train_u2++_efficonformer_v1_stream.yaml b/examples/aishell/s0/conf/train_u2++_efficonformer_v1_stream.yaml index 3b5a99a86..bc0cda55e 100644 --- a/examples/aishell/s0/conf/train_u2++_efficonformer_v1_stream.yaml +++ b/examples/aishell/s0/conf/train_u2++_efficonformer_v1_stream.yaml @@ -39,6 +39,7 @@ decoder_conf: src_attention_dropout_rate: 0.1 # hybrid CTC/attention +model: asrmodel model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option diff --git a/examples/aishell/s0/conf/train_u2++_efficonformer_v2.yaml b/examples/aishell/s0/conf/train_u2++_efficonformer_v2.yaml index c23e1b64d..fd2087513 100644 --- a/examples/aishell/s0/conf/train_u2++_efficonformer_v2.yaml +++ b/examples/aishell/s0/conf/train_u2++_efficonformer_v2.yaml @@ -39,6 +39,7 @@ decoder_conf: src_attention_dropout_rate: 0.1 # hybrid CTC/attention +model: asrmodel model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option diff --git a/examples/aishell/s0/conf/train_u2++_lite_conformer.yaml b/examples/aishell/s0/conf/train_u2++_lite_conformer.yaml index 1eb280de2..7433b51e5 100644 --- a/examples/aishell/s0/conf/train_u2++_lite_conformer.yaml +++ b/examples/aishell/s0/conf/train_u2++_lite_conformer.yaml @@ -34,6 +34,7 @@ decoder_conf: src_attention_dropout_rate: 0.1 # hybrid CTC/attention +model: asrmodel model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option diff --git a/examples/aishell/s0/conf/train_u2++_transformer.yaml b/examples/aishell/s0/conf/train_u2++_transformer.yaml index 44b4d4be7..3f1cdbe9a 100644 --- a/examples/aishell/s0/conf/train_u2++_transformer.yaml +++ b/examples/aishell/s0/conf/train_u2++_transformer.yaml @@ -27,6 +27,7 @@ decoder_conf: src_attention_dropout_rate: 0.0 # hybrid CTC/attention +model: asrmodel model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option diff --git a/examples/aishell/s0/conf/train_unified_conformer.yaml b/examples/aishell/s0/conf/train_unified_conformer.yaml index 978d3d91c..49d043850 100644 --- a/examples/aishell/s0/conf/train_unified_conformer.yaml +++ b/examples/aishell/s0/conf/train_unified_conformer.yaml @@ -33,6 +33,7 @@ decoder_conf: src_attention_dropout_rate: 0.0 # hybrid CTC/attention +model: asrmodel model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option diff --git a/examples/aishell/s0/conf/train_unified_transformer.yaml b/examples/aishell/s0/conf/train_unified_transformer.yaml index 9d7a38687..7aa524244 100644 --- a/examples/aishell/s0/conf/train_unified_transformer.yaml +++ b/examples/aishell/s0/conf/train_unified_transformer.yaml @@ -26,6 +26,7 @@ decoder_conf: src_attention_dropout_rate: 0.0 # hybrid CTC/attention +model: asrmodel model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option diff --git a/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py b/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py index 0333f9833..43dfe443a 100644 --- a/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py +++ b/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py @@ -44,11 +44,6 @@ def convert_to_wenet_yaml(tokenizer, dims, wenet_yaml_path: str): configs = {} - configs['whisper'] = True - configs['whisper_conf'] = {} - configs['whisper_conf']['is_multilingual'] = dims['n_vocab'] >= 51865 - configs['whisper_conf']['num_languages'] = dims['n_vocab'] - 51765 - \ - int(configs['whisper_conf']['is_multilingual']) configs['input_dim'] = dims['n_mels'] configs['output_dim'] = dims['n_vocab'] assert dims['n_vocab'] == tokenizer.encoding.n_vocab, "{} v.s. {}".format( @@ -94,6 +89,7 @@ def convert_to_wenet_yaml(tokenizer, dims, wenet_yaml_path: str): configs['ctc_conf'] = {} configs['ctc_conf']['ctc_blank_id'] = tokenizer.no_speech + configs['model'] = "whisper" configs['model_conf'] = {} configs['model_conf']['ctc_weight'] = 0.3 configs['model_conf']['lsm_weight'] = 0.1 From 3816cd859284f544192c84fc77b6013dbc13bd08 Mon Sep 17 00:00:00 2001 From: xingchensong Date: Sat, 9 Dec 2023 12:36:42 +0800 Subject: [PATCH 3/9] [refactor] try to pass unit test --- wenet/utils/init_model.py | 4 +- ...onvert_whisper_to_wenet_config_and_ckpt.py | 37 +++++++++++++------ 2 files changed, 27 insertions(+), 14 deletions(-) diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index 4f83af30a..8914e0067 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -122,7 +122,7 @@ def init_model(args, configs): attention_decoder=decoder, joint=joint, ctc=ctc, - special_tokens=configs['tokenizer_conf']['special_tokens'], + special_tokens=configs['tokenizer_conf'].get('special_tokens', None), **configs['model_conf']) elif configs['model'] == 'paraformer': """ NOTE(Mddct): support fintune paraformer, if there is a need for @@ -140,7 +140,7 @@ def init_model(args, configs): encoder=encoder, decoder=decoder, ctc=ctc, - special_tokens=configs['tokenizer_conf']['special_tokens'], + special_tokens=configs['tokenizer_conf'].get('special_tokens', None), **configs['model_conf']) # If specify checkpoint, load some info from checkpoint diff --git a/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py b/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py index 43dfe443a..4a51fb9b5 100644 --- a/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py +++ b/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py @@ -86,6 +86,31 @@ def convert_to_wenet_yaml(tokenizer, dims, wenet_yaml_path: str): configs['decoder_conf']['key_bias'] = False configs['decoder_conf']['activation_type'] = "gelu" + configs['tokenizer'] = 'whisper' + configs['tokenizer_conf'] = {} + configs['tokenizer_conf']['is_multilingual'] = dims['n_vocab'] >= 51865 + configs['tokenizer_conf']['num_languages'] = dims['n_vocab'] - 51765 - \ + int(configs['tokenizer_conf']['is_multilingual']) + configs['tokenizer_conf']['split_with_space'] = False + configs['tokenizer_conf']['bpe_path'] = None + configs['tokenizer_conf']['symbol_table_path'] = None + configs['tokenizer_conf']['non_lang_syms_path'] = None + configs['tokenizer_conf']['special_tokens'] = {} + configs['tokenizer_conf']['special_tokens']['sot'] = tokenizer.sot + configs['tokenizer_conf']['special_tokens']['eot'] = tokenizer.sot + configs['tokenizer_conf']['special_tokens'][ + 'sot_prev'] = tokenizer.sot_prev + configs['tokenizer_conf']['special_tokens'][ + 'transcribe'] = tokenizer.transcribe + configs['tokenizer_conf']['special_tokens'][ + 'translate'] = tokenizer.translate + configs['tokenizer_conf']['special_tokens'][ + 'no_timestamps'] = tokenizer.no_timestamps + configs['tokenizer_conf']['special_tokens'][ + 'no_speech'] = tokenizer.no_speech + configs['tokenizer_conf']['special_tokens']['timestamp_begin'] = \ + tokenizer.timestamp_begin + configs['ctc_conf'] = {} configs['ctc_conf']['ctc_blank_id'] = tokenizer.no_speech @@ -94,18 +119,6 @@ def convert_to_wenet_yaml(tokenizer, dims, wenet_yaml_path: str): configs['model_conf']['ctc_weight'] = 0.3 configs['model_conf']['lsm_weight'] = 0.1 configs['model_conf']['length_normalized_loss'] = False - configs['model_conf']['special_tokens'] = {} - configs['model_conf']['special_tokens']['sot'] = tokenizer.sot - configs['model_conf']['special_tokens']['eot'] = tokenizer.sot - configs['model_conf']['special_tokens']['sot_prev'] = tokenizer.sot_prev - configs['model_conf']['special_tokens'][ - 'transcribe'] = tokenizer.transcribe - configs['model_conf']['special_tokens']['translate'] = tokenizer.translate - configs['model_conf']['special_tokens'][ - 'no_timestamps'] = tokenizer.no_timestamps - configs['model_conf']['special_tokens']['no_speech'] = tokenizer.no_speech - configs['model_conf']['special_tokens']['timestamp_begin'] = \ - tokenizer.timestamp_begin configs['dataset_conf'] = {} configs['dataset_conf']['filter_conf'] = {} From 8b97127258d129f8f4100c4cf61157ba3454ade3 Mon Sep 17 00:00:00 2001 From: xingchensong Date: Sat, 9 Dec 2023 12:53:22 +0800 Subject: [PATCH 4/9] [refactor] try to pass unit test --- test/wenet/whisper/test_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/wenet/whisper/test_whisper.py b/test/wenet/whisper/test_whisper.py index 0eb0df9a7..50ed9efd9 100644 --- a/test/wenet/whisper/test_whisper.py +++ b/test/wenet/whisper/test_whisper.py @@ -361,7 +361,7 @@ def test_model(model, audio_path): # 6. Forward wenet.decoder wenet_tokens, _ = add_whisper_tokens( - configs['model_conf']['special_tokens'], + configs['tokenizer_conf']['special_tokens'], torch.tensor([dummy_tokens], dtype=torch.long), ignore_id=-1, task=task, From fb01adf55d228231e371ac8ae9d06f811b4eeea7 Mon Sep 17 00:00:00 2001 From: xingchensong Date: Sat, 9 Dec 2023 20:59:45 +0800 Subject: [PATCH 5/9] [refactor] set default model_type --- examples/aishell/NST/conf/train_conformer.yaml | 1 + examples/aishell2/rnnt/conf/conformer_u2pp_rnnt.yaml | 1 + wenet/utils/init_model.py | 11 +++++++---- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/aishell/NST/conf/train_conformer.yaml b/examples/aishell/NST/conf/train_conformer.yaml index 8499de2e9..ab4c80381 100644 --- a/examples/aishell/NST/conf/train_conformer.yaml +++ b/examples/aishell/NST/conf/train_conformer.yaml @@ -29,6 +29,7 @@ decoder_conf: src_attention_dropout_rate: 0.0 # hybrid CTC/attention +model: asrmodel model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option diff --git a/examples/aishell2/rnnt/conf/conformer_u2pp_rnnt.yaml b/examples/aishell2/rnnt/conf/conformer_u2pp_rnnt.yaml index cfb4b18b6..1cd58586c 100644 --- a/examples/aishell2/rnnt/conf/conformer_u2pp_rnnt.yaml +++ b/examples/aishell2/rnnt/conf/conformer_u2pp_rnnt.yaml @@ -52,6 +52,7 @@ decoder_conf: src_attention_dropout_rate: 0.1 # hybrid transducer+ctc+attention +model: transducer model_conf: transducer_weight: 0.75 ctc_weight: 0.1 diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index 8914e0067..5517eaa0f 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -107,7 +107,8 @@ def init_model(args, configs): blank_id=configs['ctc_conf']['ctc_blank_id'] if 'ctc_conf' in configs else 0) - if configs['model'] == "transducer": + model_type = configs.get('model', 'asrmodel') + if model_type == "transducer": predictor_type = configs.get('predictor', 'rnn') joint_type = configs.get('joint', 'transducerjoint') predictor = WENET_PREDICTOR_CLASSES[predictor_type]( @@ -122,9 +123,10 @@ def init_model(args, configs): attention_decoder=decoder, joint=joint, ctc=ctc, - special_tokens=configs['tokenizer_conf'].get('special_tokens', None), + special_tokens=configs['tokenizer_conf'].get( + 'special_tokens', None), **configs['model_conf']) - elif configs['model'] == 'paraformer': + elif model_type == 'paraformer': """ NOTE(Mddct): support fintune paraformer, if there is a need for sanmencoder/decoder in the future, simplify here. """ @@ -140,7 +142,8 @@ def init_model(args, configs): encoder=encoder, decoder=decoder, ctc=ctc, - special_tokens=configs['tokenizer_conf'].get('special_tokens', None), + special_tokens=configs['tokenizer_conf'].get( + 'special_tokens', None), **configs['model_conf']) # If specify checkpoint, load some info from checkpoint From 69ad73a0b6f04d9f13c599da6d887ca852186e13 Mon Sep 17 00:00:00 2001 From: xingchensong Date: Sat, 9 Dec 2023 21:34:42 +0800 Subject: [PATCH 6/9] [refactor] update yaml --- examples/aishell/NST/conf/train_conformer.yaml | 1 - examples/aishell/s0/conf/train_conformer.yaml | 1 - examples/aishell/s0/conf/train_conformer_no_pos.yaml | 1 - examples/aishell/s0/conf/train_ebranchformer.yaml | 1 - examples/aishell/s0/conf/train_transformer.yaml | 1 - examples/aishell/s0/conf/train_u2++_branchformer.yaml | 1 - examples/aishell/s0/conf/train_u2++_conformer.yaml | 1 - examples/aishell/s0/conf/train_u2++_conformer_1.8B.yaml | 1 - examples/aishell/s0/conf/train_u2++_efficonformer_v1.yaml | 1 - .../aishell/s0/conf/train_u2++_efficonformer_v1_stream.yaml | 1 - examples/aishell/s0/conf/train_u2++_efficonformer_v2.yaml | 1 - examples/aishell/s0/conf/train_u2++_lite_conformer.yaml | 1 - examples/aishell/s0/conf/train_u2++_transformer.yaml | 1 - examples/aishell/s0/conf/train_unified_conformer.yaml | 1 - examples/aishell/s0/conf/train_unified_transformer.yaml | 1 - examples/aishell2/rnnt/conf/conformer_rnnt.yaml | 3 ++- examples/aishell2/rnnt/conf/conformer_u2pp_rnnt.yaml | 2 +- examples/librispeech/rnnt/conf/conformer_rnnt.yaml | 3 ++- 18 files changed, 5 insertions(+), 18 deletions(-) diff --git a/examples/aishell/NST/conf/train_conformer.yaml b/examples/aishell/NST/conf/train_conformer.yaml index ab4c80381..8499de2e9 100644 --- a/examples/aishell/NST/conf/train_conformer.yaml +++ b/examples/aishell/NST/conf/train_conformer.yaml @@ -29,7 +29,6 @@ decoder_conf: src_attention_dropout_rate: 0.0 # hybrid CTC/attention -model: asrmodel model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option diff --git a/examples/aishell/s0/conf/train_conformer.yaml b/examples/aishell/s0/conf/train_conformer.yaml index d33ff0d2a..b8ce511cd 100644 --- a/examples/aishell/s0/conf/train_conformer.yaml +++ b/examples/aishell/s0/conf/train_conformer.yaml @@ -29,7 +29,6 @@ decoder_conf: src_attention_dropout_rate: 0.0 # hybrid CTC/attention -model: asrmodel model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option diff --git a/examples/aishell/s0/conf/train_conformer_no_pos.yaml b/examples/aishell/s0/conf/train_conformer_no_pos.yaml index 1e8aba35c..a2d5d03f5 100644 --- a/examples/aishell/s0/conf/train_conformer_no_pos.yaml +++ b/examples/aishell/s0/conf/train_conformer_no_pos.yaml @@ -29,7 +29,6 @@ decoder_conf: src_attention_dropout_rate: 0.0 # hybrid CTC/attention -model: asrmodel model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option diff --git a/examples/aishell/s0/conf/train_ebranchformer.yaml b/examples/aishell/s0/conf/train_ebranchformer.yaml index 218cd13a6..edc952295 100644 --- a/examples/aishell/s0/conf/train_ebranchformer.yaml +++ b/examples/aishell/s0/conf/train_ebranchformer.yaml @@ -32,7 +32,6 @@ decoder_conf: src_attention_dropout_rate: 0.1 # hybrid CTC/attention -model: asrmodel model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option diff --git a/examples/aishell/s0/conf/train_transformer.yaml b/examples/aishell/s0/conf/train_transformer.yaml index 88cb293e1..b7d7eee83 100644 --- a/examples/aishell/s0/conf/train_transformer.yaml +++ b/examples/aishell/s0/conf/train_transformer.yaml @@ -24,7 +24,6 @@ decoder_conf: src_attention_dropout_rate: 0.0 # hybrid CTC/attention -model: asrmodel model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option diff --git a/examples/aishell/s0/conf/train_u2++_branchformer.yaml b/examples/aishell/s0/conf/train_u2++_branchformer.yaml index 6256927f4..ef12c13a4 100644 --- a/examples/aishell/s0/conf/train_u2++_branchformer.yaml +++ b/examples/aishell/s0/conf/train_u2++_branchformer.yaml @@ -38,7 +38,6 @@ decoder_conf: src_attention_dropout_rate: 0.1 # hybrid CTC/attention -model: asrmodel model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option diff --git a/examples/aishell/s0/conf/train_u2++_conformer.yaml b/examples/aishell/s0/conf/train_u2++_conformer.yaml index cd25fec6f..b4587bce3 100644 --- a/examples/aishell/s0/conf/train_u2++_conformer.yaml +++ b/examples/aishell/s0/conf/train_u2++_conformer.yaml @@ -34,7 +34,6 @@ decoder_conf: src_attention_dropout_rate: 0.1 # hybrid CTC/attention -model: asrmodel model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option diff --git a/examples/aishell/s0/conf/train_u2++_conformer_1.8B.yaml b/examples/aishell/s0/conf/train_u2++_conformer_1.8B.yaml index ed91c2240..c13b4b295 100644 --- a/examples/aishell/s0/conf/train_u2++_conformer_1.8B.yaml +++ b/examples/aishell/s0/conf/train_u2++_conformer_1.8B.yaml @@ -34,7 +34,6 @@ decoder_conf: src_attention_dropout_rate: 0.1 # hybrid CTC/attention -model: asrmodel model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option diff --git a/examples/aishell/s0/conf/train_u2++_efficonformer_v1.yaml b/examples/aishell/s0/conf/train_u2++_efficonformer_v1.yaml index 654da120b..3d0de82db 100644 --- a/examples/aishell/s0/conf/train_u2++_efficonformer_v1.yaml +++ b/examples/aishell/s0/conf/train_u2++_efficonformer_v1.yaml @@ -39,7 +39,6 @@ decoder_conf: src_attention_dropout_rate: 0.1 # hybrid CTC/attention -model: asrmodel model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option diff --git a/examples/aishell/s0/conf/train_u2++_efficonformer_v1_stream.yaml b/examples/aishell/s0/conf/train_u2++_efficonformer_v1_stream.yaml index bc0cda55e..3b5a99a86 100644 --- a/examples/aishell/s0/conf/train_u2++_efficonformer_v1_stream.yaml +++ b/examples/aishell/s0/conf/train_u2++_efficonformer_v1_stream.yaml @@ -39,7 +39,6 @@ decoder_conf: src_attention_dropout_rate: 0.1 # hybrid CTC/attention -model: asrmodel model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option diff --git a/examples/aishell/s0/conf/train_u2++_efficonformer_v2.yaml b/examples/aishell/s0/conf/train_u2++_efficonformer_v2.yaml index fd2087513..c23e1b64d 100644 --- a/examples/aishell/s0/conf/train_u2++_efficonformer_v2.yaml +++ b/examples/aishell/s0/conf/train_u2++_efficonformer_v2.yaml @@ -39,7 +39,6 @@ decoder_conf: src_attention_dropout_rate: 0.1 # hybrid CTC/attention -model: asrmodel model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option diff --git a/examples/aishell/s0/conf/train_u2++_lite_conformer.yaml b/examples/aishell/s0/conf/train_u2++_lite_conformer.yaml index 7433b51e5..1eb280de2 100644 --- a/examples/aishell/s0/conf/train_u2++_lite_conformer.yaml +++ b/examples/aishell/s0/conf/train_u2++_lite_conformer.yaml @@ -34,7 +34,6 @@ decoder_conf: src_attention_dropout_rate: 0.1 # hybrid CTC/attention -model: asrmodel model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option diff --git a/examples/aishell/s0/conf/train_u2++_transformer.yaml b/examples/aishell/s0/conf/train_u2++_transformer.yaml index 3f1cdbe9a..44b4d4be7 100644 --- a/examples/aishell/s0/conf/train_u2++_transformer.yaml +++ b/examples/aishell/s0/conf/train_u2++_transformer.yaml @@ -27,7 +27,6 @@ decoder_conf: src_attention_dropout_rate: 0.0 # hybrid CTC/attention -model: asrmodel model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option diff --git a/examples/aishell/s0/conf/train_unified_conformer.yaml b/examples/aishell/s0/conf/train_unified_conformer.yaml index 49d043850..978d3d91c 100644 --- a/examples/aishell/s0/conf/train_unified_conformer.yaml +++ b/examples/aishell/s0/conf/train_unified_conformer.yaml @@ -33,7 +33,6 @@ decoder_conf: src_attention_dropout_rate: 0.0 # hybrid CTC/attention -model: asrmodel model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option diff --git a/examples/aishell/s0/conf/train_unified_transformer.yaml b/examples/aishell/s0/conf/train_unified_transformer.yaml index 7aa524244..9d7a38687 100644 --- a/examples/aishell/s0/conf/train_unified_transformer.yaml +++ b/examples/aishell/s0/conf/train_unified_transformer.yaml @@ -26,7 +26,6 @@ decoder_conf: src_attention_dropout_rate: 0.0 # hybrid CTC/attention -model: asrmodel model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option diff --git a/examples/aishell2/rnnt/conf/conformer_rnnt.yaml b/examples/aishell2/rnnt/conf/conformer_rnnt.yaml index aeab0b180..eda66134a 100644 --- a/examples/aishell2/rnnt/conf/conformer_rnnt.yaml +++ b/examples/aishell2/rnnt/conf/conformer_rnnt.yaml @@ -17,7 +17,7 @@ encoder_conf: pos_enc_layer_type: 'rel_pos' selfattention_layer_type: 'rel_selfattn' - +joint: transducerjoint joint_conf: join_dim: 512 prejoin_linear: True @@ -48,6 +48,7 @@ decoder_conf: src_attention_dropout_rate: 0.1 # hybrid transducer+ctc+attention +model: transducer model_conf: transducer_weight: 0.75 ctc_weight: 0.1 diff --git a/examples/aishell2/rnnt/conf/conformer_u2pp_rnnt.yaml b/examples/aishell2/rnnt/conf/conformer_u2pp_rnnt.yaml index 1cd58586c..b216f9ff2 100644 --- a/examples/aishell2/rnnt/conf/conformer_u2pp_rnnt.yaml +++ b/examples/aishell2/rnnt/conf/conformer_u2pp_rnnt.yaml @@ -21,7 +21,7 @@ encoder_conf: cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster use_dynamic_left_chunk: false - +joint: transducerjoint joint_conf: join_dim: 512 prejoin_linear: True diff --git a/examples/librispeech/rnnt/conf/conformer_rnnt.yaml b/examples/librispeech/rnnt/conf/conformer_rnnt.yaml index 8a517ccca..48edf916e 100644 --- a/examples/librispeech/rnnt/conf/conformer_rnnt.yaml +++ b/examples/librispeech/rnnt/conf/conformer_rnnt.yaml @@ -17,7 +17,7 @@ encoder_conf: pos_enc_layer_type: 'rel_pos' selfattention_layer_type: 'rel_selfattn' - +joint: transducerjoint joint_conf: join_dim: 512 prejoin_linear: True @@ -48,6 +48,7 @@ decoder_conf: src_attention_dropout_rate: 0.1 # hybrid transducer+ctc+attention +model: transducer model_conf: transducer_weight: 0.75 ctc_weight: 0.1 From 8421355ec5965dcb4a73345b92da2f3de4c6f932 Mon Sep 17 00:00:00 2001 From: xingchensong Date: Sun, 10 Dec 2023 20:08:33 +0800 Subject: [PATCH 7/9] [refactor] add unit test for init_model --- .../aishell2/rnnt/conf/conformer_rnnt.yaml | 2 ++ .../rnnt/conf/conformer_u2pp_rnnt.yaml | 2 ++ .../librispeech/rnnt/conf/conformer_rnnt.yaml | 2 ++ test/wenet/utils/test_init_model.py | 34 +++++++++++++++++++ wenet/utils/init_model.py | 14 ++++---- 5 files changed, 47 insertions(+), 7 deletions(-) create mode 100644 test/wenet/utils/test_init_model.py diff --git a/examples/aishell2/rnnt/conf/conformer_rnnt.yaml b/examples/aishell2/rnnt/conf/conformer_rnnt.yaml index eda66134a..e162f59d2 100644 --- a/examples/aishell2/rnnt/conf/conformer_rnnt.yaml +++ b/examples/aishell2/rnnt/conf/conformer_rnnt.yaml @@ -19,6 +19,8 @@ encoder_conf: joint: transducerjoint joint_conf: + enc_output_size: 256 + pred_output_size: 256 join_dim: 512 prejoin_linear: True postjoin_linear: false diff --git a/examples/aishell2/rnnt/conf/conformer_u2pp_rnnt.yaml b/examples/aishell2/rnnt/conf/conformer_u2pp_rnnt.yaml index b216f9ff2..17e773cad 100644 --- a/examples/aishell2/rnnt/conf/conformer_u2pp_rnnt.yaml +++ b/examples/aishell2/rnnt/conf/conformer_u2pp_rnnt.yaml @@ -23,6 +23,8 @@ encoder_conf: joint: transducerjoint joint_conf: + enc_output_size: 256 + pred_output_size: 256 join_dim: 512 prejoin_linear: True postjoin_linear: false diff --git a/examples/librispeech/rnnt/conf/conformer_rnnt.yaml b/examples/librispeech/rnnt/conf/conformer_rnnt.yaml index 48edf916e..27413d29d 100644 --- a/examples/librispeech/rnnt/conf/conformer_rnnt.yaml +++ b/examples/librispeech/rnnt/conf/conformer_rnnt.yaml @@ -19,6 +19,8 @@ encoder_conf: joint: transducerjoint joint_conf: + enc_output_size: 256 + pred_output_size: 256 join_dim: 512 prejoin_linear: True postjoin_linear: false diff --git a/test/wenet/utils/test_init_model.py b/test/wenet/utils/test_init_model.py new file mode 100644 index 000000000..f73e181ec --- /dev/null +++ b/test/wenet/utils/test_init_model.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright [2023-12-10] + +import glob +import yaml + +from wenet.utils.init_model import init_model + + +class DummyArguments: + jit = False + enc_init = None + checkpoint = None + + +def test_init_model(): + configs = glob.glob("examples/*/*/conf/*.yaml") + args = DummyArguments() + for c in configs: + with open(c, 'r') as fin: + config = yaml.load(fin, Loader=yaml.FullLoader) + if 'fbank_conf' in config['dataset_conf']: + input_dim = config['dataset_conf']['fbank_conf']['num_mel_bins'] + elif 'log_mel_spectrogram_conf' in config['dataset_conf']: + input_dim = config['dataset_conf']['log_mel_spectrogram_conf'][ + 'num_mel_bins'] + else: + input_dim = config['dataset_conf']['mfcc_conf']['num_mel_bins'] + config['input_dim'] = input_dim + # TODO(xcsong): fix vocab_size + config['output_dim'] = 3000 + print("checking {} {}".format(c, config)) + init_model(args, config) diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index 5517eaa0f..279fa28dc 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -75,7 +75,7 @@ def init_model(args, configs): - if configs['cmvn_file'] is not None: + if configs.get('cmvn_file', None) is not None: mean, istd = load_cmvn(configs['cmvn_file'], configs['is_json_cmvn']) global_cmvn = GlobalCMVN( torch.from_numpy(mean).float(), @@ -115,7 +115,7 @@ def init_model(args, configs): vocab_size, **configs['predictor_conf']) joint = WENET_JOINT_CLASSES[joint_type](vocab_size, **configs['joint_conf']) - model = WENET_MODEL_CLASSES[configs['model']]( + model = WENET_MODEL_CLASSES[model_type]( vocab_size=vocab_size, blank=0, predictor=predictor, @@ -123,8 +123,8 @@ def init_model(args, configs): attention_decoder=decoder, joint=joint, ctc=ctc, - special_tokens=configs['tokenizer_conf'].get( - 'special_tokens', None), + special_tokens=configs.get('tokenizer_conf', + {}).get('special_tokens', None), **configs['model_conf']) elif model_type == 'paraformer': """ NOTE(Mddct): support fintune paraformer, if there is a need for @@ -137,13 +137,13 @@ def init_model(args, configs): print(configs) return model, configs else: - model = WENET_MODEL_CLASSES[configs['model']]( + model = WENET_MODEL_CLASSES[model_type]( vocab_size=vocab_size, encoder=encoder, decoder=decoder, ctc=ctc, - special_tokens=configs['tokenizer_conf'].get( - 'special_tokens', None), + special_tokens=configs.get('tokenizer_conf', + {}).get('special_tokens', None), **configs['model_conf']) # If specify checkpoint, load some info from checkpoint From d24337746d56f48317d85d861513288b7d79362e Mon Sep 17 00:00:00 2001 From: xingchensong Date: Sun, 10 Dec 2023 21:03:17 +0800 Subject: [PATCH 8/9] [refactor] fix comment --- wenet/utils/init_model.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index 279fa28dc..89845b0d2 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -34,7 +34,7 @@ from wenet.utils.cmvn import load_cmvn from wenet.utils.checkpoint import load_checkpoint, load_trained_modules -WENET_ENCODER_CLASSES = { +_WENET_ENCODER_CLASSES = { "transformer": TransformerEncoder, "conformer": ConformerEncoder, "squeezeformer": SqueezeformerEncoder, @@ -45,26 +45,26 @@ "dual_conformer": DualConformerEncoder, } -WENET_DECODER_CLASSES = { +_WENET_DECODER_CLASSES = { "transformer": TransformerDecoder, "bitransformer": BiTransformerDecoder, } -WENET_CTC_CLASSES = { +_WENET_CTC_CLASSES = { "ctc": CTC, } -WENET_PREDICTOR_CLASSES = { +_WENET_PREDICTOR_CLASSES = { "rnn": RNNPredictor, "embedding": EmbeddingPredictor, "conv": ConvPredictor, } -WENET_JOINT_CLASSES = { +_WENET_JOINT_CLASSES = { "transducerjoint": TransducerJoint, } -WENET_MODEL_CLASSES = { +_WENET_MODEL_CLASSES = { "asrmodel": ASRModel, "ctlmodel": CTLModel, "whisper": Whisper, @@ -90,18 +90,18 @@ def init_model(args, configs): decoder_type = configs.get('decoder', 'bitransformer') ctc_type = configs.get('ctc', 'ctc') - encoder = WENET_ENCODER_CLASSES[encoder_type]( + encoder = _WENET_ENCODER_CLASSES[encoder_type]( input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'], **configs['encoder_conf']['efficient_conf'] if 'efficient_conf' in configs['encoder_conf'] else {}) - decoder = WENET_DECODER_CLASSES[decoder_type](vocab_size, - encoder.output_size(), - **configs['decoder_conf']) + decoder = _WENET_DECODER_CLASSES[decoder_type](vocab_size, + encoder.output_size(), + **configs['decoder_conf']) - ctc = WENET_CTC_CLASSES[ctc_type]( + ctc = _WENET_CTC_CLASSES[ctc_type]( vocab_size, encoder.output_size(), blank_id=configs['ctc_conf']['ctc_blank_id'] @@ -111,11 +111,11 @@ def init_model(args, configs): if model_type == "transducer": predictor_type = configs.get('predictor', 'rnn') joint_type = configs.get('joint', 'transducerjoint') - predictor = WENET_PREDICTOR_CLASSES[predictor_type]( + predictor = _WENET_PREDICTOR_CLASSES[predictor_type]( vocab_size, **configs['predictor_conf']) - joint = WENET_JOINT_CLASSES[joint_type](vocab_size, - **configs['joint_conf']) - model = WENET_MODEL_CLASSES[model_type]( + joint = _WENET_JOINT_CLASSES[joint_type](vocab_size, + **configs['joint_conf']) + model = _WENET_MODEL_CLASSES[model_type]( vocab_size=vocab_size, blank=0, predictor=predictor, @@ -137,7 +137,7 @@ def init_model(args, configs): print(configs) return model, configs else: - model = WENET_MODEL_CLASSES[model_type]( + model = _WENET_MODEL_CLASSES[model_type]( vocab_size=vocab_size, encoder=encoder, decoder=decoder, From 54f4a95ac42b7c22b433e1b855b2509918cf3ef8 Mon Sep 17 00:00:00 2001 From: xingchensong Date: Sun, 10 Dec 2023 21:54:04 +0800 Subject: [PATCH 9/9] [refactor] fix comment --- .../aishell/rnnt/conf/conformer_rnnt.yaml | 2 +- .../rnnt/conf/conformer_u2pp_rnnt.yaml | 2 +- .../conf/example_embedding_predictor.yaml | 2 +- .../s0/conf/train_unified_conformer_ctl.yaml | 2 +- examples/aishell/s0/run.sh | 2 +- .../aishell2/rnnt/conf/conformer_rnnt.yaml | 2 +- .../rnnt/conf/conformer_u2pp_rnnt.yaml | 2 +- .../librispeech/rnnt/conf/conformer_rnnt.yaml | 2 +- wenet/utils/init_model.py | 44 +++++++++---------- 9 files changed, 30 insertions(+), 30 deletions(-) diff --git a/examples/aishell/rnnt/conf/conformer_rnnt.yaml b/examples/aishell/rnnt/conf/conformer_rnnt.yaml index e162f59d2..690743760 100644 --- a/examples/aishell/rnnt/conf/conformer_rnnt.yaml +++ b/examples/aishell/rnnt/conf/conformer_rnnt.yaml @@ -17,7 +17,7 @@ encoder_conf: pos_enc_layer_type: 'rel_pos' selfattention_layer_type: 'rel_selfattn' -joint: transducerjoint +joint: transducer_joint joint_conf: enc_output_size: 256 pred_output_size: 256 diff --git a/examples/aishell/rnnt/conf/conformer_u2pp_rnnt.yaml b/examples/aishell/rnnt/conf/conformer_u2pp_rnnt.yaml index 3bd8ab36c..5079ef988 100644 --- a/examples/aishell/rnnt/conf/conformer_u2pp_rnnt.yaml +++ b/examples/aishell/rnnt/conf/conformer_u2pp_rnnt.yaml @@ -21,7 +21,7 @@ encoder_conf: cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster use_dynamic_left_chunk: false -joint: transducerjoint +joint: transducer_joint joint_conf: enc_output_size: 256 pred_output_size: 256 diff --git a/examples/aishell/rnnt/conf/example_embedding_predictor.yaml b/examples/aishell/rnnt/conf/example_embedding_predictor.yaml index 3f1423169..9a1b4ecac 100644 --- a/examples/aishell/rnnt/conf/example_embedding_predictor.yaml +++ b/examples/aishell/rnnt/conf/example_embedding_predictor.yaml @@ -15,7 +15,7 @@ encoder_conf: pos_enc_layer_type: 'rel_pos' selfattention_layer_type: 'rel_selfattn' -joint: transducerjoint +joint: transducer_joint joint_conf: enc_output_size: 256 pred_output_size: 320 diff --git a/examples/aishell/s0/conf/train_unified_conformer_ctl.yaml b/examples/aishell/s0/conf/train_unified_conformer_ctl.yaml index ea2f548cd..8cf2b726d 100644 --- a/examples/aishell/s0/conf/train_unified_conformer_ctl.yaml +++ b/examples/aishell/s0/conf/train_unified_conformer_ctl.yaml @@ -33,7 +33,7 @@ decoder_conf: src_attention_dropout_rate: 0.0 # hybrid CTC/attention -model: ctlmodel +model: ctl_model model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option diff --git a/examples/aishell/s0/run.sh b/examples/aishell/s0/run.sh index ee126e6d0..36d5bb7e3 100644 --- a/examples/aishell/s0/run.sh +++ b/examples/aishell/s0/run.sh @@ -305,7 +305,7 @@ if [ ${stage} -le 9 ] && [ ${stop_stage} -ge 9 ]; then tools/k2/prepare_mmi.sh data/train/ data/dev data/local/lfmmi # 9.2 Run LF-MMI training from stage 4, modify below args in train.yaml - # model: k2model + # model: k2_model # model_conf: # lfmmi_dir data/local/lfmmi diff --git a/examples/aishell2/rnnt/conf/conformer_rnnt.yaml b/examples/aishell2/rnnt/conf/conformer_rnnt.yaml index e162f59d2..690743760 100644 --- a/examples/aishell2/rnnt/conf/conformer_rnnt.yaml +++ b/examples/aishell2/rnnt/conf/conformer_rnnt.yaml @@ -17,7 +17,7 @@ encoder_conf: pos_enc_layer_type: 'rel_pos' selfattention_layer_type: 'rel_selfattn' -joint: transducerjoint +joint: transducer_joint joint_conf: enc_output_size: 256 pred_output_size: 256 diff --git a/examples/aishell2/rnnt/conf/conformer_u2pp_rnnt.yaml b/examples/aishell2/rnnt/conf/conformer_u2pp_rnnt.yaml index 17e773cad..454e8089d 100644 --- a/examples/aishell2/rnnt/conf/conformer_u2pp_rnnt.yaml +++ b/examples/aishell2/rnnt/conf/conformer_u2pp_rnnt.yaml @@ -21,7 +21,7 @@ encoder_conf: cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster use_dynamic_left_chunk: false -joint: transducerjoint +joint: transducer_joint joint_conf: enc_output_size: 256 pred_output_size: 256 diff --git a/examples/librispeech/rnnt/conf/conformer_rnnt.yaml b/examples/librispeech/rnnt/conf/conformer_rnnt.yaml index 27413d29d..6b7fdc906 100644 --- a/examples/librispeech/rnnt/conf/conformer_rnnt.yaml +++ b/examples/librispeech/rnnt/conf/conformer_rnnt.yaml @@ -17,7 +17,7 @@ encoder_conf: pos_enc_layer_type: 'rel_pos' selfattention_layer_type: 'rel_selfattn' -joint: transducerjoint +joint: transducer_joint joint_conf: enc_output_size: 256 pred_output_size: 256 diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index 89845b0d2..185b2c26c 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -34,7 +34,7 @@ from wenet.utils.cmvn import load_cmvn from wenet.utils.checkpoint import load_checkpoint, load_trained_modules -_WENET_ENCODER_CLASSES = { +WENET_ENCODER_CLASSES = { "transformer": TransformerEncoder, "conformer": ConformerEncoder, "squeezeformer": SqueezeformerEncoder, @@ -45,30 +45,30 @@ "dual_conformer": DualConformerEncoder, } -_WENET_DECODER_CLASSES = { +WENET_DECODER_CLASSES = { "transformer": TransformerDecoder, "bitransformer": BiTransformerDecoder, } -_WENET_CTC_CLASSES = { +WENET_CTC_CLASSES = { "ctc": CTC, } -_WENET_PREDICTOR_CLASSES = { +WENET_PREDICTOR_CLASSES = { "rnn": RNNPredictor, "embedding": EmbeddingPredictor, "conv": ConvPredictor, } -_WENET_JOINT_CLASSES = { - "transducerjoint": TransducerJoint, +WENET_JOINT_CLASSES = { + "transducer_joint": TransducerJoint, } -_WENET_MODEL_CLASSES = { - "asrmodel": ASRModel, - "ctlmodel": CTLModel, +WENET_MODEL_CLASSES = { + "asr_model": ASRModel, + "ctl_model": CTLModel, "whisper": Whisper, - "k2model": K2Model, + "k2_model": K2Model, "transducer": Transducer, } @@ -90,32 +90,32 @@ def init_model(args, configs): decoder_type = configs.get('decoder', 'bitransformer') ctc_type = configs.get('ctc', 'ctc') - encoder = _WENET_ENCODER_CLASSES[encoder_type]( + encoder = WENET_ENCODER_CLASSES[encoder_type]( input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'], **configs['encoder_conf']['efficient_conf'] if 'efficient_conf' in configs['encoder_conf'] else {}) - decoder = _WENET_DECODER_CLASSES[decoder_type](vocab_size, - encoder.output_size(), - **configs['decoder_conf']) + decoder = WENET_DECODER_CLASSES[decoder_type](vocab_size, + encoder.output_size(), + **configs['decoder_conf']) - ctc = _WENET_CTC_CLASSES[ctc_type]( + ctc = WENET_CTC_CLASSES[ctc_type]( vocab_size, encoder.output_size(), blank_id=configs['ctc_conf']['ctc_blank_id'] if 'ctc_conf' in configs else 0) - model_type = configs.get('model', 'asrmodel') + model_type = configs.get('model', 'asr_model') if model_type == "transducer": predictor_type = configs.get('predictor', 'rnn') - joint_type = configs.get('joint', 'transducerjoint') - predictor = _WENET_PREDICTOR_CLASSES[predictor_type]( + joint_type = configs.get('joint', 'transducer_joint') + predictor = WENET_PREDICTOR_CLASSES[predictor_type]( vocab_size, **configs['predictor_conf']) - joint = _WENET_JOINT_CLASSES[joint_type](vocab_size, - **configs['joint_conf']) - model = _WENET_MODEL_CLASSES[model_type]( + joint = WENET_JOINT_CLASSES[joint_type](vocab_size, + **configs['joint_conf']) + model = WENET_MODEL_CLASSES[model_type]( vocab_size=vocab_size, blank=0, predictor=predictor, @@ -137,7 +137,7 @@ def init_model(args, configs): print(configs) return model, configs else: - model = _WENET_MODEL_CLASSES[model_type]( + model = WENET_MODEL_CLASSES[model_type]( vocab_size=vocab_size, encoder=encoder, decoder=decoder,