Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Not for merge] Implement FastEmit #1069

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions k2/python/k2/mutual_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def forward(
py: torch.Tensor,
pxy_grads: List[Optional[torch.Tensor]],
boundary: Optional[torch.Tensor] = None,
fast_emit_scale: float = 0.0,
return_grad: bool = False,
) -> torch.Tensor:
"""
Expand Down Expand Up @@ -109,6 +110,10 @@ def forward(
all sequences are
of the same length.

fast_emit_scale:
Implement fast_emit proposed in https://arxiv.org/pdf/2010.11148.pdf
The idea is to scale px_grad with (1 + fast_emit_scale).

return_grad:
Whether to return grads of ``px`` and ``py``, this grad standing
for the occupation probability is the output of the backward with a
Expand Down Expand Up @@ -163,6 +168,7 @@ def forward(
ans_grad = torch.ones(B, device=px.device, dtype=px.dtype)
(px_grad, py_grad) = _k2.mutual_information_backward(
px, py, boundary, p, ans_grad)
px_grad *= (1 + fast_emit_scale)
ctx.save_for_backward(px_grad, py_grad)
assert len(pxy_grads) == 2
pxy_grads[0] = px_grad
Expand All @@ -173,19 +179,20 @@ def forward(
@staticmethod
def backward(
ctx, ans_grad: Tensor
) -> Tuple[torch.Tensor, torch.Tensor, None, None, None]:
) -> Tuple[torch.Tensor, torch.Tensor, None, None, None, None]:
(px_grad, py_grad) = ctx.saved_tensors
(B,) = ans_grad.shape
ans_grad = ans_grad.reshape(B, 1, 1) # (B, 1, 1)
px_grad *= ans_grad
py_grad *= ans_grad
return (px_grad, py_grad, None, None, None)
return (px_grad, py_grad, None, None, None, None)


def mutual_information_recursion(
px: Tensor,
py: Tensor,
boundary: Optional[Tensor] = None,
fast_emit_scale: float = 0.0,
return_grad: bool = False,
) -> Union[Tuple[Tensor, Tuple[Tensor, Tensor]], Tensor]:
"""A recursion that is useful in computing mutual information between two
Expand Down Expand Up @@ -248,6 +255,10 @@ def mutual_information_recursion(
``y`` sequences respectively, and can be used if not all sequences are
of the same length.

fast_emit_scale:
Implement fast_emit proposed in https://arxiv.org/pdf/2010.11148.pdf
The idea is to scale px_grad with (1 + fast_emit_scale).

return_grad:
Whether to return grads of ``px`` and ``py``, this grad standing for the
occupation probability is the output of the backward with a
Expand Down Expand Up @@ -292,8 +303,14 @@ def mutual_information_recursion(
px, py = px.contiguous(), py.contiguous()

pxy_grads = [None, None]
scores = MutualInformationRecursionFunction.apply(px, py, pxy_grads,
boundary, return_grad)
scores = MutualInformationRecursionFunction.apply(
px,
py,
pxy_grads,
boundary,
fast_emit_scale,
return_grad
)
px_grad, py_grad = pxy_grads
return (scores, (px_grad, py_grad)) if return_grad else scores

Expand Down
46 changes: 42 additions & 4 deletions k2/python/k2/rnnt_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def rnnt_loss_simple(
termination_symbol: int,
boundary: Optional[Tensor] = None,
modified: bool = False,
fast_emit_scale: float = 0.0,
reduction: Optional[str] = "mean",
return_grad: bool = False,
) -> Union[Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]]]:
Expand All @@ -228,6 +229,9 @@ def rnnt_loss_simple(
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
fast_emit_scale:
Implement fast_emit proposed in https://arxiv.org/pdf/2010.11148.pdf
The idea is to scale px_grad with (1 + fast_emit_scale).
reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied.
Expand Down Expand Up @@ -258,8 +262,13 @@ def rnnt_loss_simple(
boundary=boundary,
modified=modified,
)

scores_and_grads = mutual_information_recursion(
px=px, py=py, boundary=boundary, return_grad=return_grad
px=px,
py=py,
boundary=boundary,
fast_emit_scale=fast_emit_scale,
return_grad=return_grad
)
negated_loss = scores_and_grads[0] if return_grad else scores_and_grads
if reduction == "none":
Expand Down Expand Up @@ -376,6 +385,7 @@ def rnnt_loss(
termination_symbol: int,
boundary: Optional[Tensor] = None,
modified: bool = False,
fast_emit_scale: float = 0.0,
reduction: Optional[str] = "mean",
) -> Tensor:
"""A normal RNN-T loss, which uses a 'joiner' network output as input,
Expand All @@ -397,6 +407,9 @@ def rnnt_loss(
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
fast_emit_scale:
Implement fast_emit proposed in https://arxiv.org/pdf/2010.11148.pdf
The idea is to scale px_grad with (1 + fast_emit_scale).
reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied.
Expand All @@ -416,7 +429,13 @@ def rnnt_loss(
boundary=boundary,
modified=modified,
)
negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary)

negated_loss = mutual_information_recursion(
px=px,
py=py,
boundary=boundary,
fast_emit_scale=fast_emit_scale
)
if reduction == "none":
return -negated_loss
elif reduction == "mean":
Expand Down Expand Up @@ -957,6 +976,7 @@ def rnnt_loss_pruned(
termination_symbol: int,
boundary: Tensor = None,
modified: bool = False,
fast_emit_scale: float = 0.0,
reduction: Optional[str] = "mean",
) -> Tensor:
"""A RNN-T loss with pruning, which uses a pruned 'joiner' network output
Expand Down Expand Up @@ -987,6 +1007,9 @@ def rnnt_loss_pruned(
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
fast_emit_scale:
Implement fast_emit proposed in https://arxiv.org/pdf/2010.11148.pdf
The idea is to scale px_grad with (1 + fast_emit_scale).
reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied.
Expand All @@ -1006,7 +1029,13 @@ def rnnt_loss_pruned(
boundary=boundary,
modified=modified,
)
negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary)

negated_loss = mutual_information_recursion(
px=px,
py=py,
boundary=boundary,
fast_emit_scale=fast_emit_scale
)
if reduction == "none":
return -negated_loss
elif reduction == "mean":
Expand Down Expand Up @@ -1250,6 +1279,7 @@ def rnnt_loss_smoothed(
am_only_scale: float = 0.1,
boundary: Optional[Tensor] = None,
modified: bool = False,
fast_emit_scale: float = 0.0,
reduction: Optional[str] = "mean",
return_grad: bool = False,
) -> Union[Tuple[Tensor, Tuple[Tensor, Tensor]], Tensor]:
Expand Down Expand Up @@ -1283,6 +1313,9 @@ def rnnt_loss_smoothed(
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
fast_emit_scale:
Implement fast_emit proposed in https://arxiv.org/pdf/2010.11148.pdf
The idea is to scale px_grad with (1 + fast_emit_scale).
reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied.
Expand Down Expand Up @@ -1315,8 +1348,13 @@ def rnnt_loss_smoothed(
boundary=boundary,
modified=modified,
)

scores_and_grads = mutual_information_recursion(
px=px, py=py, boundary=boundary, return_grad=return_grad
px=px,
py=py,
boundary=boundary,
fast_emit_scale=fast_emit_scale,
return_grad=return_grad
)
negated_loss = scores_and_grads[0] if return_grad else scores_and_grads
if reduction == "none":
Expand Down