Skip to content

Commit

Permalink
[ssl/bestrq] happy ending---stable training (#2614)
Browse files Browse the repository at this point in the history
* [ssl/bestrq] happy ending

* fix multibook loss
  • Loading branch information
Mddct authored Aug 16, 2024
1 parent 203e067 commit 2adf651
Showing 1 changed file with 27 additions and 21 deletions.
48 changes: 27 additions & 21 deletions wenet/ssl/bestrq/bestrq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch

from wenet.ssl.bestrq.mask import compute_mask_indices_v2
from wenet.utils.mask import make_pad_mask
from wenet.utils.mask import make_non_pad_mask, make_pad_mask
from wenet.transformer.attention import RelPositionMultiHeadedAttention
from wenet.transformer.encoder_layer import ConformerEncoderLayer

Expand Down Expand Up @@ -96,12 +96,7 @@ def __init__(
# stack input: eg: fbank
self.stack_frames = self.encoder.embed.right_context + 1
self.stride = self.encoder.embed.subsampling_rate
input_dim = num_mel_bins * self.stack_frames

# norm input
self.norm = torch.nn.LayerNorm(
input_dim, eps=norm_epsilon, elementwise_affine=False
) if self.stack_frames > 1 else torch.nn.Identity()
input_dim = num_mel_bins * self.stride

# random projectoin
self.projection = torch.nn.parameter.Parameter(
Expand Down Expand Up @@ -177,11 +172,12 @@ def forward(
xs, code_ids_mask = self._apply_mask_signal(xs, xs_lens)

# 2.0 stack fbank
unmasked_xs = self._stack_features(input)
unmasked_xs = self._stack_features(input, xs_lens)
masked_xs = xs

# 2.1 get nearest embedding
target_ids = self._nearest_embedding_idx(unmasked_xs)
target_ids = target_ids[:, :code_ids_mask.size(1), :]

# 3 forward xxx-formaer block and its subsampling layer
out, out_mask = self.encoder(masked_xs, xs_lens)
Expand Down Expand Up @@ -258,30 +254,40 @@ def _apply_mask_signal(
xs = torch.where(masks_expand, mask_emb, input)
return xs, subsampling_mask

def _stack_features(self, input: torch.Tensor) -> torch.Tensor:
def _stack_features(self, input: torch.Tensor,
input_lens: torch.Tensor) -> torch.Tensor:

stack_input = input.unfold(1, size=self.stack_frames, step=self.stride)
stack_input = input.unfold(1, size=self.stride, step=self.stride)
stack_input = stack_input.transpose(-1, -2)
b, n, f, d = stack_input.size()
stack_input = stack_input.reshape(b, n, f * d)

return stack_input
# NOTE(Mddct): important!!!
# norm stack features
mask = make_non_pad_mask(input_lens)
stack_mask = mask.unfold(1, size=self.stride, step=self.stride)
stack_mask, _ = torch.min(stack_mask, dim=-1)

stack_input = stack_input * stack_mask.unsqueeze(2)
mean = stack_input.sum(1, keepdim=True) / stack_mask.sum(
dim=1, keepdim=True).unsqueeze(1)
std = torch.sqrt(((stack_input - mean)**2).sum(dim=1, keepdim=True) /
stack_mask.sum(dim=1, keepdim=True).unsqueeze(1))
norm_stack_input = (stack_input - mean) / (std + 1e-5)
return norm_stack_input

def _compute_loss(self, input: torch.Tensor, target: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
log_probs = torch.log_softmax(input, dim=-1).transpose(
1, 2) # [B, T', num_codebooks, num_embeddings]

per_example_n_loss = -log_probs.gather(3, target.unsqueeze(3)).squeeze(
3) # [B, T', num_codebooks]

numerator = torch.sum(per_example_n_loss * mask.unsqueeze(2))
denominator = torch.sum(mask) + 1e-5
loss = numerator / (denominator * self.num_codebooks)
logits = input.transpose(1, 2).contiguous().view(-1, input.size(-1))
loss = torch.nn.functional.cross_entropy(
logits,
target.contiguous().view(-1),
reduction='none',
)
loss = (loss * mask.view(-1)).sum() / mask.sum()
return loss

def _nearest_embedding_idx(self, xs: torch.Tensor) -> torch.Tensor:
xs = self.norm(xs)
xs = torch.matmul(xs, self.projection.to(xs.device))
xs = xs / (xs.norm(dim=-1, p=2, keepdim=True) + 1e-8)
codebooks = self.embeddings
Expand Down

0 comments on commit 2adf651

Please sign in to comment.