diff --git a/wenet/transformer/attention.py b/wenet/transformer/attention.py index c9d8f07b4b..43a7c80911 100644 --- a/wenet/transformer/attention.py +++ b/wenet/transformer/attention.py @@ -155,14 +155,15 @@ def forward_attention( # For last chunk, time2 might be larger than scores.size(-1) mask = mask[..., :scores.size(-1)] # (batch, 1, *, time2) scores = scores.masked_fill(mask, -float('inf')) - attn = torch.softmax(scores, dim=-1).masked_fill( - mask, 0.0) # (batch, head, time1, time2) + attn = torch.softmax(scores.float(), + dim=-1).type_as(value).masked_fill( + mask, 0.0) # (batch, head, time1, time2) # NOTE(xcsong): When will `if mask.size(2) > 0` be False? # 1. onnx(16/-1, -1/-1, 16/0) # 2. jit (16/-1, -1/-1, 16/0, 16/4) else: - attn = torch.softmax(scores, - dim=-1) # (batch, ..., head, time1, time2) + attn = torch.softmax(scores.float(), dim=-1).type_as( + value) # (batch, ..., head, time1, time2) p_attn = self.dropout(attn) x = torch.matmul(p_attn, value) # (batch, ..., head, time1, d_k)