Skip to content

Commit

Permalink
[paraformer] refine model class
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Dec 12, 2023
1 parent 5887279 commit 1b3cb56
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 55 deletions.
35 changes: 23 additions & 12 deletions wenet/paraformer/convert_paraformer_to_wenet_config_and_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import yaml

from wenet.utils.checkpoint import save_checkpoint
from wenet.utils.init_model import init_model


def _load_paraformer_cmvn(cmvn_file) -> Tuple[List, List]:
Expand Down Expand Up @@ -88,24 +89,35 @@ def convert_to_wenet_tokenizer_conf(symbol_table_path, seg_dict, configs,
configs['tokenizer_conf'] = {}
configs['tokenizer_conf']['symbol_table_path'] = symbol_table_path
configs['tokenizer_conf']['seg_dict_path'] = output_path
configs['tokenizer_conf']['special_tokens'] = {}
configs['tokenizer_conf']['special_tokens']['<eos>'] = 2
configs['tokenizer_conf']['special_tokens']['<sos>'] = 1
configs['tokenizer_conf']['special_tokens']['<blank>'] = 0
configs['tokenizer_conf']['special_tokens']['<unk>'] = 8403

shutil.copy(seg_dict, output_path)


def convert_to_wenet_yaml(configs, wenet_yaml_path: str,
fields_to_keep: List[str]) -> Dict:
configs = _filter_dict_fields(configs, fields_to_keep)
configs['encoder'] = 'SanmEncoder'
configs['encoder'] = 'sanm_encoder'
configs['encoder_conf']['input_layer'] = 'conv2d'
configs['decoder'] = 'SanmDecoder'
configs['decoder'] = 'sanm_decoder'
configs['lfr_conf'] = {'lfr_m': 7, 'lfr_n': 6}

configs['cif_predictor_conf'] = configs.pop('predictor_conf')
configs['cif_predictor_conf']['cnn_groups'] = 1
configs['cif_predictor_conf']['residual'] = False
configs['input_dim'] = configs['lfr_conf']['lfr_m'] * 80
configs['predictor'] = 'cif_predictor'
configs['predictor_conf'] = configs.pop('predictor_conf')
configs['predictor_conf']['cnn_groups'] = 1
configs['predictor_conf']['residual'] = False
# This type not use
del configs['encoder_conf']['selfattention_layer_type'], configs[
'encoder_conf']['pos_enc_class']

configs['ctc_conf'] = {}
configs['ctc_conf']['ctc_blank_id'] = 8403

configs['dataset_conf'] = {}
configs['dataset_conf']['filter_conf'] = {}
configs['dataset_conf']['filter_conf']['max_length'] = 20000
Expand Down Expand Up @@ -145,9 +157,9 @@ def convert_to_wenet_yaml(configs, wenet_yaml_path: str,
return configs


def convert_to_wenet_state_dict(configs, paraformer_path, wenet_model_path):
from wenet.utils.init_ali_paraformer import init_model
model, _ = init_model(configs, paraformer_path)
def convert_to_wenet_state_dict(args, configs, wenet_model_path):
args.checkpoint = args.paraformer_model
model, _ = init_model(args, configs)
save_checkpoint(model, wenet_model_path)


Expand Down Expand Up @@ -254,9 +266,9 @@ def main():
configs['model'] = 'paraformer'
configs['is_json_cmvn'] = True
configs['cmvn_file'] = json_cmvn_path
configs['input_dim'] = 80
# configs['input_dim'] = 80
fields_to_keep = [
'encoder_conf', 'decoder_conf', 'predictor_conf', 'input_dim',
'model', 'encoder_conf', 'decoder_conf', 'predictor_conf', 'input_dim',
'output_dim', 'cmvn_file', 'is_json_cmvn', 'model_conf', 'paraformer',
'optim', 'optim_conf', 'scheduler', 'scheduler_conf', 'tokenizer',
'tokenizer_conf'
Expand All @@ -266,8 +278,7 @@ def main():
fields_to_keep)

wenet_model_path = os.path.join(args.output_dir, "wenet_paraformer.pt")
convert_to_wenet_state_dict(wenet_configs, args.paraformer_model,
wenet_model_path)
convert_to_wenet_state_dict(args, wenet_configs, wenet_model_path)

print("Please check {} {} {} {} {} in {}".format(json_cmvn_path,
wenet_train_yaml,
Expand Down
6 changes: 4 additions & 2 deletions wenet/paraformer/paraformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(self,
sampler: bool = True,
sampling_ratio: float = 0.75,
add_eos: bool = True,
special_tokens: Optional[Dict] = None,
**kwargs):
assert isinstance(encoder,
SanmEncoder), isinstance(decoder, SanmDecoder)
Expand All @@ -67,8 +68,9 @@ def __init__(self,

self.lfr = LFR()

self.sos = 1
self.eos = 2
assert special_tokens is not None
self.sos = special_tokens['<sos>']
self.eos = special_tokens['<eos>']
self.ignore_id = ignore_id

self.criterion_att = LabelSmoothingLoss(
Expand Down
34 changes: 0 additions & 34 deletions wenet/utils/init_ali_paraformer.py

This file was deleted.

29 changes: 22 additions & 7 deletions wenet/utils/init_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
import torch

from wenet.k2.model import K2Model
from wenet.paraformer.cif import Cif
from wenet.paraformer.layers import SanmDecoder, SanmEncoder
from wenet.paraformer.paraformer import Paraformer
from wenet.transducer.joint import TransducerJoint
from wenet.transducer.predictor import (ConvPredictor, EmbeddingPredictor,
RNNPredictor)
Expand Down Expand Up @@ -43,11 +46,13 @@
"e_branchformer": EBranchformerEncoder,
"dual_transformer": DualTransformerEncoder,
"dual_conformer": DualConformerEncoder,
'sanm_encoder': SanmEncoder,
}

WENET_DECODER_CLASSES = {
"transformer": TransformerDecoder,
"bitransformer": BiTransformerDecoder,
"sanm_decoder": SanmDecoder,
}

WENET_CTC_CLASSES = {
Expand All @@ -58,6 +63,7 @@
"rnn": RNNPredictor,
"embedding": EmbeddingPredictor,
"conv": ConvPredictor,
"cif_predictor": Cif,
}

WENET_JOINT_CLASSES = {
Expand All @@ -70,6 +76,7 @@
"whisper": Whisper,
"k2_model": K2Model,
"transducer": Transducer,
'paraformer': Paraformer,
}


Expand Down Expand Up @@ -131,11 +138,19 @@ def init_model(args, configs):
sanmencoder/decoder in the future, simplify here.
"""
# TODO(Mddct): refine this
from wenet.utils.init_ali_paraformer import (init_model as
init_ali_paraformer_model)
model, configs = init_ali_paraformer_model(configs, args.checkpoint)
print(configs)
return model, configs
predictor_type = configs.get('predictor', 'cif')
predictor = WENET_PREDICTOR_CLASSES[predictor_type](
**configs['predictor_conf'])
model = WENET_MODEL_CLASSES[model_type](
vocab_size=vocab_size,
encoder=encoder,
decoder=decoder,
predictor=predictor,
ctc=ctc,
**configs['model_conf'],
special_tokens=configs.get('tokenizer_conf',
{}).get('special_tokens', None),
)
else:
model = WENET_MODEL_CLASSES[model_type](
vocab_size=vocab_size,
Expand All @@ -147,9 +162,9 @@ def init_model(args, configs):
**configs['model_conf'])

# If specify checkpoint, load some info from checkpoint
if args.checkpoint is not None:
if hasattr(args, 'checkpoint') and args.checkpoint is not None:
infos = load_checkpoint(model, args.checkpoint)
elif args.enc_init is not None:
elif hasattr(args, 'checkpoint') and args.enc_init is not None:
infos = load_trained_modules(model, args)
else:
infos = {}
Expand Down

0 comments on commit 1b3cb56

Please sign in to comment.