Skip to content

Commit

Permalink
add examples
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct authored Sep 27, 2023
1 parent dda5a69 commit a392d05
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions wenet/utils/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def causal_or_lookahead_mask(
[1, 1, 1, 0],
[1, 1, 1, 1]]
>>> causal_or_lookahead_mask(seq_mask.unsqueeze(1), 0, 2)
[[1, 0, 0, 0],
[[[1, 0, 0, 0],
[1, 1, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]],
Expand All @@ -335,7 +335,22 @@ def causal_or_lookahead_mask(
[[1, 0, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0],
[0, 1, 1, 1]]
[0, 1, 1, 1]]]
>>> causal_or_lookahead_mask(seq_mask.unsqueeze(1), 1, 2)
[[[1, 1, 0, 0],
[1, 1, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]],
[[1, 1, 0, 0],
[1, 1, 1, 0],
[1, 1, 1, 0],
[0, 0, 0, 0]],
[[1, 1, 0, 0],
[1, 1, 1, 0],
[1, 1, 1, 1],
[0, 1, 1, 1]]]
"""
_, _, T = mask.size()
indices = torch.arange(T, device=mask.device)
Expand Down

0 comments on commit a392d05

Please sign in to comment.