diff --git a/wenet/transformer/embedding.py b/wenet/transformer/embedding.py index 2efa2f5fd..db8a41333 100644 --- a/wenet/transformer/embedding.py +++ b/wenet/transformer/embedding.py @@ -208,9 +208,9 @@ def __init__(self, rope_theta=10000.0): super().__init__(d_model, dropout_rate=dropout_rate, max_len=max_len) delattr(self, 'pe') - - pe = precompute_freqs_cis(head_dim, max_len * 2, rope_theta) - self.register_buffer("pe", pe.unsqueeze(0)) + self.max_len = max_len * 2 + pe = precompute_freqs_cis(head_dim, self.max_len, rope_theta) + self.register_buffer("pe", torch.view_as_real(pe.unsqueeze(0))) self.dropout_rate = dropout_rate def forward( @@ -219,13 +219,34 @@ def forward( offset: Union[int, torch.Tensor] = 0) -> Tuple[torch.Tensor, torch.Tensor]: - pos_emb = self.position_encoding(offset, x.size(1), False) + pos_emb = self.position_encoding(offset, x.size(1), True) pos_emb = pos_emb.unsqueeze(1) # [1, 1, seq, head_dim//2] # NOTE(Mddct): some model don't scale # TODO(Mddct): fix x = x * self.xscale - # NOTE(Mddct) dropout don't suuport complex float for pos_emb - return self.dropout(x), self.dropout_complex(pos_emb) + return self.dropout(x), pos_emb + + def position_encoding(self, + offset: Union[int, torch.Tensor], + size: int, + apply_dropout: bool = True) -> torch.Tensor: + + pe = torch.view_as_complex(self.pe) + if isinstance(offset, int): + assert offset + size <= self.max_len + pos_emb = pe[:, offset:offset + size] + else: + assert torch.max(offset) + size <= self.max_len + index = offset.unsqueeze(1) + torch.arange(0, size).to( + offset.device) # B X T + flag = index > 0 + # remove negative offset + index = index * flag + pos_emb = F.embedding(index, pe[0]) # B X T X head_dim//2 + if apply_dropout: + # NOTE(Mddct) dropout don't suuport complex float for pos_emb + pos_emb = self.dropout_complex(pos_emb) + return pos_emb def dropout_complex(self, x): mask = torch.nn.functional.dropout(