From 2adf6512d8d947b52ef1af6bed0d3eddf72ccb46 Mon Sep 17 00:00:00 2001 From: Dinghao Zhou Date: Fri, 16 Aug 2024 22:14:57 +0800 Subject: [PATCH] [ssl/bestrq] happy ending---stable training (#2614) * [ssl/bestrq] happy ending * fix multibook loss --- wenet/ssl/bestrq/bestrq_model.py | 48 ++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 21 deletions(-) diff --git a/wenet/ssl/bestrq/bestrq_model.py b/wenet/ssl/bestrq/bestrq_model.py index d75bc6de9..b70fd13db 100644 --- a/wenet/ssl/bestrq/bestrq_model.py +++ b/wenet/ssl/bestrq/bestrq_model.py @@ -3,7 +3,7 @@ import torch from wenet.ssl.bestrq.mask import compute_mask_indices_v2 -from wenet.utils.mask import make_pad_mask +from wenet.utils.mask import make_non_pad_mask, make_pad_mask from wenet.transformer.attention import RelPositionMultiHeadedAttention from wenet.transformer.encoder_layer import ConformerEncoderLayer @@ -96,12 +96,7 @@ def __init__( # stack input: eg: fbank self.stack_frames = self.encoder.embed.right_context + 1 self.stride = self.encoder.embed.subsampling_rate - input_dim = num_mel_bins * self.stack_frames - - # norm input - self.norm = torch.nn.LayerNorm( - input_dim, eps=norm_epsilon, elementwise_affine=False - ) if self.stack_frames > 1 else torch.nn.Identity() + input_dim = num_mel_bins * self.stride # random projectoin self.projection = torch.nn.parameter.Parameter( @@ -177,11 +172,12 @@ def forward( xs, code_ids_mask = self._apply_mask_signal(xs, xs_lens) # 2.0 stack fbank - unmasked_xs = self._stack_features(input) + unmasked_xs = self._stack_features(input, xs_lens) masked_xs = xs # 2.1 get nearest embedding target_ids = self._nearest_embedding_idx(unmasked_xs) + target_ids = target_ids[:, :code_ids_mask.size(1), :] # 3 forward xxx-formaer block and its subsampling layer out, out_mask = self.encoder(masked_xs, xs_lens) @@ -258,30 +254,40 @@ def _apply_mask_signal( xs = torch.where(masks_expand, mask_emb, input) return xs, subsampling_mask - def _stack_features(self, input: torch.Tensor) -> torch.Tensor: + def _stack_features(self, input: torch.Tensor, + input_lens: torch.Tensor) -> torch.Tensor: - stack_input = input.unfold(1, size=self.stack_frames, step=self.stride) + stack_input = input.unfold(1, size=self.stride, step=self.stride) stack_input = stack_input.transpose(-1, -2) b, n, f, d = stack_input.size() stack_input = stack_input.reshape(b, n, f * d) - return stack_input + # NOTE(Mddct): important!!! + # norm stack features + mask = make_non_pad_mask(input_lens) + stack_mask = mask.unfold(1, size=self.stride, step=self.stride) + stack_mask, _ = torch.min(stack_mask, dim=-1) + + stack_input = stack_input * stack_mask.unsqueeze(2) + mean = stack_input.sum(1, keepdim=True) / stack_mask.sum( + dim=1, keepdim=True).unsqueeze(1) + std = torch.sqrt(((stack_input - mean)**2).sum(dim=1, keepdim=True) / + stack_mask.sum(dim=1, keepdim=True).unsqueeze(1)) + norm_stack_input = (stack_input - mean) / (std + 1e-5) + return norm_stack_input def _compute_loss(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: - log_probs = torch.log_softmax(input, dim=-1).transpose( - 1, 2) # [B, T', num_codebooks, num_embeddings] - - per_example_n_loss = -log_probs.gather(3, target.unsqueeze(3)).squeeze( - 3) # [B, T', num_codebooks] - - numerator = torch.sum(per_example_n_loss * mask.unsqueeze(2)) - denominator = torch.sum(mask) + 1e-5 - loss = numerator / (denominator * self.num_codebooks) + logits = input.transpose(1, 2).contiguous().view(-1, input.size(-1)) + loss = torch.nn.functional.cross_entropy( + logits, + target.contiguous().view(-1), + reduction='none', + ) + loss = (loss * mask.view(-1)).sum() / mask.sum() return loss def _nearest_embedding_idx(self, xs: torch.Tensor) -> torch.Tensor: - xs = self.norm(xs) xs = torch.matmul(xs, self.projection.to(xs.device)) xs = xs / (xs.norm(dim=-1, p=2, keepdim=True) + 1e-8) codebooks = self.embeddings