From 1c66043d728148ee13fbfa0c37296e941fdb6e32 Mon Sep 17 00:00:00 2001 From: Binbin Zhang Date: Thu, 4 Jan 2024 10:47:54 +0800 Subject: [PATCH 1/7] [vits] add vits support --- wenet/tts/vits/commons.py | 154 ++++++ wenet/tts/vits/losses.py | 58 ++ wenet/tts/vits/mel_processing.py | 107 ++++ wenet/tts/vits/models.py | 870 ++++++++++++++++++++++++++++++ wenet/tts/vits/modules.py | 513 ++++++++++++++++++ wenet/tts/vits/monotonic_align.py | 57 ++ wenet/tts/vits/transforms.py | 206 +++++++ wenet/utils/init_model.py | 49 +- 8 files changed, 1991 insertions(+), 23 deletions(-) create mode 100644 wenet/tts/vits/commons.py create mode 100644 wenet/tts/vits/losses.py create mode 100644 wenet/tts/vits/mel_processing.py create mode 100644 wenet/tts/vits/models.py create mode 100644 wenet/tts/vits/modules.py create mode 100644 wenet/tts/vits/monotonic_align.py create mode 100644 wenet/tts/vits/transforms.py diff --git a/wenet/tts/vits/commons.py b/wenet/tts/vits/commons.py new file mode 100644 index 000000000..943889cc0 --- /dev/null +++ b/wenet/tts/vits/commons.py @@ -0,0 +1,154 @@ +import math + +import torch +from torch.nn import functional as F + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def convert_pad_shape(pad_shape): + pad_shape = [item for sublist in reversed(pad_shape) for item in sublist] + return pad_shape + + +def kl_divergence(m_p, logs_p, m_q, logs_q): + """KL(P||Q)""" + kl = (logs_q - logs_p) - 0.5 + kl += (0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q)**2)) * + torch.exp(-2.0 * logs_q)) + return kl + + +def rand_gumbel(shape): + """Sample from the Gumbel distribution, protect from overflows.""" + uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 + return -torch.log(-torch.log(uniform_samples)) + + +def rand_gumbel_like(x): + g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) + return g + + +def slice_segments(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + return ret + + +def rand_slice_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + ids_str = (torch.rand([b]).to(device=x.device) * + ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def get_timing_signal_1d(length, + channels, + min_timescale=1.0, + max_timescale=1.0e4): + position = torch.arange(length, dtype=torch.float) + num_timescales = channels // 2 + log_timescale_increment = math.log( + float(max_timescale) / float(min_timescale)) / (num_timescales - 1) + inv_timescales = min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float) * + -log_timescale_increment) + scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) + signal = F.pad(signal, [0, 0, 0, channels % 2]) + signal = signal.view(1, channels, length) + return signal + + +def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, + max_timescale) + return x + signal.to(dtype=x.dtype, device=x.device) + + +def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, + max_timescale) + return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) + + +def subsequent_mask(length): + mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) + return mask + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +def shift_1d(x): + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] + return x + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def generate_path(duration, mask): + """ + duration: [b, 1, t_x] + mask: [b, 1, t_y, t_x] + """ + device = duration.device + + b, _, t_y, t_x = mask.shape + cum_duration = torch.cumsum(duration, -1) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0] + ]))[:, :-1] + path = path.unsqueeze(1).transpose(2, 3) * mask + return path + + +def clip_grad_value_(parameters, clip_value, norm_type=2): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + if clip_value is not None: + clip_value = float(clip_value) + + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item()**norm_type + if clip_value is not None: + p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm = total_norm**(1.0 / norm_type) + return total_norm diff --git a/wenet/tts/vits/losses.py b/wenet/tts/vits/losses.py new file mode 100644 index 000000000..470abec8a --- /dev/null +++ b/wenet/tts/vits/losses.py @@ -0,0 +1,58 @@ +import torch + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + rl = rl.float().detach() + gl = gl.float() + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + dr = dr.float() + dg = dg.float() + r_loss = torch.mean((1 - dr)**2) + g_loss = torch.mean(dg**2) + loss += r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + dg = dg.float() + l = torch.mean((1 - dg)**2) + gen_losses.append(l) + loss += l + + return loss, gen_losses + + +def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): + """ + z_p, logs_q: [b, h, t_t] + m_p, logs_p: [b, h, t_t] + """ + z_p = z_p.float() + logs_q = logs_q.float() + m_p = m_p.float() + logs_p = logs_p.float() + z_mask = z_mask.float() + + kl = logs_p - logs_q - 0.5 + kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2.0 * logs_p) + kl = torch.sum(kl * z_mask) + l = kl / torch.sum(z_mask) + return l diff --git a/wenet/tts/vits/mel_processing.py b/wenet/tts/vits/mel_processing.py new file mode 100644 index 000000000..28bab7336 --- /dev/null +++ b/wenet/tts/vits/mel_processing.py @@ -0,0 +1,107 @@ +import torch +import torch.nn.functional as F +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn + +MAX_WAV_VALUE = 32768.0 + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + """ + PARAMS + ------ + C: compression factor + """ + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def spectrogram_torch(y, + n_fft, + sampling_rate, + hop_size, + win_size, + center=False): + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + wnsize_dtype_device = str(win_size) + "_" + dtype_device + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( + dtype=y.dtype, device=y.device) + + y = F.pad( + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + spec = torch.view_as_real(spec) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + + +def spec_to_mel_torch(spec, n_fft, n_mels, sampling_rate): + global mel_basis + dtype_device = str(spec.dtype) + "_" + str(spec.device) + if dtype_device not in mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels) + mel_basis[dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, + device=spec.device) + spec = torch.matmul(mel_basis[dtype_device], spec) + spec = spectral_normalize_torch(spec) + return spec + + +def mel_spectrogram_torch(y, + n_fft, + n_mels, + sampling_rate, + hop_size, + win_size, + center=False): + spec = spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, + center) + spec = spec_to_mel_torch(spec, n_fft, n_mels, sampling_rate) + + return spec diff --git a/wenet/tts/vits/models.py b/wenet/tts/vits/models.py new file mode 100644 index 000000000..2f52914ac --- /dev/null +++ b/wenet/tts/vits/models.py @@ -0,0 +1,870 @@ +import math +import time + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn import Conv1d, ConvTranspose1d, Conv2d +from torch.nn.utils import remove_weight_norm, spectral_norm +from torch.nn.utils import weight_norm + +import wenet.tts.vits.commons as commons +import wenet.tts.vits.modules as modules +import wenet.tts.vits.attentions as attentions +import wenet.tts.vits.monotonic_align as monotonic_align +from wenet.tts.vits.commons import init_weights, get_padding +from wenet.tts.vits.losses import generator_loss, discriminator_loss, feature_loss, kl_loss +from wenet.tts.vits.mel_processing import mel_spectrogram_torch + + +class StochasticDurationPredictor(nn.Module): + + def __init__( + self, + in_channels, + filter_channels, + kernel_size, + p_dropout, + n_flows=4, + gin_channels=256, + ): + super().__init__() + filter_channels = in_channels # it needs to be removed from future version. + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.log_flow = modules.Log() + self.flows = nn.ModuleList() + self.flows.append(modules.ElementwiseAffine(2)) + for i in range(n_flows): + self.flows.append( + modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) + self.flows.append(modules.Flip()) + + self.post_pre = nn.Conv1d(1, filter_channels, 1) + self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.post_convs = modules.DDSConv(filter_channels, + kernel_size, + n_layers=3, + p_dropout=p_dropout) + self.post_flows = nn.ModuleList() + self.post_flows.append(modules.ElementwiseAffine(2)) + for i in range(4): + self.post_flows.append( + modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) + self.post_flows.append(modules.Flip()) + + self.pre = nn.Conv1d(in_channels, filter_channels, 1) + self.proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.convs = modules.DDSConv(filter_channels, + kernel_size, + n_layers=3, + p_dropout=p_dropout) + self.cond = nn.Conv1d(gin_channels, filter_channels, 1) + + def forward(self, + x, + x_mask, + w=None, + g=None, + reverse=False, + noise_scale=1.0): + x = torch.detach(x) + x = self.pre(x) + g = torch.detach(g) + x = x + self.cond(g) + x = self.convs(x, x_mask) + x = self.proj(x) * x_mask + + if not reverse: + flows = self.flows + assert w is not None + + logdet_tot_q = 0 + h_w = self.post_pre(w) + h_w = self.post_convs(h_w, x_mask) + h_w = self.post_proj(h_w) * x_mask + e_q = (torch.randn(w.size(0), 2, w.size(2)).to( + device=x.device, dtype=x.dtype) * x_mask) + z_q = e_q + for flow in self.post_flows: + z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) + logdet_tot_q += logdet_q + z_u, z1 = torch.split(z_q, [1, 1], 1) + u = torch.sigmoid(z_u) * x_mask + z0 = (w - u) * x_mask + logdet_tot_q += torch.sum( + (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]) + logq = ( + torch.sum(-0.5 * (math.log(2 * math.pi) + + (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q) + + logdet_tot = 0 + z0, logdet = self.log_flow(z0, x_mask) + logdet_tot += logdet + z = torch.cat([z0, z1], 1) + for flow in flows: + z, logdet = flow(z, x_mask, g=x, reverse=reverse) + logdet_tot = logdet_tot + logdet + nll = (torch.sum(0.5 * (math.log(2 * math.pi) + + (z**2)) * x_mask, [1, 2]) - logdet_tot) + return nll + logq # [b] + else: + flows = list(reversed(self.flows)) + flows = flows[:-2] + [flows[-1]] # remove a useless vflow + z = (torch.randn(x.size(0), 2, x.size(2)).to( + device=x.device, dtype=x.dtype) * noise_scale) + for flow in flows: + z = flow(z, x_mask, g=x, reverse=reverse) + z0, z1 = torch.split(z, [1, 1], 1) + logw = z0 + return logw + + +class DurationPredictor(nn.Module): + + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, + gin_channels): + super().__init__() + + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.gin_channels = gin_channels + + self.drop = nn.Dropout(p_dropout) + self.conv_1 = nn.Conv1d(in_channels, + filter_channels, + kernel_size, + padding=kernel_size // 2) + self.norm_1 = modules.LayerNorm(filter_channels) + self.conv_2 = nn.Conv1d(filter_channels, + filter_channels, + kernel_size, + padding=kernel_size // 2) + self.norm_2 = modules.LayerNorm(filter_channels) + self.proj = nn.Conv1d(filter_channels, 1, 1) + self.cond = nn.Conv1d(gin_channels, in_channels, 1) + + def forward(self, x, x_mask, g=None): + x = torch.detach(x) + g = torch.detach(g) + x = x + self.cond(g) + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask + + +class TextEncoder(nn.Module): + + def __init__( + self, + n_vocab, + out_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + ): + super().__init__() + self.n_vocab = n_vocab + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.emb = nn.Embedding(n_vocab, hidden_channels) + nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) + + self.encoder = attentions.Encoder(hidden_channels, filter_channels, + n_heads, n_layers, kernel_size, + p_dropout) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths): + x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] + x = torch.transpose(x, 1, -1) # [b, h, t] + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), + 1).to(x.dtype) + + x = self.encoder(x * x_mask, x_mask) + stats = self.proj(x) * x_mask + + m, logs = torch.split(stats, self.out_channels, dim=1) + return x, m, logs, x_mask + + +class ResidualCouplingBlock(nn.Module): + + def __init__( + self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + n_flows=4, + gin_channels=256, + ): + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + for i in range(n_flows): + self.flows.append( + modules.ResidualCouplingLayer( + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + mean_only=True, + )) + self.flows.append(modules.Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, reverse=reverse) + return x + + def remove_weight_norm(self): + for i, l in enumerate(self.flows): + if i % 2 == 0: + l.remove_weight_norm() + + +class PosteriorEncoder(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = modules.WN( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + ) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, g=None): + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), + 1).to(x.dtype) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask + return z, m, logs, x_mask + + +class Generator(torch.nn.Module): + + def __init__( + self, + initial_channel, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels, + ): + super(Generator, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.conv_pre = Conv1d(initial_channel, + upsample_initial_channel, + 7, + 1, + padding=3) + resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2**(i + 1)), + k, + u, + padding=(k - u) // 2, + ))) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2**(i + 1)) + for j, (k, d) in enumerate( + zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock(ch, k, d)) + + self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) + self.ups.apply(init_weights) + + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + + def forward(self, x, g=None): + x = self.conv_pre(x) + x = x + self.cond(g) + + for i in range(self.num_upsamples): + x = F.leaky_relu(x, modules.LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + + +class ConvNeXtLayer(nn.Module): + + def __init__(self, channels, h_channels, scale): + super().__init__() + self.dw_conv = nn.Conv1d( + channels, + channels, + kernel_size=3, + padding=1, + groups=channels, + ) + self.norm = modules.LayerNorm(channels) + self.pw_conv1 = nn.Conv1d(channels, h_channels, 1) + self.pw_conv2 = nn.Conv1d(h_channels, channels, 1) + self.scale = nn.Parameter(torch.full(size=(1, channels, 1), + fill_value=scale), + requires_grad=True) + + def forward(self, x): + res = x + x = self.dw_conv(x) + x = self.norm(x) + x = self.pw_conv1(x) + x = F.gelu(x) + x = self.pw_conv2(x) + x = self.scale * x + x = res + x + return x + + +class DiscriminatorP(torch.nn.Module): + + def __init__(self, + period, + kernel_size=5, + stride=3, + use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.use_spectral_norm = use_spectral_norm + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList([ + norm_f( + Conv2d( + 1, + 32, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + )), + norm_f( + Conv2d( + 32, + 128, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + )), + norm_f( + Conv2d( + 128, + 512, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + )), + norm_f( + Conv2d( + 512, + 1024, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + )), + norm_f( + Conv2d( + 1024, + 1024, + (kernel_size, 1), + 1, + padding=(get_padding(kernel_size, 1), 0), + )), + ]) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class DiscriminatorS(torch.nn.Module): + + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + + def __init__(self, use_spectral_norm=False): + super(MultiPeriodDiscriminator, self).__init__() + periods = [2, 3, 5, 7, 11] + + discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] + discs = discs + [ + DiscriminatorP(i, use_spectral_norm=use_spectral_norm) + for i in periods + ] + self.discriminators = nn.ModuleList(discs) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class SynthesizerTrn(nn.Module): + """ + Synthesizer for Training + """ + + def __init__(self, + n_vocab, + spec_channels, + segment_size, + inter_channels=192, + hidden_channels=192, + filter_channels=768, + n_heads=2, + n_layers=6, + kernel_size=3, + p_dropout=0.1, + resblock="1", + resblock_kernel_sizes=[3, 7, 11], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + upsample_rates=[8, 8, 2, 2], + upsample_initial_channel=512, + upsample_kernel_sizes=[16, 16, 4, 4], + n_speakers=1, + gin_channels=256, + use_sdp=True, + **kwargs): + super().__init__() + self.n_vocab = n_vocab + self.spec_channels = spec_channels + self.inter_channels = inter_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.segment_size = segment_size + self.n_speakers = n_speakers + self.gin_channels = gin_channels + self.use_sdp = use_sdp + + self.enc_p = TextEncoder( + n_vocab, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + ) + self.dec = Generator( + inter_channels, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=gin_channels, + ) + self.enc_q = PosteriorEncoder( + spec_channels, + inter_channels, + hidden_channels, + 5, + 1, + 16, + gin_channels=gin_channels, + ) + self.flow = ResidualCouplingBlock(inter_channels, + hidden_channels, + 5, + 1, + 4, + gin_channels=gin_channels) + + if use_sdp: + self.dp = StochasticDurationPredictor(hidden_channels, + 192, + 3, + 0.5, + 4, + gin_channels=gin_channels) + else: + self.dp = DurationPredictor(hidden_channels, + 256, + 3, + 0.5, + gin_channels=gin_channels) + + self.emb_g = nn.Embedding(n_speakers, gin_channels) + + def forward(self, x, x_lengths, y, y_lengths, sid=None): + x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) + g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) + z_p = self.flow(z, y_mask, g=g) + + with torch.no_grad(): + # negative cross-entropy + s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t] + neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], + keepdim=True) # [b, 1, t_s] + neg_cent2 = torch.matmul( + -0.5 * (z_p**2).transpose(1, 2), + s_p_sq_r) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] + neg_cent3 = torch.matmul( + z_p.transpose(1, 2), + (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] + neg_cent4 = torch.sum(-0.5 * (m_p**2) * s_p_sq_r, [1], + keepdim=True) # [b, 1, t_s] + neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4 + + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze( + y_mask, -1) + attn = (monotonic_align.maximum_path( + neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()) + + w = attn.sum(2) + if self.use_sdp: + l_length = self.dp(x, x_mask, w, g=g) + l_length = l_length / torch.sum(x_mask) + else: + logw_ = torch.log(w + 1e-6) * x_mask + logw = self.dp(x, x_mask, g=g) + l_length = torch.sum( + (logw - logw_)**2, [1, 2]) / torch.sum(x_mask) # for averaging + + # expand prior + m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, + 2)).transpose(1, 2) + logs_p = torch.matmul(attn.squeeze(1), + logs_p.transpose(1, 2)).transpose(1, 2) + + z_slice, ids_slice = commons.rand_slice_segments( + z, y_lengths, self.segment_size) + o = self.dec(z_slice, g=g) + return ( + o, + l_length, + attn, + ids_slice, + x_mask, + y_mask, + (z, z_p, m_p, logs_p, m_q, logs_q), + ) + + def infer( + self, + x, + x_lengths, + sid=None, + noise_scale=1, + length_scale=1, + noise_scale_w=1.0, + max_len=None, + ): + t1 = time.time() + x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) + g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + t2 = time.time() + if self.use_sdp: + logw = self.dp(x, + x_mask, + g=g, + reverse=True, + noise_scale=noise_scale_w) + else: + logw = self.dp(x, x_mask, g=g) + t3 = time.time() + w = torch.exp(logw) * x_mask * length_scale + w_ceil = torch.ceil(w) + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), + 1).to(x_mask.dtype) + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + attn = commons.generate_path(w_ceil, attn_mask) + + m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose( + 1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] + logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose( + 1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] + + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + t4 = time.time() + z = self.flow(z_p, y_mask, g=g, reverse=True) + t5 = time.time() + o = self.dec((z * y_mask)[:, :, :max_len], g=g) + t6 = time.time() + print("TextEncoder: {}s DurationPredictor: {}s Flow: {}s Decoder: {}s". + format( + round(t2 - t1, 3), + round(t3 - t2, 3), + round(t5 - t4, 3), + round(t6 - t5, 3), + )) + return o, attn, y_mask, (z, z_p, m_p, logs_p) + + def export_forward(self, x, x_lengths, scales, sid): + # shape of scales: Bx3, make triton happy + audio, *_ = self.infer( + x, + x_lengths, + sid, + noise_scale=scales[0][0], + length_scale=scales[0][1], + noise_scale_w=scales[0][2], + ) + return audio + + def voice_conversion(self, y, y_lengths, sid_src, sid_tgt): + g_src = self.emb_g(sid_src).unsqueeze(-1) + g_tgt = self.emb_g(sid_tgt).unsqueeze(-1) + z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src) + z_p = self.flow(z, y_mask, g=g_src) + z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) + o_hat = self.dec(z_hat * y_mask, g=g_tgt) + return o_hat, y_mask, (z, z_p, z_hat) + + +class VitsModel(nn.Module): + + def __init__(self, n_vocab, spec_channels, **kwargs): + super().__init__() + self.filter_length = kwargs.get('filter_length', 1024) + self.n_mel_channels = kwargs.get('n_mel_channels', 80) + self.sampling_rate = kwargs.get('sampling_rate', 16000) + self.win_length = kwargs.get('win_length', 1024) + self.hop_length = kwargs.get('hop_length', 256) + self.segment_size = kwargs.get('segment_size', 8192) + self.c_mel = kwargs.get('c_mel', 45) + self.c_kl = kwargs.get('c_kl', 1.0) + self.d_interval = kwargs.get('d_interval', 2) + self.g = SynthesizerTrn(n_vocab, spec_channels, + self.segment_size // self.hop_length, + **kwargs['generator']) + self.d = MultiPeriodDiscriminator(**kwargs['discriminator']) + self.step = 0 + + def forward(self, batch: dict, device: torch.device): + x = batch['target'].to(device) + x_lengths = batch['target_lengths'].to(device) + spec = batch['feats'].to(device) + spec_lengths = batch['feats_lengths'].to(device) + spec = spec.transpose(1, 2) + y = batch['pcm'].to(device) + y = y.unsqueeze(1) + y_lengths = batch['pcm_length'].to(device) + + batch_size = x.size(0) + sid = torch.zeros(batch_size, device=device, dtype=torch.long) + (y_hat, l_length, attn, ids_slice, x_mask, z_mask, + (z, z_p, m_p, logs_p, m_q, logs_q)) = self.g(x, x_lengths, spec, + spec_lengths, sid) + # mel = spec_to_mel_torch( + # spec, + # self.filter_length, + # self.n_mel_channels, + # self.sampling_rate, + # ) + mel = mel_spectrogram_torch( + y.squeeze(1), + self.filter_length, + self.n_mel_channels, + self.sampling_rate, + self.hop_length, + self.win_length, + ) + y_mel = commons.slice_segments(mel, ids_slice, + self.segment_size // self.hop_length) + y_hat_mel = mel_spectrogram_torch( + y_hat.squeeze(1), + self.filter_length, + self.n_mel_channels, + self.sampling_rate, + self.hop_length, + self.win_length, + ) + y = commons.slice_segments(y, ids_slice * self.hop_length, + self.segment_size) + # Train generator and discriminator alternately + if self.step % self.d_interval == 0: + # Discriminator + y_d_hat_r, y_d_hat_g, _, _ = self.d(y, y_hat.detach()) + loss_disc, losses_disc_r, losses_disc_g = discriminator_loss( + y_d_hat_r, y_d_hat_g) + loss_disc_all = loss_disc + losses = {'loss': loss_disc_all, 'loss_disc': loss_disc_all} + else: + # Generator loss + y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.d(y, y_hat) + loss_dur = torch.sum(l_length.float()) + loss_mel = F.l1_loss(y_mel, y_hat_mel) * self.c_mel + loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * self.c_kl + loss_fm = feature_loss(fmap_r, fmap_g) + loss_gen, losses_gen = generator_loss(y_d_hat_g) + loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl + losses = { + 'loss': loss_gen_all, + 'loss_gen': loss_gen, + 'loss_fm': loss_fm, + 'loss_mel': loss_mel, + 'loss_dur': loss_dur, + 'loss_kl': loss_kl, + } + + self.step += 1 + return losses + + def infer(self, text: torch.Tensor): + assert text.dim() == 1 + device = text.device + x_length = torch.tensor([text.size(0)], + device=device, + dtype=torch.long) + x = text.unsqueeze(0) + sid = torch.zeros(1, device=device, dtype=torch.long) + audio, *_ = self.g.infer(x, x_length, sid=sid) + return audio diff --git a/wenet/tts/vits/modules.py b/wenet/tts/vits/modules.py new file mode 100644 index 000000000..1678ffb90 --- /dev/null +++ b/wenet/tts/vits/modules.py @@ -0,0 +1,513 @@ +import math + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn import Conv1d +from torch.nn.utils import remove_weight_norm +from torch.nn.utils import weight_norm + +import wenet.tts.vits.commons as commons +from wenet.tts.vits.commons import init_weights, get_padding +from wenet.tts.vits.transforms import piecewise_rational_quadratic_transform + +LRELU_SLOPE = 0.1 + + +class LayerNorm(nn.Module): + + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels, ), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +class ConvReluNorm(nn.Module): + + def __init__( + self, + in_channels, + hidden_channels, + out_channels, + kernel_size, + n_layers, + p_dropout, + ): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + assert n_layers > 1, "Number of layers should be larger than 0." + + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + self.conv_layers.append( + nn.Conv1d(in_channels, + hidden_channels, + kernel_size, + padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout)) + for _ in range(n_layers - 1): + self.conv_layers.append( + nn.Conv1d( + hidden_channels, + hidden_channels, + kernel_size, + padding=kernel_size // 2, + )) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask + + +class DDSConv(nn.Module): + """ + Dialted and Depth-Separable Convolution + """ + + def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + + self.drop = nn.Dropout(p_dropout) + self.convs_sep = nn.ModuleList() + self.convs_1x1 = nn.ModuleList() + self.norms_1 = nn.ModuleList() + self.norms_2 = nn.ModuleList() + for i in range(n_layers): + dilation = kernel_size**i + padding = (kernel_size * dilation - dilation) // 2 + self.convs_sep.append( + nn.Conv1d( + channels, + channels, + kernel_size, + groups=channels, + dilation=dilation, + padding=padding, + )) + self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) + self.norms_1.append(LayerNorm(channels)) + self.norms_2.append(LayerNorm(channels)) + + def forward(self, x, x_mask, g=None): + if g is not None: + x = x + g + for i in range(self.n_layers): + y = self.convs_sep[i](x * x_mask) + y = self.norms_1[i](y) + y = F.gelu(y) + y = self.convs_1x1[i](y) + y = self.norms_2[i](y) + y = F.gelu(y) + y = self.drop(y) + x = x + y + return x * x_mask + + +class WN(torch.nn.Module): + + def __init__( + self, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels, + p_dropout=0, + ): + super(WN, self).__init__() + assert kernel_size % 2 == 1 + self.hidden_channels = hidden_channels + self.kernel_size = (kernel_size, ) + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + self.p_dropout = p_dropout + + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() + self.drop = nn.Dropout(p_dropout) + + cond_layer = torch.nn.Conv1d(gin_channels, + 2 * hidden_channels * n_layers, 1) + self.cond_layer = weight_norm(cond_layer, name="weight") + + for i in range(n_layers): + dilation = dilation_rate**i + padding = int((kernel_size * dilation - dilation) / 2) + in_layer = torch.nn.Conv1d( + hidden_channels, + 2 * hidden_channels, + kernel_size, + dilation=dilation, + padding=padding, + ) + in_layer = weight_norm(in_layer, name="weight") + self.in_layers.append(in_layer) + + # last one is not necessary + if i < n_layers - 1: + res_skip_channels = 2 * hidden_channels + else: + res_skip_channels = hidden_channels + + res_skip_layer = torch.nn.Conv1d(hidden_channels, + res_skip_channels, 1) + res_skip_layer = weight_norm(res_skip_layer, name="weight") + self.res_skip_layers.append(res_skip_layer) + + def forward(self, x, x_mask, g=None, **kwargs): + output = torch.zeros_like(x) + n_channels_tensor = torch.IntTensor([self.hidden_channels]) + + g = self.cond_layer(g) + + for i in range(self.n_layers): + x_in = self.in_layers[i](x) + cond_offset = i * 2 * self.hidden_channels + g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :] + + acts = commons.fused_add_tanh_sigmoid_multiply( + x_in, g_l, n_channels_tensor) + acts = self.drop(acts) + + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.n_layers - 1: + res_acts = res_skip_acts[:, :self.hidden_channels, :] + x = (x + res_acts) * x_mask + output = output + res_skip_acts[:, self.hidden_channels:, :] + else: + output = output + res_skip_acts + return output * x_mask + + def remove_weight_norm(self): + torch.nn.utils.remove_weight_norm(self.cond_layer) + for l in self.in_layers: + torch.nn.utils.remove_weight_norm(l) + for l in self.res_skip_layers: + torch.nn.utils.remove_weight_norm(l) + + +class ResBlock1(torch.nn.Module): + + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.convs1 = nn.ModuleList([ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + )), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + )), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + )), + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + )), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + )), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + )), + ]) + self.convs2.apply(init_weights) + + def forward(self, x, x_mask=None): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c2(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.convs = nn.ModuleList([ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + )), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + )), + ]) + self.convs.apply(init_weights) + + def forward(self, x, x_mask=None): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class Log(nn.Module): + + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask + logdet = torch.sum(-y, [1, 2]) + return y, logdet + else: + x = torch.exp(x) * x_mask + return x + + +class Flip(nn.Module): + + def forward(self, x, *args, reverse=False, **kwargs): + x = torch.flip(x, [1]) + if not reverse: + logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) + return x, logdet + else: + return x + + +class ElementwiseAffine(nn.Module): + + def __init__(self, channels): + super().__init__() + self.channels = channels + self.m = nn.Parameter(torch.zeros(channels, 1)) + self.logs = nn.Parameter(torch.zeros(channels, 1)) + + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = self.m + torch.exp(self.logs) * x + y = y * x_mask + logdet = torch.sum(self.logs * x_mask, [1, 2]) + return y, logdet + else: + x = (x - self.m) * torch.exp(-self.logs) * x_mask + return x + + +class ResidualCouplingLayer(nn.Module): + + def __init__( + self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=0, + gin_channels=256, + mean_only=False, + ): + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.half_channels = channels // 2 + self.mean_only = mean_only + + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.enc = WN( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=p_dropout, + gin_channels=gin_channels, + ) + self.post = nn.Conv1d(hidden_channels, + self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, logs = torch.split(stats, [self.half_channels] * 2, 1) + else: + m = stats + logs = torch.zeros_like(m) + + if not reverse: + x1 = m + x1 * torch.exp(logs) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(logs, [1, 2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-logs) * x_mask + x = torch.cat([x0, x1], 1) + return x + + def remove_weight_norm(self): + self.enc.remove_weight_norm() + + +class ConvFlow(nn.Module): + + def __init__( + self, + in_channels, + filter_channels, + kernel_size, + n_layers, + num_bins=10, + tail_bound=5.0, + ): + super().__init__() + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.num_bins = num_bins + self.tail_bound = tail_bound + self.half_channels = in_channels // 2 + + self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) + self.convs = DDSConv(filter_channels, + kernel_size, + n_layers, + p_dropout=0.0) + self.proj = nn.Conv1d(filter_channels, + self.half_channels * (num_bins * 3 - 1), 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) + h = self.convs(h, x_mask, g=g) + h = self.proj(h) * x_mask + + b, c, t = x0.shape + h = h.reshape(b, c, -1, t).permute(0, 1, 3, + 2) # [b, cx?, t] -> [b, c, t, ?] + + unnormalized_widths = h[..., :self.num_bins] / math.sqrt( + self.filter_channels) + unnormalized_heights = h[..., + self.num_bins:2 * self.num_bins] / math.sqrt( + self.filter_channels) + unnormalized_derivatives = h[..., 2 * self.num_bins:] + + x1, logabsdet = piecewise_rational_quadratic_transform( + x1, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=reverse, + tails="linear", + tail_bound=self.tail_bound, + ) + + x = torch.cat([x0, x1], 1) * x_mask + logdet = torch.sum(logabsdet * x_mask, [1, 2]) + if not reverse: + return x, logdet + else: + return x diff --git a/wenet/tts/vits/monotonic_align.py b/wenet/tts/vits/monotonic_align.py new file mode 100644 index 000000000..b56b24bbc --- /dev/null +++ b/wenet/tts/vits/monotonic_align.py @@ -0,0 +1,57 @@ +import torch +import numba +import numpy as np + + +def maximum_path(neg_cent: torch.Tensor, mask: torch.Tensor): + """numba optimized version. + neg_cent: [b, t_t, t_s] + mask: [b, t_t, t_s] + """ + device = neg_cent.device + dtype = neg_cent.dtype + neg_cent = neg_cent.data.cpu().numpy().astype(np.float32) + path = np.zeros(neg_cent.shape, dtype=np.int32) + + t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32) + t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32) + maximum_path_jit(path, neg_cent, t_t_max, t_s_max) + return torch.from_numpy(path).to(device=device, dtype=dtype) + + +@numba.jit(numba.void(numba.int32[:, :, ::1], numba.float32[:, :, ::1], + numba.int32[::1], numba.int32[::1]), + nopython=True, + nogil=True) +def maximum_path_jit(paths, values, t_ys, t_xs): + b = paths.shape[0] + max_neg_val = -1e9 + for i in range(int(b)): + path = paths[i] + value = values[i] + t_y = t_ys[i] + t_x = t_xs[i] + + v_prev = v_cur = 0.0 + index = t_x - 1 + + for y in range(t_y): + for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): + if x == y: + v_cur = max_neg_val + else: + v_cur = value[y - 1, x] + if x == 0: + if y == 0: + v_prev = 0.0 + else: + v_prev = max_neg_val + else: + v_prev = value[y - 1, x - 1] + value[y, x] += max(v_prev, v_cur) + + for y in range(t_y - 1, -1, -1): + path[y, index] = 1 + if index != 0 and (index == y or + value[y - 1, index] < value[y - 1, index - 1]): + index = index - 1 diff --git a/wenet/tts/vits/transforms.py b/wenet/tts/vits/transforms.py new file mode 100644 index 000000000..1b99fd7b5 --- /dev/null +++ b/wenet/tts/vits/transforms.py @@ -0,0 +1,206 @@ +import numpy as np +import torch +from torch.nn import functional as F + +DEFAULT_MIN_BIN_WIDTH = 1e-3 +DEFAULT_MIN_BIN_HEIGHT = 1e-3 +DEFAULT_MIN_DERIVATIVE = 1e-3 + + +def piecewise_rational_quadratic_transform( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails=None, + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if tails is None: + spline_fn = rational_quadratic_spline + spline_kwargs = {} + else: + spline_fn = unconstrained_rational_quadratic_spline + spline_kwargs = {"tails": tails, "tail_bound": tail_bound} + + outputs, logabsdet = spline_fn( + inputs=inputs, + unnormalized_widths=unnormalized_widths, + unnormalized_heights=unnormalized_heights, + unnormalized_derivatives=unnormalized_derivatives, + inverse=inverse, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + **spline_kwargs) + return outputs, logabsdet + + +def searchsorted(bin_locations, inputs, eps=1e-6): + bin_locations[..., bin_locations.size(-1) - 1] += eps + return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 + + +def unconstrained_rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails="linear", + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) + outside_interval_mask = ~inside_interval_mask + + outputs = torch.zeros_like(inputs) + logabsdet = torch.zeros_like(inputs) + + if tails == "linear": + unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) + constant = np.log(np.exp(1 - min_derivative) - 1) + unnormalized_derivatives[..., 0] = constant + unnormalized_derivatives[..., + unnormalized_derivatives.size(-1) - + 1] = constant + + outputs[outside_interval_mask] = inputs[outside_interval_mask] + logabsdet[outside_interval_mask] = 0 + else: + raise RuntimeError("{} tails are not implemented.".format(tails)) + + ( + outputs[inside_interval_mask], + logabsdet[inside_interval_mask], + ) = rational_quadratic_spline( + inputs=inputs[inside_interval_mask], + unnormalized_widths=unnormalized_widths[inside_interval_mask, :], + unnormalized_heights=unnormalized_heights[inside_interval_mask, :], + unnormalized_derivatives=unnormalized_derivatives[ + inside_interval_mask, :], + inverse=inverse, + left=-tail_bound, + right=tail_bound, + bottom=-tail_bound, + top=tail_bound, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + ) + + return outputs, logabsdet + + +def rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + left=0.0, + right=1.0, + bottom=0.0, + top=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if torch.min(inputs) < left or torch.max(inputs) > right: + raise ValueError("Input to a transform is not within its domain") + + num_bins = unnormalized_widths.shape[-1] + + if min_bin_width * num_bins > 1.0: + raise ValueError("Minimal bin width too large for the number of bins") + if min_bin_height * num_bins > 1.0: + raise ValueError("Minimal bin height too large for the number of bins") + + widths = F.softmax(unnormalized_widths, dim=-1) + widths = min_bin_width + (1 - min_bin_width * num_bins) * widths + cumwidths = torch.cumsum(widths, dim=-1) + cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) + cumwidths = (right - left) * cumwidths + left + cumwidths[..., 0] = left + cumwidths[..., cumwidths.size(-1) - 1] = right + widths = cumwidths[..., 1:] - cumwidths[..., :-1] + + derivatives = min_derivative + F.softplus(unnormalized_derivatives) + + heights = F.softmax(unnormalized_heights, dim=-1) + heights = min_bin_height + (1 - min_bin_height * num_bins) * heights + cumheights = torch.cumsum(heights, dim=-1) + cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) + cumheights = (top - bottom) * cumheights + bottom + cumheights[..., 0] = bottom + cumheights[..., cumheights.size(-1) - 1] = top + heights = cumheights[..., 1:] - cumheights[..., :-1] + + if inverse: + bin_idx = searchsorted(cumheights, inputs)[..., None] + else: + bin_idx = searchsorted(cumwidths, inputs)[..., None] + + input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] + input_bin_widths = widths.gather(-1, bin_idx)[..., 0] + + input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] + delta = heights / widths + input_delta = delta.gather(-1, bin_idx)[..., 0] + + input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] + input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., + 0] + + input_heights = heights.gather(-1, bin_idx)[..., 0] + + if inverse: + a = (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * + input_delta) + input_heights * (input_delta - input_derivatives) + b = input_heights * input_derivatives - (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta) + c = -input_delta * (inputs - input_cumheights) + + discriminant = b.pow(2) - 4 * a * c + assert (discriminant >= 0).all() + + root = (2 * c) / (-b - torch.sqrt(discriminant)) + outputs = root * input_bin_widths + input_cumwidths + + theta_one_minus_theta = root * (1 - root) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta) + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * root.pow(2) + + 2 * input_delta * theta_one_minus_theta + input_derivatives * + (1 - root).pow(2)) + logabsdet = torch.log( + derivative_numerator) - 2 * torch.log(denominator) + + return outputs, -logabsdet + else: + theta = (inputs - input_cumwidths) / input_bin_widths + theta_one_minus_theta = theta * (1 - theta) + + numerator = input_heights * (input_delta * theta.pow(2) + + input_derivatives * theta_one_minus_theta) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta) + outputs = input_cumheights + numerator / denominator + + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * theta.pow(2) + + 2 * input_delta * theta_one_minus_theta + input_derivatives * + (1 - theta).pow(2)) + logabsdet = torch.log( + derivative_numerator) - 2 * torch.log(denominator) + + return outputs, logabsdet diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index 4b116bb57..4ad0c869a 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -36,6 +36,7 @@ from wenet.whisper.whisper import Whisper from wenet.utils.cmvn import load_cmvn from wenet.utils.checkpoint import load_checkpoint, load_trained_modules +from wenet.tts.vits.models import VitsModel WENET_ENCODER_CLASSES = { "transformer": TransformerEncoder, @@ -95,29 +96,31 @@ def init_model(args, configs): input_dim = configs['input_dim'] vocab_size = configs['output_dim'] - encoder_type = configs.get('encoder', 'conformer') - decoder_type = configs.get('decoder', 'bitransformer') - ctc_type = configs.get('ctc', 'ctc') - - encoder = WENET_ENCODER_CLASSES[encoder_type]( - input_dim, - global_cmvn=global_cmvn, - **configs['encoder_conf'], - **configs['encoder_conf']['efficient_conf'] - if 'efficient_conf' in configs['encoder_conf'] else {}) - - decoder = WENET_DECODER_CLASSES[decoder_type](vocab_size, - encoder.output_size(), - **configs['decoder_conf']) - - ctc = WENET_CTC_CLASSES[ctc_type]( - vocab_size, - encoder.output_size(), - blank_id=configs['ctc_conf']['ctc_blank_id'] - if 'ctc_conf' in configs else 0) - model_type = configs.get('model', 'asr_model') - if model_type == "transducer": + if model_type in ['asr_model', 'paraformer', 'transducer']: + encoder_type = configs.get('encoder', 'conformer') + decoder_type = configs.get('decoder', 'bitransformer') + ctc_type = configs.get('ctc', 'ctc') + + encoder = WENET_ENCODER_CLASSES[encoder_type]( + input_dim, + global_cmvn=global_cmvn, + **configs['encoder_conf'], + **configs['encoder_conf']['efficient_conf'] + if 'efficient_conf' in configs['encoder_conf'] else {}) + + decoder = WENET_DECODER_CLASSES[decoder_type]( + vocab_size, encoder.output_size(), **configs['decoder_conf']) + + ctc = WENET_CTC_CLASSES[ctc_type]( + vocab_size, + encoder.output_size(), + blank_id=configs['ctc_conf']['ctc_blank_id'] + if 'ctc_conf' in configs else 0) + + if model_type == 'vits': + model = VitsModel(vocab_size, input_dim, **configs['model_conf']) + elif model_type == "transducer": predictor_type = configs.get('predictor', 'rnn') joint_type = configs.get('joint', 'transducer_joint') predictor = WENET_PREDICTOR_CLASSES[predictor_type]( @@ -170,7 +173,7 @@ def init_model(args, configs): print(configs) # Tie emb.weight to decoder.output_layer.weight - if model.decoder.tie_word_embedding: + if hasattr(model, 'decoder') and model.decoder.tie_word_embedding: model.decoder.tie_or_clone_weights(jit_mode=args.jit) return model, configs From 492fff58d57937e2ac1926cc7882302f72ab14bc Mon Sep 17 00:00:00 2001 From: Binbin Zhang Date: Thu, 4 Jan 2024 16:48:33 +0800 Subject: [PATCH 2/7] add attention file --- wenet/tts/vits/attentions.py | 410 +++++++++++++++++++++++++++++++++++ 1 file changed, 410 insertions(+) create mode 100644 wenet/tts/vits/attentions.py diff --git a/wenet/tts/vits/attentions.py b/wenet/tts/vits/attentions.py new file mode 100644 index 000000000..4739217b3 --- /dev/null +++ b/wenet/tts/vits/attentions.py @@ -0,0 +1,410 @@ +import math + +import torch +from torch import nn +from torch.nn import functional as F + +import wenet.tts.vits.commons as commons +from wenet.tts.vits.modules import LayerNorm + + +class Encoder(nn.Module): + + def __init__(self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + window_size=4, + **kwargs): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.window_size = window_size + + self.drop = nn.Dropout(p_dropout) + self.attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout, + window_size=window_size, + )) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + )) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class Decoder(nn.Module): + + def __init__(self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + proximal_bias=False, + proximal_init=True, + **kwargs): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + + self.drop = nn.Dropout(p_dropout) + self.self_attn_layers = nn.ModuleList() + self.norm_layers_0 = nn.ModuleList() + self.encdec_attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.self_attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout, + proximal_bias=proximal_bias, + proximal_init=proximal_init, + )) + self.norm_layers_0.append(LayerNorm(hidden_channels)) + self.encdec_attn_layers.append( + MultiHeadAttention(hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout)) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + causal=True, + )) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask, h, h_mask): + """ + x: decoder input + h: encoder output + """ + self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to( + device=x.device, dtype=x.dtype) + encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + y = self.self_attn_layers[i](x, x, self_attn_mask) + y = self.drop(y) + x = self.norm_layers_0[i](x + y) + + y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class MultiHeadAttention(nn.Module): + + def __init__( + self, + channels, + out_channels, + n_heads, + p_dropout=0.0, + window_size=None, + heads_share=True, + block_length=None, + proximal_bias=False, + proximal_init=False, + ): + super().__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.p_dropout = p_dropout + self.window_size = window_size + self.heads_share = heads_share + self.block_length = block_length + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = nn.Conv1d(channels, channels, 1) + self.conv_k = nn.Conv1d(channels, channels, 1) + self.conv_v = nn.Conv1d(channels, channels, 1) + self.conv_o = nn.Conv1d(channels, out_channels, 1) + self.drop = nn.Dropout(p_dropout) + + if window_size is not None: + n_heads_rel = 1 if heads_share else n_heads + rel_stddev = self.k_channels**-0.5 + self.emb_rel_k = nn.Parameter( + torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) + * rel_stddev) + self.emb_rel_v = nn.Parameter( + torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) + * rel_stddev) + + nn.init.xavier_uniform_(self.conv_q.weight) + nn.init.xavier_uniform_(self.conv_k.weight) + nn.init.xavier_uniform_(self.conv_v.weight) + if proximal_init: + with torch.no_grad(): + self.conv_k.weight.copy_(self.conv_q.weight) + self.conv_k.bias.copy_(self.conv_q.bias) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, self.attn = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + # reshape [b, d, t] -> [b, n_h, t, d_k] + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = query.view(b, self.n_heads, self.k_channels, + t_t).transpose(2, 3) + key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + value = value.view(b, self.n_heads, self.k_channels, + t_s).transpose(2, 3) + + scores = torch.matmul(query / math.sqrt(self.k_channels), + key.transpose(-2, -1)) + if self.window_size is not None: + msg = "Relative attention is only available for self-attention." + assert t_s == t_t, msg + key_relative_embeddings = self._get_relative_embeddings( + self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys( + query / math.sqrt(self.k_channels), key_relative_embeddings) + scores_local = self._relative_position_to_absolute_position( + rel_logits) + scores = scores + scores_local + if self.proximal_bias: + msg = "Proximal bias is only available for self-attention." + assert t_s == t_t, msg + scores = scores + self._attention_bias_proximal(t_s).to( + device=scores.device, dtype=scores.dtype) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + if self.block_length is not None: + msg = "Local attention is only available for self-attention." + assert t_s == t_t, msg + block_mask = ( + torch.ones_like(scores).triu(-self.block_length).tril( + self.block_length)) + scores = scores.masked_fill(block_mask == 0, -1e4) + p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + if self.window_size is not None: + relative_weights = self._absolute_position_to_relative_position( + p_attn) + value_relative_embeddings = self._get_relative_embeddings( + self.emb_rel_v, t_s) + output = output + self._matmul_with_relative_values( + relative_weights, value_relative_embeddings) + output = (output.transpose(2, 3).contiguous().view(b, d, t_t) + ) # [b, n_h, t_t, d_k] -> [b, d, t_t] + return output, p_attn + + def _matmul_with_relative_values(self, x, y): + """ + x: [b, h, l, m] + y: [h or 1, m, d] + ret: [b, h, l, d] + """ + ret = torch.matmul(x, y.unsqueeze(0)) + return ret + + def _matmul_with_relative_keys(self, x, y): + """ + x: [b, h, l, d] + y: [h or 1, m, d] + ret: [b, h, l, m] + """ + ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) + return ret + + def _get_relative_embeddings(self, relative_embeddings, length): + max_relative_position = 2 * self.window_size + 1 + # Pad first before slice to avoid using cond ops. + pad_length = max(length - (self.window_size + 1), 0) + slice_start_position = max((self.window_size + 1) - length, 0) + slice_end_position = slice_start_position + 2 * length - 1 + if pad_length > 0: + padded_relative_embeddings = F.pad( + relative_embeddings, + commons.convert_pad_shape([[0, 0], [pad_length, pad_length], + [0, 0]]), + ) + else: + padded_relative_embeddings = relative_embeddings + used_relative_embeddings = padded_relative_embeddings[:, + slice_start_position: + slice_end_position] + return used_relative_embeddings + + def _relative_position_to_absolute_position(self, x): + """ + x: [b, h, l, 2*l-1] + ret: [b, h, l, l] + """ + batch, heads, length, _ = x.size() + # Concat columns of pad to shift from relative to absolute indexing. + x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, + 1]])) + + # Concat extra elements so to add up to shape (len+1, 2*len-1). + x_flat = x.view([batch, heads, length * 2 * length]) + x_flat = F.pad( + x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, + length - 1]])) + + # Reshape and slice out the padded elements. + x_final = x_flat.view([batch, heads, length + 1, + 2 * length - 1])[:, :, :length, length - 1:] + return x_final + + def _absolute_position_to_relative_position(self, x): + """ + x: [b, h, l, l] + ret: [b, h, l, 2*l-1] + """ + batch, heads, length, _ = x.size() + # padd along column + x = F.pad( + x, + commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, + length - 1]])) + x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) + # add 0's in the beginning that will skew the elements after reshape + x_flat = F.pad( + x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) + x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] + return x_final + + def _attention_bias_proximal(self, length): + """Bias for self-attention to encourage attention to close positions. + Args: + length: an integer scalar. + Returns: + a Tensor with shape [1, 1, length, length] + """ + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze( + torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + filter_channels, + kernel_size, + p_dropout=0.0, + activation=None, + causal=False, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.activation = activation + self.causal = causal + + if causal: + self.padding = self._causal_padding + else: + self.padding = self._same_padding + + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) + self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) + self.drop = nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(self.padding(x * x_mask)) + if self.activation == "gelu": + x = x * torch.sigmoid(1.702 * x) + else: + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(self.padding(x * x_mask)) + return x * x_mask + + def _causal_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = self.kernel_size - 1 + pad_r = 0 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x + + def _same_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = (self.kernel_size - 1) // 2 + pad_r = self.kernel_size // 2 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x From 34b0554d7bb7379e5ec102575b375fef5b6fd2cf Mon Sep 17 00:00:00 2001 From: Binbin Zhang Date: Thu, 4 Jan 2024 18:43:36 +0800 Subject: [PATCH 3/7] mask pad value to 0 --- wenet/tts/vits/models.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/wenet/tts/vits/models.py b/wenet/tts/vits/models.py index 2f52914ac..f43867f99 100644 --- a/wenet/tts/vits/models.py +++ b/wenet/tts/vits/models.py @@ -15,6 +15,7 @@ from wenet.tts.vits.commons import init_weights, get_padding from wenet.tts.vits.losses import generator_loss, discriminator_loss, feature_loss, kl_loss from wenet.tts.vits.mel_processing import mel_spectrogram_torch +from wenet.utils.mask import make_pad_mask class StochasticDurationPredictor(nn.Module): @@ -791,6 +792,8 @@ def __init__(self, n_vocab, spec_channels, **kwargs): def forward(self, batch: dict, device: torch.device): x = batch['target'].to(device) x_lengths = batch['target_lengths'].to(device) + x_mask = make_pad_mask(x_lengths) + x = x.masked_fill(x_mask, 0) # change pad value(IGNORE_ID = -1) to 0 spec = batch['feats'].to(device) spec_lengths = batch['feats_lengths'].to(device) spec = spec.transpose(1, 2) From c3a80f9d894d02b0448c39cb1a18b98da87c72fb Mon Sep 17 00:00:00 2001 From: Binbin Zhang Date: Thu, 4 Jan 2024 20:26:25 +0800 Subject: [PATCH 4/7] add megabyte --- wenet/transformer/embedding.py | 2 +- wenet/tts/megabyte.py | 248 +++++++++++++++++++++++++++++++++ wenet/utils/init_model.py | 6 + 3 files changed, 255 insertions(+), 1 deletion(-) create mode 100644 wenet/tts/megabyte.py diff --git a/wenet/transformer/embedding.py b/wenet/transformer/embedding.py index 17d8810ff..e41bf593c 100644 --- a/wenet/transformer/embedding.py +++ b/wenet/transformer/embedding.py @@ -36,7 +36,7 @@ class PositionalEncoding(torch.nn.Module): def __init__(self, d_model: int, dropout_rate: float, - max_len: int = 5000, + max_len: int = 10000, reverse: bool = False): """Construct an PositionalEncoding object.""" super().__init__() diff --git a/wenet/tts/megabyte.py b/wenet/tts/megabyte.py new file mode 100644 index 000000000..3793b15db --- /dev/null +++ b/wenet/tts/megabyte.py @@ -0,0 +1,248 @@ +# Copyright (c) 2023 Binbin Zhang(binbzha@qq.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional + +import torch +import torchaudio +import torch.nn.functional as F +import torch.nn as nn +from torch.nn.utils.rnn import pad_sequence +from encodec import EncodecModel + +from wenet.utils.common import (IGNORE_ID, th_accuracy) +from wenet.utils.class_utils import WENET_EMB_CLASSES +from wenet.utils.mask import make_pad_mask, subsequent_mask + + +class MegaByte(nn.Module): + + def __init__(self, + vocab_size: int, + g_num_layers: int = 12, + g_nhead: int = 8, + g_d_model: int = 512, + g_dim_feedforward: int = 2048, + l_num_layers: int = 6, + l_nhead: int = 8, + l_d_model: int = 256, + l_dim_feedforward: int = 1024, + ctc_weight: float = 0.3): + super().__init__() + self.audio_size = 1024 + 1 # 1 is last one + self.num_quantizer = 8 + self.text_sos = 2 + self.text_eos = 2 + self.audio_sos = 1024 + self.audio_eos = 1024 + self.ignore_id = IGNORE_ID + self.g_nhead = g_nhead + assert g_d_model % self.num_quantizer == 0 + self.g_embedding_size = int(g_d_model / self.num_quantizer) + self.g_model = nn.TransformerEncoder( + nn.TransformerEncoderLayer(d_model=g_d_model, + nhead=self.g_nhead, + dim_feedforward=g_dim_feedforward, + batch_first=True), + num_layers=g_num_layers, + norm=nn.LayerNorm(g_d_model, eps=1e-5), + ) + self.l_model = nn.TransformerEncoder( + nn.TransformerEncoderLayer(d_model=l_d_model, + nhead=l_nhead, + dim_feedforward=l_dim_feedforward, + batch_first=True), + num_layers=l_num_layers, + norm=nn.LayerNorm(l_d_model, eps=1e-5), + ) + self.g_audio_embedding = nn.Sequential( + nn.Embedding(self.audio_size, self.g_embedding_size), + WENET_EMB_CLASSES['abs_pos'](self.g_embedding_size, 0.1), + ) + self.l_audio_embedding = nn.Sequential( + nn.Embedding(self.audio_size, l_d_model), + WENET_EMB_CLASSES['abs_pos'](l_d_model, 0.1), + ) + self.text_embedding = nn.Sequential( + nn.Embedding(vocab_size, g_d_model), + WENET_EMB_CLASSES['abs_pos'](g_d_model, 0.1), + ) + self.g2l_linear = nn.Linear(self.g_embedding_size, l_d_model) + self.projection = nn.Linear(l_d_model, self.audio_size) + self.codec = EncodecModel.encodec_model_24khz() + self.codec.set_target_bandwidth(6.0) + + def forward(self, batch: dict, + device: torch.device) -> Dict[str, Optional[torch.Tensor]]: + text = batch['target'].to(device) + text_lengths = batch['target_lengths'].to(device) + wavs = batch['pcm'] + # 1. on-the-fly quantization + audio = [] + for wav in wavs: + wav = wav.to(device).unsqueeze(0) + wav = torchaudio.functional.resample(wav, 16000, + self.codec.sample_rate) + wav = wav.unsqueeze(0) + with torch.no_grad(): + encoded_frames = self.codec.encode(wav) + vq = encoded_frames[0][0][0].transpose(0, 1) + audio.append(vq) + audio_lengths = torch.tensor([x.size(0) for x in audio], + dtype=torch.int32, + device=device) + audio = pad_sequence(audio, + batch_first=True, + padding_value=self.audio_eos) + text_mask = make_pad_mask(text_lengths) + text = text.masked_fill(text_mask, self.text_eos) + text = F.pad(text, (1, 1), value=self.text_eos) # eos same as sos + text_lengths = text_lengths + 2 + text_pad_mask = make_pad_mask(text_lengths) + audio_pad_mask = make_pad_mask(audio_lengths + 1) # add sos + text_audio_pad_mask = torch.concat([text_pad_mask, audio_pad_mask], + dim=1) + text_len, audio_len = text.size(1), audio.size(1) + 1 + text_audio_len = text_len + audio_len + batch_size = text.size(0) + # 2. Global model + text_emb, _ = self.text_embedding(text) + g_audio = torch.concat( + [torch.ones_like(audio[:, :1, :]) * self.audio_sos, audio], + dim=1) # add sos + g_audio_emb, _ = self.g_audio_embedding(g_audio.view(batch_size, -1)) + g_audio_emb = g_audio_emb.view(batch_size, audio_len, -1) + text_audio_emb = torch.concat([text_emb, g_audio_emb], dim=1) + text_attn_mask = F.pad( + torch.zeros((text_len, text_len), dtype=torch.bool, device=device), + (0, audio_len), + value=True, + ) + audio_attn_mask = F.pad( + torch.triu( + torch.ones(audio_len, + audio_len, + dtype=torch.bool, + device=device), + diagonal=1, + ), + (text_len, 0), + value=False, + ) + attn_mask = torch.concat([text_attn_mask, audio_attn_mask], dim=0) + pad_mask = text_audio_pad_mask.view(batch_size, 1, 1, text_audio_len) + pad_mask = pad_mask.expand(-1, self.g_nhead, -1, -1) + pad_mask = pad_mask.reshape(batch_size * self.g_nhead, 1, + text_audio_len) + attn_mask = attn_mask.logical_or(pad_mask) + f_mask = torch.zeros_like(attn_mask, dtype=torch.float) + f_mask = f_mask.masked_fill(attn_mask, float('-inf')) + g_output = self.g_model(text_audio_emb, + f_mask)[:, text_len:, :].contiguous() + g_output = g_output.view(batch_size * audio_len, self.num_quantizer, + -1) + g_logits = self.g2l_linear(g_output) + # 3. Local model + l_audio = torch.concat( + [audio, torch.ones_like(audio[:, :1, :]) * self.audio_eos], + dim=1) # add global eos + l_label = l_audio.masked_fill(audio_pad_mask.unsqueeze(-1), + self.ignore_id) + l_label = l_label.view(batch_size * audio_len, self.num_quantizer) + l_audio = l_audio.view(batch_size * audio_len, self.num_quantizer) + l_input = F.pad(l_audio[:, :-1], (1, 0), + value=self.audio_sos) # add local sos + l_input, _ = self.l_audio_embedding(l_input) + l_input = l_input + g_logits + mask = ~subsequent_mask(self.num_quantizer, device) + l_logits = self.l_model(l_input, mask) + l_logits = self.projection(l_logits) + loss = F.cross_entropy(l_logits.permute(0, 2, 1), + l_label, + ignore_index=self.ignore_id) + acc = th_accuracy(l_logits.view(-1, self.audio_size), + l_label, + ignore_label=self.ignore_id) + return { + 'loss': loss, + 'acc': torch.tensor(acc), + } + + def inference(self, audio: torch.Tensor, ref_text: torch.Tensor, + syn_text: torch.Tensor, device: torch.device): + batch_size = audio.size(0) + assert batch_size == 1 + text = torch.concat([ref_text, syn_text], dim=1) + print(text) + text = F.pad(text, (1, 1), value=self.text_eos) # add sos & eos + text_len = text.size(1) + text_emb, _ = self.text_embedding(text) + + max_len = 75 * 1 # 2 seconds + src_audio = audio + # TODO(Binbin Zhang): Add cache + for step in range(max_len): + # Global + g_audio = torch.concat( + [torch.ones_like(audio[:, :1, :]) * self.audio_sos, audio], + dim=1) # add sos + audio_len = g_audio.size(1) + g_audio_emb, _ = self.g_audio_embedding( + g_audio.view(batch_size, -1)) + g_audio_emb = g_audio_emb.view(batch_size, audio_len, -1) + text_audio_emb = torch.concat([text_emb, g_audio_emb], dim=1) + text_attn_mask = F.pad( + torch.zeros((text_len, text_len), + dtype=torch.bool, + device=device), + (0, audio_len), + value=True, + ) + audio_attn_mask = F.pad( + torch.triu( + torch.ones(audio_len, + audio_len, + dtype=torch.bool, + device=device), + diagonal=1, + ), + (text_len, 0), + value=False, + ) + attn_mask = torch.concat([text_attn_mask, audio_attn_mask], dim=0) + g_output = self.g_model(text_audio_emb, + attn_mask)[:, -1, :].contiguous() + g_output = g_output.view(batch_size, self.num_quantizer, + -1) # 1, 8, g_emb + g_logits = self.g2l_linear(g_output) # 1, 8, l_d_model + # Local + la = [self.audio_sos] + for i in range(self.num_quantizer): + l_input = torch.tensor(la, dtype=torch.long, + device=device).unsqueeze(0) + l_input, _ = self.l_audio_embedding(l_input) + l_input = l_input + g_logits[:, :i + 1, :] + mask = ~subsequent_mask(i + 1, device) + l_logits = self.l_model(l_input, mask) + l_logits = self.projection(l_logits) + pred = l_logits[0, -1, :].argmax().item() + la.append(pred) + print(step, la[1:]) + if self.audio_eos in la[1:]: + break + gen = torch.tensor(la[1:], dtype=torch.long, device=device) + gen = gen.view(1, 1, self.num_quantizer) + audio = torch.concat([audio, gen], dim=1) + print(audio.size()) + return audio diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index 4ad0c869a..83a55d2d4 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -37,6 +37,8 @@ from wenet.utils.cmvn import load_cmvn from wenet.utils.checkpoint import load_checkpoint, load_trained_modules from wenet.tts.vits.models import VitsModel +from wenet.tts.megabyte import MegaByte +from wenet.tts.valle import VallE WENET_ENCODER_CLASSES = { "transformer": TransformerEncoder, @@ -120,6 +122,10 @@ def init_model(args, configs): if model_type == 'vits': model = VitsModel(vocab_size, input_dim, **configs['model_conf']) + elif model_type == 'megabyte': + model = MegaByte(vocab_size, **configs['model_conf']) + elif model_type == 'valle': + model = VallE(vocab_size, **configs['model_conf']) elif model_type == "transducer": predictor_type = configs.get('predictor', 'rnn') joint_type = configs.get('joint', 'transducer_joint') From 6455487dae75fd56874d6cf4e5f9f26905bd0b66 Mon Sep 17 00:00:00 2001 From: Binbin Zhang Date: Thu, 4 Jan 2024 20:36:49 +0800 Subject: [PATCH 5/7] add valle --- wenet/tts/valle.py | 189 ++++++++++++++++++++++++++++++++++++++ wenet/utils/init_model.py | 6 +- 2 files changed, 192 insertions(+), 3 deletions(-) create mode 100644 wenet/tts/valle.py diff --git a/wenet/tts/valle.py b/wenet/tts/valle.py new file mode 100644 index 000000000..67fa16514 --- /dev/null +++ b/wenet/tts/valle.py @@ -0,0 +1,189 @@ +# Copyright (c) 2023 Binbin Zhang(binbzha@qq.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +from typing import Dict, Optional + +import torch +import torchaudio +import torch.nn.functional as F +import torch.nn as nn +from torch.nn.utils.rnn import pad_sequence +from encodec import EncodecModel + +from wenet.utils.common import (IGNORE_ID, add_sos_eos, th_accuracy) +from wenet.utils.class_utils import WENET_EMB_CLASSES +from wenet.utils.mask import make_pad_mask + + +class VallE(nn.Module): + + def __init__(self, + vocab_size: int, + tie_word_embedding: bool = True, + num_blocks: int = 12, + attention_heads: int = 16, + attention_dim: int = 1024, + linear_units: int = 4096, + dropout_rate: float = 0.1, + ctc_weight: float = 0.3): + super().__init__() + self.audio_size = 1024 + 1 # 1 is last one + self.num_quantizer = 8 + self.text_sos = 2 + self.text_eos = 2 + self.audio_sos = 1024 + self.audio_eos = 1024 + self.ignore_id = IGNORE_ID + self.nhead = attention_heads + self.ar_decoder = nn.TransformerEncoder( + nn.TransformerEncoderLayer(d_model=attention_dim, + nhead=self.nhead, + dim_feedforward=linear_units, + batch_first=True), + num_layers=num_blocks, + norm=nn.LayerNorm(attention_dim, eps=1e-5), + ) + self.nar_decoder = nn.TransformerEncoder( + nn.TransformerEncoderLayer(d_model=attention_dim, + nhead=self.nhead, + dim_feedforward=linear_units, + batch_first=True), + num_layers=num_blocks, + norm=nn.LayerNorm(attention_dim, eps=1e-5), + ) + self.ar_text_embedding = nn.Sequential( + nn.Embedding(vocab_size, attention_dim), + WENET_EMB_CLASSES['abs_pos'](attention_dim, 0.1), + ) + self.nar_text_embedding = nn.Sequential( + nn.Embedding(vocab_size, attention_dim), + WENET_EMB_CLASSES['abs_pos'](attention_dim, 0.1), + ) + self.audio_embedding = nn.ModuleList([ + nn.Sequential( + nn.Embedding(self.audio_size, attention_dim), + WENET_EMB_CLASSES['abs_pos'](attention_dim, 0.1), + ) for i in range(self.num_quantizer) + ]) + self.projection = nn.ModuleList([ + nn.Linear(attention_dim, self.audio_size) + for i in range(self.num_quantizer) + ]) + if tie_word_embedding: + for i in range(self.num_quantizer): + self.projection[i].weight = self.audio_embedding[i][0].weight + self.codec = EncodecModel.encodec_model_24khz() + self.codec.set_target_bandwidth(6.0) + + def forward(self, batch: dict, + device: torch.device) -> Dict[str, Optional[torch.Tensor]]: + text = batch['target'].to(device) + text_lengths = batch['target_lengths'].to(device) + wavs = batch['pcm'] + # 1. on-the-fly quantization + audio = [] + for wav in wavs: + wav = wav.to(device).unsqueeze(0) + wav = torchaudio.functional.resample(wav, 16000, + self.codec.sample_rate) + wav = wav.unsqueeze(0) + with torch.no_grad(): + encoded_frames = self.codec.encode(wav) + vq = encoded_frames[0][0][0].transpose(0, 1) + audio.append(vq) + audio_lengths = torch.tensor([x.size(0) for x in audio], + dtype=torch.int32, + device=device) + audio = pad_sequence(audio, batch_first=True, padding_value=0) + text_mask = make_pad_mask(text_lengths) + text = text.masked_fill(text_mask, self.text_eos) + text = F.pad(text, (1, 1), value=self.text_eos) # eos same as sos + text_lengths = text_lengths + 2 + text_pad_mask = make_pad_mask(text_lengths) + audio_pad_mask = make_pad_mask(audio_lengths + 1) # add sos/eos + text_audio_pad_mask = torch.concat([text_pad_mask, audio_pad_mask], + dim=1) + text_len, audio_len = text.size(1), audio.size(1) + 1 + text_audio_len = text_len + audio_len + batch_size = text.size(0) + + # 2-1. AR decoder branch + ar_text_emb, _ = self.ar_text_embedding(text) + ar_audio_in, ar_audio_out = add_sos_eos(audio[:, :, 0], self.audio_sos, + self.audio_eos, self.ignore_id) + ar_audio_emb, _ = self.audio_embedding[0](ar_audio_in) + ar_text_audio_emb = torch.concat([ar_text_emb, ar_audio_emb], dim=1) + text_attn_mask = F.pad( + torch.zeros((text_len, text_len), dtype=torch.bool, device=device), + (0, audio_len), + value=True, + ) + audio_attn_mask = F.pad( + torch.triu( + torch.ones(audio_len, + audio_len, + dtype=torch.bool, + device=device), + diagonal=1, + ), + (text_len, 0), + value=False, + ) + text_audio_attn_mask = torch.concat([text_attn_mask, audio_attn_mask], + dim=0) + pad_mask = text_audio_pad_mask.view(batch_size, 1, 1, text_audio_len) + pad_mask = pad_mask.expand(-1, self.nhead, -1, -1) + pad_mask = pad_mask.reshape(batch_size * self.nhead, 1, text_audio_len) + text_audio_attn_mask = text_audio_attn_mask.logical_or(pad_mask) + fmask = torch.zeros_like(text_audio_attn_mask, dtype=torch.float) + fmask = fmask.masked_fill(text_audio_attn_mask, float('-inf')) + ar_decoder_out = self.ar_decoder(ar_text_audio_emb, fmask) + ar_decoder_out = self.projection[0]( + ar_decoder_out)[:, text_len:, :].contiguous() + ar_loss = F.cross_entropy(ar_decoder_out.permute(0, 2, 1), + ar_audio_out, + ignore_index=self.ignore_id) + ar_acc = th_accuracy(ar_decoder_out.view(-1, self.audio_size), + ar_audio_out, + ignore_label=self.ignore_id) + # 2-2. NAR decoder branch, random sample one to train + k = random.randint(1, self.num_quantizer - 1) + nar_text_emb, _ = self.nar_text_embedding(text) + nar_audio_in, nar_audio_out = audio[:, :, k - 1], audio[:, :, k] + nar_audio_emb, _ = self.audio_embedding[k](nar_audio_in) + nar_text_audio_emb = torch.concat([nar_text_emb, nar_audio_emb], dim=1) + audio_pad_mask = make_pad_mask(audio_lengths) + nar_audio_out = nar_audio_out.masked_fill(audio_pad_mask, + self.ignore_id) + text_audio_mask = torch.concat([text_pad_mask, audio_pad_mask], dim=1) + nar_decoder_out = self.nar_decoder( + nar_text_audio_emb, src_key_padding_mask=text_audio_mask) + nar_decoder_out = self.projection[k]( + nar_decoder_out)[:, text_len:, :].contiguous() + nar_loss = F.cross_entropy(nar_decoder_out.permute(0, 2, 1), + nar_audio_out, + ignore_index=self.ignore_id) + nar_acc = th_accuracy(nar_decoder_out.view(-1, self.audio_size), + nar_audio_out, + ignore_label=self.ignore_id) + + loss = ar_loss + nar_loss + return { + 'loss': loss, + 'ar_loss': ar_loss, + 'nar_loss': nar_loss, + 'ar_acc': torch.tensor(ar_acc), + 'nar_acc': torch.tensor(nar_acc), + } diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index 83a55d2d4..68340165a 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -36,9 +36,6 @@ from wenet.whisper.whisper import Whisper from wenet.utils.cmvn import load_cmvn from wenet.utils.checkpoint import load_checkpoint, load_trained_modules -from wenet.tts.vits.models import VitsModel -from wenet.tts.megabyte import MegaByte -from wenet.tts.valle import VallE WENET_ENCODER_CLASSES = { "transformer": TransformerEncoder, @@ -121,10 +118,13 @@ def init_model(args, configs): if 'ctc_conf' in configs else 0) if model_type == 'vits': + from wenet.tts.vits.models import VitsModel model = VitsModel(vocab_size, input_dim, **configs['model_conf']) elif model_type == 'megabyte': + from wenet.tts.megabyte import MegaByte model = MegaByte(vocab_size, **configs['model_conf']) elif model_type == 'valle': + from wenet.tts.valle import VallE model = VallE(vocab_size, **configs['model_conf']) elif model_type == "transducer": predictor_type = configs.get('predictor', 'rnn') From 25c60ba2b6533eefcd4236267ef146f2d8e603a8 Mon Sep 17 00:00:00 2001 From: Binbin Zhang Date: Mon, 8 Jan 2024 18:17:07 +0800 Subject: [PATCH 6/7] support multiple optimizer and scheduler --- wenet/bin/train.py | 6 ++-- wenet/tts/vits/models.py | 57 ++++++++++++++++++++++++++++---------- wenet/utils/executor.py | 28 +++++++++++-------- wenet/utils/train_utils.py | 23 ++++++++++++--- 4 files changed, 81 insertions(+), 33 deletions(-) diff --git a/wenet/bin/train.py b/wenet/bin/train.py index a69dd1ffc..376e6f3d1 100644 --- a/wenet/bin/train.py +++ b/wenet/bin/train.py @@ -34,7 +34,7 @@ add_trace_args, init_distributed, init_dataset_and_dataloader, check_modify_and_save_config, init_optimizer_and_scheduler, trace_and_print_model, wrap_cuda_model, init_summarywriter, save_model, - log_per_epoch) + log_per_epoch, get_lr) def get_args(): @@ -130,7 +130,7 @@ def main(): train_dataset.set_epoch(epoch) configs['epoch'] = epoch - lr = optimizer.param_groups[0]['lr'] + lr = get_lr(optimizer) logging.info('Epoch {} TRAIN info lr {} rank {}'.format( epoch, lr, rank)) @@ -148,7 +148,7 @@ def main(): total_loss, num_seen_utts = executor.cv(model, cv_data_loader, configs) cv_loss = total_loss / num_seen_utts - lr = optimizer.param_groups[0]['lr'] + lr = get_lr(optimizer) logging.info('Epoch {} CV info lr {} cv_loss {} rank {}'.format( epoch, lr, cv_loss, rank)) info_dict = { diff --git a/wenet/tts/vits/models.py b/wenet/tts/vits/models.py index f43867f99..a5664ad77 100644 --- a/wenet/tts/vits/models.py +++ b/wenet/tts/vits/models.py @@ -16,6 +16,7 @@ from wenet.tts.vits.losses import generator_loss, discriminator_loss, feature_loss, kl_loss from wenet.tts.vits.mel_processing import mel_spectrogram_torch from wenet.utils.mask import make_pad_mask +from wenet.utils.scheduler import WarmupLR class StochasticDurationPredictor(nn.Module): @@ -787,19 +788,45 @@ def __init__(self, n_vocab, spec_channels, **kwargs): self.segment_size // self.hop_length, **kwargs['generator']) self.d = MultiPeriodDiscriminator(**kwargs['discriminator']) - self.step = 0 + + def get_optimizer(self): + optim_d = torch.optim.AdamW(self.d.parameters(), + 0.0002, + betas=[0.8, 0.99], + eps=1.0e-9) + optim_g = torch.optim.AdamW(self.g.parameters(), + 0.0002, + betas=[0.8, 0.99], + eps=1.0e-9) + return [optim_d, optim_g] + + def get_scheduler(self, optimizer): + scheduler_d = WarmupLR(optimizer[0], warmup_steps=250) + scheduler_g = WarmupLR(optimizer[1], warmup_steps=250) + return [scheduler_d, scheduler_g] def forward(self, batch: dict, device: torch.device): x = batch['target'].to(device) x_lengths = batch['target_lengths'].to(device) x_mask = make_pad_mask(x_lengths) x = x.masked_fill(x_mask, 0) # change pad value(IGNORE_ID = -1) to 0 - spec = batch['feats'].to(device) - spec_lengths = batch['feats_lengths'].to(device) - spec = spec.transpose(1, 2) + # spec = batch['feats'].to(device) + # spec_lengths = batch['feats_lengths'].to(device) + # spec = spec.transpose(1, 2) y = batch['pcm'].to(device) y = y.unsqueeze(1) y_lengths = batch['pcm_length'].to(device) + optimizer_idx = batch.get('optimizer_idx', 0) + + spec = mel_spectrogram_torch( + y.squeeze(1), + self.filter_length, + self.n_mel_channels, + self.sampling_rate, + self.hop_length, + self.win_length, + ) + spec_lengths = (y_lengths - self.win_length) // self.hop_length + 1 batch_size = x.size(0) sid = torch.zeros(batch_size, device=device, dtype=torch.long) @@ -812,14 +839,15 @@ def forward(self, batch: dict, device: torch.device): # self.n_mel_channels, # self.sampling_rate, # ) - mel = mel_spectrogram_torch( - y.squeeze(1), - self.filter_length, - self.n_mel_channels, - self.sampling_rate, - self.hop_length, - self.win_length, - ) + # mel = mel_spectrogram_torch( + # y.squeeze(1), + # self.filter_length, + # self.n_mel_channels, + # self.sampling_rate, + # self.hop_length, + # self.win_length, + # ) + mel = spec y_mel = commons.slice_segments(mel, ids_slice, self.segment_size // self.hop_length) y_hat_mel = mel_spectrogram_torch( @@ -833,7 +861,7 @@ def forward(self, batch: dict, device: torch.device): y = commons.slice_segments(y, ids_slice * self.hop_length, self.segment_size) # Train generator and discriminator alternately - if self.step % self.d_interval == 0: + if optimizer_idx == 0: # Discriminator y_d_hat_r, y_d_hat_g, _, _ = self.d(y, y_hat.detach()) loss_disc, losses_disc_r, losses_disc_g = discriminator_loss( @@ -857,8 +885,7 @@ def forward(self, batch: dict, device: torch.device): 'loss_dur': loss_dur, 'loss_kl': loss_kl, } - - self.step += 1 + print('optimizer_idx', optimizer_idx) return losses def infer(self, text: torch.Tensor): diff --git a/wenet/utils/executor.py b/wenet/utils/executor.py index 7cf148758..ca8def4a3 100644 --- a/wenet/utils/executor.py +++ b/wenet/utils/executor.py @@ -23,7 +23,7 @@ from wenet.utils.train_utils import (wenet_join, batch_forward, batch_backward, update_parameter_and_lr, log_per_step, - save_model) + save_model, get_lr) class Executor: @@ -69,15 +69,21 @@ def train(self, model, optimizer, scheduler, train_data_loader, # processes. else: context = nullcontext - - with context(): - info_dict = batch_forward(model, batch_dict, scaler, - info_dict) - info_dict = batch_backward(model, scaler, info_dict) - - info_dict = update_parameter_and_lr(model, optimizer, - scheduler, scaler, - info_dict) + num_opt = len(optimizer) if isinstance(optimizer, list) else 1 + for opt_idx in range(num_opt): + batch_dict['optimizer_idx'] = opt_idx + with context(): + info_dict = batch_forward(model, batch_dict, scaler, + info_dict) + info_dict = batch_backward(model, scaler, info_dict) + + info_dict = update_parameter_and_lr( + model, + optimizer[opt_idx] if num_opt > 1 else optimizer, + scheduler[opt_idx] if num_opt > 1 else scheduler, + scaler, + info_dict, + ) save_interval = info_dict.get('save_interval', 10000) if self.step % save_interval == 0 and self.step != 0 \ and (batch_idx + 1) % info_dict["accum_grad"] == 0: @@ -92,7 +98,7 @@ def train(self, model, optimizer, scheduler, train_data_loader, "save_time": datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S'), "lr": - optimizer.param_groups[0]['lr'] + get_lr(optimizer) }) save_model(model, info_dict) log_per_step(writer, info_dict) diff --git a/wenet/utils/train_utils.py b/wenet/utils/train_utils.py index c46a7d667..6f12b9b5f 100644 --- a/wenet/utils/train_utils.py +++ b/wenet/utils/train_utils.py @@ -310,7 +310,9 @@ def wrap_cuda_model(args, model): def init_optimizer_and_scheduler(args, configs, model): - if configs['optim'] == 'adam': + if hasattr(model.module, 'get_optimizer'): + optimizer = model.module.get_optimizer() + elif configs['optim'] == 'adam': optimizer = optim.Adam(model.parameters(), **configs['optim_conf']) elif configs['optim'] == 'adamw': optimizer = optim.AdamW(model.parameters(), **configs['optim_conf']) @@ -318,7 +320,9 @@ def init_optimizer_and_scheduler(args, configs, model): raise ValueError("unknown optimizer: " + configs['optim']) scheduler_type = None - if configs['scheduler'] == 'warmuplr': + if hasattr(model.module, 'get_scheduler'): + scheduler = model.module.get_scheduler(optimizer) + elif configs['scheduler'] == 'warmuplr': scheduler_type = WarmupLR scheduler = WarmupLR(optimizer, **configs['scheduler_conf']) elif configs['scheduler'] == 'NoamHoldAnnealing': @@ -354,7 +358,11 @@ def scheduler(opt): model_parameters=model.parameters()) step = configs["init_infos"].get("step", -1) - scheduler.set_step(step) + if isinstance(scheduler, list): + for s in scheduler: + s.set_step(step) + else: + scheduler.set_step(step) return model, optimizer, scheduler @@ -497,6 +505,13 @@ def batch_backward(model, scaler, info_dict): return info_dict +def get_lr(optimizer): + if isinstance(optimizer, list): + return optimizer[0].param_groups[0]['lr'] + else: + return optimizer.param_groups[0]['lr'] + + def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict): rank = int(os.environ.get('RANK', 0)) train_engine = info_dict.get("train_engine", "torch_ddp") @@ -543,7 +558,7 @@ def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict): optimizer.zero_grad() scheduler.step() - info_dict["lr"] = optimizer.param_groups[0]['lr'] + info_dict["lr"] = get_lr(optimizer) info_dict["grad_norm"] = grad_norm return info_dict From 89a4a5dbb96c99aa8714ad80d01b04612dfa2b26 Mon Sep 17 00:00:00 2001 From: Binbin Zhang Date: Tue, 9 Jan 2024 09:48:44 +0800 Subject: [PATCH 7/7] add generator and discriminator log info --- wenet/tts/vits/models.py | 1 - wenet/utils/executor.py | 3 +++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/wenet/tts/vits/models.py b/wenet/tts/vits/models.py index a5664ad77..a236851ab 100644 --- a/wenet/tts/vits/models.py +++ b/wenet/tts/vits/models.py @@ -885,7 +885,6 @@ def forward(self, batch: dict, device: torch.device): 'loss_dur': loss_dur, 'loss_kl': loss_kl, } - print('optimizer_idx', optimizer_idx) return losses def infer(self, text: torch.Tensor): diff --git a/wenet/utils/executor.py b/wenet/utils/executor.py index ca8def4a3..384a7527b 100644 --- a/wenet/utils/executor.py +++ b/wenet/utils/executor.py @@ -70,6 +70,7 @@ def train(self, model, optimizer, scheduler, train_data_loader, else: context = nullcontext num_opt = len(optimizer) if isinstance(optimizer, list) else 1 + loss_dict = {} for opt_idx in range(num_opt): batch_dict['optimizer_idx'] = opt_idx with context(): @@ -84,6 +85,8 @@ def train(self, model, optimizer, scheduler, train_data_loader, scaler, info_dict, ) + loss_dict.update(info_dict['loss_dict']) + info_dict['loss_dict'] = loss_dict save_interval = info_dict.get('save_interval', 10000) if self.step % save_interval == 0 and self.step != 0 \ and (batch_idx + 1) % info_dict["accum_grad"] == 0: