Skip to content

Commit

Permalink
mask pad value to 0
Browse files Browse the repository at this point in the history
  • Loading branch information
robin1001 committed Jan 4, 2024
1 parent 2d873ee commit e801927
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions wenet/tts/vits/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from wenet.tts.vits.commons import init_weights, get_padding
from wenet.tts.vits.losses import generator_loss, discriminator_loss, feature_loss, kl_loss
from wenet.tts.vits.mel_processing import mel_spectrogram_torch
from wenet.utils.mask import make_pad_mask


class StochasticDurationPredictor(nn.Module):
Expand Down Expand Up @@ -791,6 +792,8 @@ def __init__(self, n_vocab, spec_channels, **kwargs):
def forward(self, batch: dict, device: torch.device):
x = batch['target'].to(device)
x_lengths = batch['target_lengths'].to(device)
x_mask = make_pad_mask(x_lengths)
x = x.masked_fill(x_mask, 0) # change pad value(IGNORE_ID = -1) to 0
spec = batch['feats'].to(device)
spec_lengths = batch['feats_lengths'].to(device)
spec = spec.transpose(1, 2)
Expand Down

0 comments on commit e801927

Please sign in to comment.