Skip to content

Commit

Permalink
fix(code): fix lint (#2207)
Browse files Browse the repository at this point in the history
  • Loading branch information
xingchensong authored Dec 8, 2023
1 parent 2894f7c commit 7bece3a
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 40 deletions.
10 changes: 6 additions & 4 deletions wenet/k2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@ def __init__(
self.load_lfmmi_resource()

@torch.jit.ignore(drop=True)
def _forward_ctc(self, encoder_out: torch.Tensor,
encoder_mask: torch.Tensor, text: torch.Tensor,
text_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
loss_ctc, ctc_probs = self._calc_lfmmi_loss(encoder_out, encoder_mask, text)
def _forward_ctc(
self, encoder_out: torch.Tensor, encoder_mask: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
loss_ctc, ctc_probs = self._calc_lfmmi_loss(encoder_out, encoder_mask,
text)
return loss_ctc, ctc_probs

@torch.jit.ignore(drop=True)
Expand Down
19 changes: 11 additions & 8 deletions wenet/paraformer/paraformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ def forward(
loss_ctc: Optional[torch.Tensor] = None
if self.ctc_weight != 0.0:
loss_ctc, ctc_probs = self._forward_ctc(encoder_out,
encoder_out_mask,
text, text_lengths)
encoder_out_mask, text,
text_lengths)
# TODO(Mddc): thu acc
loss_decoder = self.criterion_att(decoder_out, ys_pad)
loss = loss_decoder
Expand All @@ -151,13 +151,16 @@ def forward(
}

@torch.jit.ignore(drop=True)
def _forward_ctc(self, encoder_out: torch.Tensor,
encoder_mask: torch.Tensor, text: torch.Tensor,
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
def _forward_ctc(
self,
encoder_out: torch.Tensor,
encoder_mask: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
loss_ctc, ctc_probs = self.ctc(encoder_out, encoder_out_lens,
text, text_lengths)
loss_ctc, ctc_probs = self.ctc(encoder_out, encoder_out_lens, text,
text_lengths)
return loss_ctc, ctc_probs

@torch.jit.ignore(drop=True)
Expand Down
25 changes: 10 additions & 15 deletions wenet/transformer/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,10 @@ def __init__(

super().__init__()
# note that eos is the same as sos (equivalent ID)
self.sos = (
vocab_size - 1
if special_tokens is None
else special_tokens.get("sos", vocab_size - 1)
)
self.eos = (
vocab_size - 1
if special_tokens is None
else special_tokens.get("eos", vocab_size - 1)
)
self.sos = (vocab_size - 1 if special_tokens is None else
special_tokens.get("sos", vocab_size - 1))
self.eos = (vocab_size - 1 if special_tokens is None else
special_tokens.get("eos", vocab_size - 1))
self.vocab_size = vocab_size
self.special_tokens = special_tokens
self.ignore_id = ignore_id
Expand Down Expand Up @@ -136,12 +130,13 @@ def forward(
return {"loss": loss, "loss_att": loss_att, "loss_ctc": loss_ctc}

@torch.jit.ignore(drop=True)
def _forward_ctc(self, encoder_out: torch.Tensor,
encoder_mask: torch.Tensor, text: torch.Tensor,
text_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def _forward_ctc(
self, encoder_out: torch.Tensor, encoder_mask: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
loss_ctc, ctc_probs = self.ctc(encoder_out, encoder_out_lens,
text, text_lengths)
loss_ctc, ctc_probs = self.ctc(encoder_out, encoder_out_lens, text,
text_lengths)
return loss_ctc, ctc_probs

def filter_blank_embedding(
Expand Down
1 change: 0 additions & 1 deletion wenet/transformer/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def __init__(
def forward(self, hs_pad: torch.Tensor, hlens: torch.Tensor,
ys_pad: torch.Tensor,
ys_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

"""Calculate CTC loss.
Args:
Expand Down
26 changes: 14 additions & 12 deletions wenet/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,18 @@ def add_model_args(parser):
default=None,
type=str,
help="Pre-trained model to initialize encoder")
parser.add_argument('--enc_init_mods',
default="encoder.",
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
help="List of encoder modules \
parser.add_argument(
'--enc_init_mods',
default="encoder.",
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
help="List of encoder modules \
to initialize ,separated by a comma")
parser.add_argument('--freeze_modules',
default="",
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
help='free module names',)
parser.add_argument(
'--freeze_modules',
default="",
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
help='free module names',
)
parser.add_argument('--lfmmi_dir',
default='',
required=False,
Expand Down Expand Up @@ -243,10 +246,8 @@ def check_modify_and_save_config(args, configs, symbol_table):
fout.write(data)

if configs["model_conf"]["apply_non_blank_embedding"]:
logging.warn(
'Had better load a well trained model'
'if apply_non_blank_embedding is true !!!'
)
logging.warn('Had better load a well trained model'
'if apply_non_blank_embedding is true !!!')

return configs

Expand Down Expand Up @@ -611,6 +612,7 @@ def log_per_epoch(writer, info_dict):
writer.add_scalar('epoch/cv_loss', info_dict["cv_loss"], epoch)
writer.add_scalar('epoch/lr', info_dict["lr"], epoch)


def freeze_modules(model, args):
for name, param in model.named_parameters():
for module_name in args.freeze_modules:
Expand Down

0 comments on commit 7bece3a

Please sign in to comment.