From 8f7a8f3f9f6674b058d0dc9a79cd29e2775ca6f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xingchen=20Song=28=E5=AE=8B=E6=98=9F=E8=BE=B0=29?= Date: Fri, 24 Nov 2023 11:54:18 +0800 Subject: [PATCH] fix(whisper): support arbitrary ctc blank id (#2157) --- wenet/transformer/ctc.py | 4 +++- wenet/utils/init_model.py | 3 ++- wenet/utils/train_utils.py | 11 +++++++++++ .../convert_whisper_to_wenet_config_and_ckpt.py | 3 +++ 4 files changed, 19 insertions(+), 2 deletions(-) diff --git a/wenet/transformer/ctc.py b/wenet/transformer/ctc.py index c30ffbf08..ef6662060 100644 --- a/wenet/transformer/ctc.py +++ b/wenet/transformer/ctc.py @@ -25,6 +25,7 @@ def __init__( encoder_output_size: int, dropout_rate: float = 0.0, reduce: bool = True, + blank_id: int = 0, ): """ Construct CTC module Args: @@ -32,6 +33,7 @@ def __init__( encoder_output_size: number of encoder projection units dropout_rate: dropout rate (0.0 ~ 1.0) reduce: reduce the CTC loss into a scalar + blank_id: blank label. """ super().__init__() eprojs = encoder_output_size @@ -39,7 +41,7 @@ def __init__( self.ctc_lo = torch.nn.Linear(eprojs, odim) reduction_type = "sum" if reduce else "none" - self.ctc_loss = torch.nn.CTCLoss(reduction=reduction_type) + self.ctc_loss = torch.nn.CTCLoss(blank=blank_id, reduction=reduction_type) def forward(self, hs_pad: torch.Tensor, hlens: torch.Tensor, ys_pad: torch.Tensor, ys_lens: torch.Tensor) -> torch.Tensor: diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index 6eb652473..2e4a16e54 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -99,7 +99,8 @@ def init_model(args, configs): assert configs['decoder_conf']['r_num_blocks'] > 0 decoder = BiTransformerDecoder(vocab_size, encoder.output_size(), **configs['decoder_conf']) - ctc = CTC(vocab_size, encoder.output_size()) + ctc = CTC(vocab_size, encoder.output_size(), + blank_id=configs['ctc_conf']['ctc_blank_id']) # Init joint CTC/Attention or Transducer model if 'predictor' in configs: diff --git a/wenet/utils/train_utils.py b/wenet/utils/train_utils.py index 4c2aba516..6952cb190 100644 --- a/wenet/utils/train_utils.py +++ b/wenet/utils/train_utils.py @@ -216,6 +216,17 @@ def check_modify_and_save_config(args, configs): symbol_table = read_symbol_table(args.symbol_table) vocab_size = len(symbol_table) + if 'ctc_conf' not in configs: + configs['ctc_conf'] = {} + + if '' in symbol_table: + if 'ctc_blank_id' in configs['ctc_conf']: + assert configs['ctc_conf']['ctc_blank_id'] == symbol_table[''] + else: + configs['ctc_conf']['ctc_blank_id'] = symbol_table[''] + else: + assert 'ctc_blank_id' in configs['ctc_conf'], "PLZ set ctc_blank_id in yaml" + configs['input_dim'] = input_dim configs['output_dim'] = configs.get('output_dim', vocab_size) configs['cmvn_file'] = args.cmvn diff --git a/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py b/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py index 45c36d970..909f43c92 100644 --- a/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py +++ b/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py @@ -89,6 +89,9 @@ def convert_to_wenet_yaml(tokenizer, dims, wenet_yaml_path: str): configs['decoder_conf']['key_bias'] = False configs['decoder_conf']['activation_type'] = "gelu" + configs['ctc_conf'] = {} + configs['ctc_conf']['ctc_blank_id'] = 50362 # + configs['model_conf'] = {} configs['model_conf']['ctc_weight'] = 0.3 configs['model_conf']['lsm_weight'] = 0.1