From 7bece3a32de5a36e05281dad5d192256178319a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xingchen=20Song=28=E5=AE=8B=E6=98=9F=E8=BE=B0=29?= Date: Fri, 8 Dec 2023 20:19:29 +0800 Subject: [PATCH] fix(code): fix lint (#2207) --- wenet/k2/model.py | 10 ++++++---- wenet/paraformer/paraformer.py | 19 +++++++++++-------- wenet/transformer/asr_model.py | 25 ++++++++++--------------- wenet/transformer/ctc.py | 1 - wenet/utils/train_utils.py | 26 ++++++++++++++------------ 5 files changed, 41 insertions(+), 40 deletions(-) diff --git a/wenet/k2/model.py b/wenet/k2/model.py index 271e450da..60a4fab30 100644 --- a/wenet/k2/model.py +++ b/wenet/k2/model.py @@ -47,10 +47,12 @@ def __init__( self.load_lfmmi_resource() @torch.jit.ignore(drop=True) - def _forward_ctc(self, encoder_out: torch.Tensor, - encoder_mask: torch.Tensor, text: torch.Tensor, - text_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - loss_ctc, ctc_probs = self._calc_lfmmi_loss(encoder_out, encoder_mask, text) + def _forward_ctc( + self, encoder_out: torch.Tensor, encoder_mask: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + loss_ctc, ctc_probs = self._calc_lfmmi_loss(encoder_out, encoder_mask, + text) return loss_ctc, ctc_probs @torch.jit.ignore(drop=True) diff --git a/wenet/paraformer/paraformer.py b/wenet/paraformer/paraformer.py index 32df2dde1..7f7c7513f 100644 --- a/wenet/paraformer/paraformer.py +++ b/wenet/paraformer/paraformer.py @@ -137,8 +137,8 @@ def forward( loss_ctc: Optional[torch.Tensor] = None if self.ctc_weight != 0.0: loss_ctc, ctc_probs = self._forward_ctc(encoder_out, - encoder_out_mask, - text, text_lengths) + encoder_out_mask, text, + text_lengths) # TODO(Mddc): thu acc loss_decoder = self.criterion_att(decoder_out, ys_pad) loss = loss_decoder @@ -151,13 +151,16 @@ def forward( } @torch.jit.ignore(drop=True) - def _forward_ctc(self, encoder_out: torch.Tensor, - encoder_mask: torch.Tensor, text: torch.Tensor, - text_lengths: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + def _forward_ctc( + self, + encoder_out: torch.Tensor, + encoder_mask: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: encoder_out_lens = encoder_mask.squeeze(1).sum(1) - loss_ctc, ctc_probs = self.ctc(encoder_out, encoder_out_lens, - text, text_lengths) + loss_ctc, ctc_probs = self.ctc(encoder_out, encoder_out_lens, text, + text_lengths) return loss_ctc, ctc_probs @torch.jit.ignore(drop=True) diff --git a/wenet/transformer/asr_model.py b/wenet/transformer/asr_model.py index bac9fafea..9cd2aff4f 100644 --- a/wenet/transformer/asr_model.py +++ b/wenet/transformer/asr_model.py @@ -53,16 +53,10 @@ def __init__( super().__init__() # note that eos is the same as sos (equivalent ID) - self.sos = ( - vocab_size - 1 - if special_tokens is None - else special_tokens.get("sos", vocab_size - 1) - ) - self.eos = ( - vocab_size - 1 - if special_tokens is None - else special_tokens.get("eos", vocab_size - 1) - ) + self.sos = (vocab_size - 1 if special_tokens is None else + special_tokens.get("sos", vocab_size - 1)) + self.eos = (vocab_size - 1 if special_tokens is None else + special_tokens.get("eos", vocab_size - 1)) self.vocab_size = vocab_size self.special_tokens = special_tokens self.ignore_id = ignore_id @@ -136,12 +130,13 @@ def forward( return {"loss": loss, "loss_att": loss_att, "loss_ctc": loss_ctc} @torch.jit.ignore(drop=True) - def _forward_ctc(self, encoder_out: torch.Tensor, - encoder_mask: torch.Tensor, text: torch.Tensor, - text_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def _forward_ctc( + self, encoder_out: torch.Tensor, encoder_mask: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: encoder_out_lens = encoder_mask.squeeze(1).sum(1) - loss_ctc, ctc_probs = self.ctc(encoder_out, encoder_out_lens, - text, text_lengths) + loss_ctc, ctc_probs = self.ctc(encoder_out, encoder_out_lens, text, + text_lengths) return loss_ctc, ctc_probs def filter_blank_embedding( diff --git a/wenet/transformer/ctc.py b/wenet/transformer/ctc.py index e1e2d2d89..a1d32cfcb 100644 --- a/wenet/transformer/ctc.py +++ b/wenet/transformer/ctc.py @@ -50,7 +50,6 @@ def __init__( def forward(self, hs_pad: torch.Tensor, hlens: torch.Tensor, ys_pad: torch.Tensor, ys_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """Calculate CTC loss. Args: diff --git a/wenet/utils/train_utils.py b/wenet/utils/train_utils.py index 93a8b012b..4fbddb807 100644 --- a/wenet/utils/train_utils.py +++ b/wenet/utils/train_utils.py @@ -66,15 +66,18 @@ def add_model_args(parser): default=None, type=str, help="Pre-trained model to initialize encoder") - parser.add_argument('--enc_init_mods', - default="encoder.", - type=lambda s: [str(mod) for mod in s.split(",") if s != ""], - help="List of encoder modules \ + parser.add_argument( + '--enc_init_mods', + default="encoder.", + type=lambda s: [str(mod) for mod in s.split(",") if s != ""], + help="List of encoder modules \ to initialize ,separated by a comma") - parser.add_argument('--freeze_modules', - default="", - type=lambda s: [str(mod) for mod in s.split(",") if s != ""], - help='free module names',) + parser.add_argument( + '--freeze_modules', + default="", + type=lambda s: [str(mod) for mod in s.split(",") if s != ""], + help='free module names', + ) parser.add_argument('--lfmmi_dir', default='', required=False, @@ -243,10 +246,8 @@ def check_modify_and_save_config(args, configs, symbol_table): fout.write(data) if configs["model_conf"]["apply_non_blank_embedding"]: - logging.warn( - 'Had better load a well trained model' - 'if apply_non_blank_embedding is true !!!' - ) + logging.warn('Had better load a well trained model' + 'if apply_non_blank_embedding is true !!!') return configs @@ -611,6 +612,7 @@ def log_per_epoch(writer, info_dict): writer.add_scalar('epoch/cv_loss', info_dict["cv_loss"], epoch) writer.add_scalar('epoch/lr', info_dict["lr"], epoch) + def freeze_modules(model, args): for name, param in model.named_parameters(): for module_name in args.freeze_modules: