From 1b3cb56699e965cb30977e482334d12d00a61eb2 Mon Sep 17 00:00:00 2001 From: Mddct Date: Tue, 12 Dec 2023 23:45:45 +0800 Subject: [PATCH] [paraformer] refine model class --- ...ert_paraformer_to_wenet_config_and_ckpt.py | 35 ++++++++++++------- wenet/paraformer/paraformer.py | 6 ++-- wenet/utils/init_ali_paraformer.py | 34 ------------------ wenet/utils/init_model.py | 29 +++++++++++---- 4 files changed, 49 insertions(+), 55 deletions(-) delete mode 100644 wenet/utils/init_ali_paraformer.py diff --git a/wenet/paraformer/convert_paraformer_to_wenet_config_and_ckpt.py b/wenet/paraformer/convert_paraformer_to_wenet_config_and_ckpt.py index 12f198da4..7b219b041 100644 --- a/wenet/paraformer/convert_paraformer_to_wenet_config_and_ckpt.py +++ b/wenet/paraformer/convert_paraformer_to_wenet_config_and_ckpt.py @@ -13,6 +13,7 @@ import yaml from wenet.utils.checkpoint import save_checkpoint +from wenet.utils.init_model import init_model def _load_paraformer_cmvn(cmvn_file) -> Tuple[List, List]: @@ -88,24 +89,35 @@ def convert_to_wenet_tokenizer_conf(symbol_table_path, seg_dict, configs, configs['tokenizer_conf'] = {} configs['tokenizer_conf']['symbol_table_path'] = symbol_table_path configs['tokenizer_conf']['seg_dict_path'] = output_path + configs['tokenizer_conf']['special_tokens'] = {} + configs['tokenizer_conf']['special_tokens'][''] = 2 + configs['tokenizer_conf']['special_tokens'][''] = 1 + configs['tokenizer_conf']['special_tokens'][''] = 0 + configs['tokenizer_conf']['special_tokens'][''] = 8403 + shutil.copy(seg_dict, output_path) def convert_to_wenet_yaml(configs, wenet_yaml_path: str, fields_to_keep: List[str]) -> Dict: configs = _filter_dict_fields(configs, fields_to_keep) - configs['encoder'] = 'SanmEncoder' + configs['encoder'] = 'sanm_encoder' configs['encoder_conf']['input_layer'] = 'conv2d' - configs['decoder'] = 'SanmDecoder' + configs['decoder'] = 'sanm_decoder' configs['lfr_conf'] = {'lfr_m': 7, 'lfr_n': 6} - configs['cif_predictor_conf'] = configs.pop('predictor_conf') - configs['cif_predictor_conf']['cnn_groups'] = 1 - configs['cif_predictor_conf']['residual'] = False + configs['input_dim'] = configs['lfr_conf']['lfr_m'] * 80 + configs['predictor'] = 'cif_predictor' + configs['predictor_conf'] = configs.pop('predictor_conf') + configs['predictor_conf']['cnn_groups'] = 1 + configs['predictor_conf']['residual'] = False # This type not use del configs['encoder_conf']['selfattention_layer_type'], configs[ 'encoder_conf']['pos_enc_class'] + configs['ctc_conf'] = {} + configs['ctc_conf']['ctc_blank_id'] = 8403 + configs['dataset_conf'] = {} configs['dataset_conf']['filter_conf'] = {} configs['dataset_conf']['filter_conf']['max_length'] = 20000 @@ -145,9 +157,9 @@ def convert_to_wenet_yaml(configs, wenet_yaml_path: str, return configs -def convert_to_wenet_state_dict(configs, paraformer_path, wenet_model_path): - from wenet.utils.init_ali_paraformer import init_model - model, _ = init_model(configs, paraformer_path) +def convert_to_wenet_state_dict(args, configs, wenet_model_path): + args.checkpoint = args.paraformer_model + model, _ = init_model(args, configs) save_checkpoint(model, wenet_model_path) @@ -254,9 +266,9 @@ def main(): configs['model'] = 'paraformer' configs['is_json_cmvn'] = True configs['cmvn_file'] = json_cmvn_path - configs['input_dim'] = 80 + # configs['input_dim'] = 80 fields_to_keep = [ - 'encoder_conf', 'decoder_conf', 'predictor_conf', 'input_dim', + 'model', 'encoder_conf', 'decoder_conf', 'predictor_conf', 'input_dim', 'output_dim', 'cmvn_file', 'is_json_cmvn', 'model_conf', 'paraformer', 'optim', 'optim_conf', 'scheduler', 'scheduler_conf', 'tokenizer', 'tokenizer_conf' @@ -266,8 +278,7 @@ def main(): fields_to_keep) wenet_model_path = os.path.join(args.output_dir, "wenet_paraformer.pt") - convert_to_wenet_state_dict(wenet_configs, args.paraformer_model, - wenet_model_path) + convert_to_wenet_state_dict(args, wenet_configs, wenet_model_path) print("Please check {} {} {} {} {} in {}".format(json_cmvn_path, wenet_train_yaml, diff --git a/wenet/paraformer/paraformer.py b/wenet/paraformer/paraformer.py index 7f7c7513f..4ba354fff 100644 --- a/wenet/paraformer/paraformer.py +++ b/wenet/paraformer/paraformer.py @@ -53,6 +53,7 @@ def __init__(self, sampler: bool = True, sampling_ratio: float = 0.75, add_eos: bool = True, + special_tokens: Optional[Dict] = None, **kwargs): assert isinstance(encoder, SanmEncoder), isinstance(decoder, SanmDecoder) @@ -67,8 +68,9 @@ def __init__(self, self.lfr = LFR() - self.sos = 1 - self.eos = 2 + assert special_tokens is not None + self.sos = special_tokens[''] + self.eos = special_tokens[''] self.ignore_id = ignore_id self.criterion_att = LabelSmoothingLoss( diff --git a/wenet/utils/init_ali_paraformer.py b/wenet/utils/init_ali_paraformer.py deleted file mode 100644 index c0db11357..000000000 --- a/wenet/utils/init_ali_paraformer.py +++ /dev/null @@ -1,34 +0,0 @@ -import torch -from wenet.paraformer.cif import Cif -from wenet.paraformer.layers import (SanmDecoder, SanmEncoder) -from wenet.paraformer.paraformer import Paraformer -from wenet.transformer.cmvn import GlobalCMVN -from wenet.utils.checkpoint import load_checkpoint -from wenet.utils.cmvn import load_cmvn - - -def init_model(configs, checkpoint_path=None): - mean, istd = load_cmvn(configs['cmvn_file'], True) - global_cmvn = GlobalCMVN( - torch.from_numpy(mean).float(), - torch.from_numpy(istd).float()) - input_dim = configs['input_dim'] - vocab_size = configs['output_dim'] - encoder = SanmEncoder(global_cmvn=global_cmvn, - input_size=configs['lfr_conf']['lfr_m'] * input_dim, - **configs['encoder_conf']) - decoder = SanmDecoder(vocab_size=vocab_size, - encoder_output_size=encoder.output_size(), - **configs['decoder_conf']) - predictor = Cif(**configs['cif_predictor_conf']) - model = Paraformer( - vocab_size=vocab_size, - encoder=encoder, - decoder=decoder, - predictor=predictor, - **configs['model_conf'], - ) - - if checkpoint_path is not None: - load_checkpoint(model, checkpoint_path) - return model, configs diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index 185b2c26c..963bb64ee 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -15,6 +15,9 @@ import torch from wenet.k2.model import K2Model +from wenet.paraformer.cif import Cif +from wenet.paraformer.layers import SanmDecoder, SanmEncoder +from wenet.paraformer.paraformer import Paraformer from wenet.transducer.joint import TransducerJoint from wenet.transducer.predictor import (ConvPredictor, EmbeddingPredictor, RNNPredictor) @@ -43,11 +46,13 @@ "e_branchformer": EBranchformerEncoder, "dual_transformer": DualTransformerEncoder, "dual_conformer": DualConformerEncoder, + 'sanm_encoder': SanmEncoder, } WENET_DECODER_CLASSES = { "transformer": TransformerDecoder, "bitransformer": BiTransformerDecoder, + "sanm_decoder": SanmDecoder, } WENET_CTC_CLASSES = { @@ -58,6 +63,7 @@ "rnn": RNNPredictor, "embedding": EmbeddingPredictor, "conv": ConvPredictor, + "cif_predictor": Cif, } WENET_JOINT_CLASSES = { @@ -70,6 +76,7 @@ "whisper": Whisper, "k2_model": K2Model, "transducer": Transducer, + 'paraformer': Paraformer, } @@ -131,11 +138,19 @@ def init_model(args, configs): sanmencoder/decoder in the future, simplify here. """ # TODO(Mddct): refine this - from wenet.utils.init_ali_paraformer import (init_model as - init_ali_paraformer_model) - model, configs = init_ali_paraformer_model(configs, args.checkpoint) - print(configs) - return model, configs + predictor_type = configs.get('predictor', 'cif') + predictor = WENET_PREDICTOR_CLASSES[predictor_type]( + **configs['predictor_conf']) + model = WENET_MODEL_CLASSES[model_type]( + vocab_size=vocab_size, + encoder=encoder, + decoder=decoder, + predictor=predictor, + ctc=ctc, + **configs['model_conf'], + special_tokens=configs.get('tokenizer_conf', + {}).get('special_tokens', None), + ) else: model = WENET_MODEL_CLASSES[model_type]( vocab_size=vocab_size, @@ -147,9 +162,9 @@ def init_model(args, configs): **configs['model_conf']) # If specify checkpoint, load some info from checkpoint - if args.checkpoint is not None: + if hasattr(args, 'checkpoint') and args.checkpoint is not None: infos = load_checkpoint(model, args.checkpoint) - elif args.enc_init is not None: + elif hasattr(args, 'checkpoint') and args.enc_init is not None: infos = load_trained_modules(model, args) else: infos = {}