Skip to content

Commit

Permalink
[transformer] keep high precisioin in softmax (#2508)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct authored Apr 29, 2024
1 parent f42ddb2 commit 47b6cfb
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions wenet/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,15 @@ def forward_attention(
# For last chunk, time2 might be larger than scores.size(-1)
mask = mask[..., :scores.size(-1)] # (batch, 1, *, time2)
scores = scores.masked_fill(mask, -float('inf'))
attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0) # (batch, head, time1, time2)
attn = torch.softmax(scores.float(),
dim=-1).type_as(value).masked_fill(
mask, 0.0) # (batch, head, time1, time2)
# NOTE(xcsong): When will `if mask.size(2) > 0` be False?
# 1. onnx(16/-1, -1/-1, 16/0)
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
else:
attn = torch.softmax(scores,
dim=-1) # (batch, ..., head, time1, time2)
attn = torch.softmax(scores.float(), dim=-1).type_as(
value) # (batch, ..., head, time1, time2)

p_attn = self.dropout(attn)
x = torch.matmul(p_attn, value) # (batch, ..., head, time1, d_k)
Expand Down

0 comments on commit 47b6cfb

Please sign in to comment.