Skip to content

Commit

Permalink
[transformer] fix warning: ignore(True) has been deprecated (#2492)
Browse files Browse the repository at this point in the history
* [transformer] fix warning: ignore(True) has been deprecated

* [transformer] fix warning: ignore(True) has been deprecated
  • Loading branch information
xingchensong authored Apr 18, 2024
1 parent 69a084f commit 5b4e2f4
Show file tree
Hide file tree
Showing 9 changed files with 15 additions and 15 deletions.
2 changes: 1 addition & 1 deletion wenet/ctl_model/asr_model_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
self.ctl_weight = ctl_weight
self.logit_temp = logit_temp

@torch.jit.ignore(drop=True)
@torch.jit.unused
def forward(
self,
batch: dict,
Expand Down
6 changes: 3 additions & 3 deletions wenet/k2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
if self.lfmmi_dir != '':
self.load_lfmmi_resource()

@torch.jit.ignore(drop=True)
@torch.jit.unused
def _forward_ctc(
self, encoder_out: torch.Tensor, encoder_mask: torch.Tensor,
text: torch.Tensor,
Expand All @@ -63,7 +63,7 @@ def _forward_ctc(
text)
return loss_ctc, ctc_probs

@torch.jit.ignore(drop=True)
@torch.jit.unused
def load_lfmmi_resource(self):
try:
import icefall
Expand Down Expand Up @@ -94,7 +94,7 @@ def load_lfmmi_resource(self):
assert len(arr) == 2
self.word_table[int(arr[1])] = arr[0]

@torch.jit.ignore(drop=True)
@torch.jit.unused
def _calc_lfmmi_loss(self, encoder_out, encoder_mask, text):
try:
import k2
Expand Down
4 changes: 2 additions & 2 deletions wenet/paraformer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
xs, _, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
return xs

@torch.jit.ignore(drop=True)
@torch.jit.unused
def forward_layers_checkpointed(self, xs: torch.Tensor,
chunk_masks: torch.Tensor,
pos_emb: torch.Tensor,
Expand Down Expand Up @@ -471,7 +471,7 @@ def forward_layers(self, x: torch.Tensor, tgt_mask: torch.Tensor,
x = layer(x)
return x

@torch.jit.ignore(drop=True)
@torch.jit.unused
def forward_layers_checkpointed(self, x: torch.Tensor,
tgt_mask: torch.Tensor,
memory: torch.Tensor,
Expand Down
4 changes: 2 additions & 2 deletions wenet/paraformer/paraformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def __init__(self,
# labels: 你 好 we@@ net eos
self.add_eos = add_eos

@torch.jit.ignore(drop=True)
@torch.jit.unused
def forward(
self,
batch: Dict,
Expand Down Expand Up @@ -232,7 +232,7 @@ def _calc_att_loss(
ignore_label=self.ignore_id)
return loss_att, acc_att

@torch.jit.ignore(drop=True)
@torch.jit.unused
def _sampler(self, encoder_out, encoder_out_mask, ys_pad, ys_pad_lens,
pre_acoustic_embeds):
device = encoder_out.device
Expand Down
2 changes: 1 addition & 1 deletion wenet/ssl/w2vbert/w2vbert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def _reset_parameter(module: torch.nn.Module):
_reset_parameter(conv1)
_reset_parameter(conv2)

@torch.jit.ignore(drop=True)
@torch.jit.unused
def forward(
self,
batch: Dict,
Expand Down
2 changes: 1 addition & 1 deletion wenet/ssl/wav2vec2/wav2vec2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def _reset_parameter(module: torch.nn.Module):
_reset_parameter(conv1)
_reset_parameter(conv2)

@torch.jit.ignore(drop=True)
@torch.jit.unused
def forward(
self,
batch: Dict,
Expand Down
6 changes: 3 additions & 3 deletions wenet/transformer/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(
normalize_length=length_normalized_loss,
)

@torch.jit.ignore(drop=True)
@torch.jit.unused
def forward(
self,
batch: dict,
Expand Down Expand Up @@ -133,7 +133,7 @@ def forward(
"th_accuracy": acc_att,
}

@torch.jit.ignore(drop=True)
@torch.jit.unused
def _forward_ctc(
self, encoder_out: torch.Tensor, encoder_mask: torch.Tensor,
text: torch.Tensor,
Expand Down Expand Up @@ -231,7 +231,7 @@ def _forward_encoder(
) # (B, maxlen, encoder_dim)
return encoder_out, encoder_mask

@torch.jit.ignore(drop=True)
@torch.jit.unused
def ctc_logprobs(self,
encoder_out: torch.Tensor,
blank_penalty: float = 0.0,
Expand Down
2 changes: 1 addition & 1 deletion wenet/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def forward_layers(self, x: torch.Tensor, tgt_mask: torch.Tensor,
memory_mask)
return x

@torch.jit.ignore(drop=True)
@torch.jit.unused
def forward_layers_checkpointed(self, x: torch.Tensor,
tgt_mask: torch.Tensor,
memory: torch.Tensor,
Expand Down
2 changes: 1 addition & 1 deletion wenet/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
return xs

@torch.jit.ignore(drop=True)
@torch.jit.unused
def forward_layers_checkpointed(self, xs: torch.Tensor,
chunk_masks: torch.Tensor,
pos_emb: torch.Tensor,
Expand Down

0 comments on commit 5b4e2f4

Please sign in to comment.