From 83e4637ffde1819d4d0400ab59152528d776068d Mon Sep 17 00:00:00 2001 From: durson Date: Fri, 25 Aug 2023 12:53:06 +0200 Subject: [PATCH] added return_grad for all types of rnnt loss (#29) * added return_grad for all types of rnnt loss * lifted T >= S for regular case * black reformat * black -l80 reformat * fixed s_range adjustment rule --- fast_rnnt/python/fast_rnnt/rnnt_loss.py | 124 ++++++++++++------ .../python/tests/mutual_information_test.py | 1 - fast_rnnt/python/tests/rnnt_loss_test.py | 97 +++++++++++++- 3 files changed, 182 insertions(+), 40 deletions(-) mode change 100644 => 100755 fast_rnnt/python/tests/rnnt_loss_test.py diff --git a/fast_rnnt/python/fast_rnnt/rnnt_loss.py b/fast_rnnt/python/fast_rnnt/rnnt_loss.py index 622aa46..986dc54 100644 --- a/fast_rnnt/python/fast_rnnt/rnnt_loss.py +++ b/fast_rnnt/python/fast_rnnt/rnnt_loss.py @@ -22,6 +22,26 @@ from .mutual_information import mutual_information_recursion +def validate_st_lengths( + S: int, + T: int, + is_rnnt_type_regular: bool, + boundary: Optional[Tensor] = None, +): + if boundary is None: + assert S >= 1, S + assert ( + is_rnnt_type_regular or T >= S + ), f"Modified transducer requires T >= S, but got T={T} and S={S}" + else: + Ss = boundary[:, 2] + Ts = boundary[:, 3] + assert (Ss >= 1).all(), Ss + assert ( + is_rnnt_type_regular or (Ts >= Ss).all() + ), f"Modified transducer requires T >= S, but got T={Ts} and S={Ss}" + + def fix_for_boundary(px: Tensor, boundary: Optional[Tensor] = None) -> Tensor: """ Insert -inf's into `px` in appropriate places if `boundary` is not @@ -145,8 +165,8 @@ def get_rnnt_logprobs( (B, T, C) = am.shape S = lm.shape[1] - 1 assert symbols.shape == (B, S), symbols.shape - assert S >= 1, S - assert T >= S, (T, S) + + validate_st_lengths(S, T, rnnt_type == "regular", boundary) assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type # subtracting am_max and lm_max is to ensure the probs are in a good range @@ -391,8 +411,8 @@ def get_rnnt_logprobs_joint( (B, T, S1, C) = logits.shape S = S1 - 1 assert symbols.shape == (B, S), symbols.shape - assert S >= 1, S - assert T >= S, (T, S) + + validate_st_lengths(S, T, rnnt_type == "regular", boundary) assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type normalizers = torch.logsumexp(logits, dim=3) @@ -437,6 +457,7 @@ def rnnt_loss( rnnt_type: str = "regular", delay_penalty: float = 0.0, reduction: Optional[str] = "mean", + return_grad: bool = False, ) -> Tensor: """A normal RNN-T loss, which uses a 'joiner' network output as input, i.e. a 4 dimensions tensor. @@ -509,20 +530,24 @@ def rnnt_loss( penalty = penalty * delay_penalty px += penalty.to(px.dtype) - negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary) + scores_and_grads = mutual_information_recursion( + px=px, py=py, boundary=boundary, return_grad=return_grad + ) + negated_loss = scores_and_grads[0] if return_grad else scores_and_grads if reduction == "none": - return -negated_loss + loss = -negated_loss elif reduction == "mean": - return -torch.mean(negated_loss) + loss = -torch.mean(negated_loss) elif reduction == "sum": - return -torch.sum(negated_loss) + loss = -torch.sum(negated_loss) else: raise ValueError( f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" ) + return (loss, scores_and_grads[1]) if return_grad else loss -def _monotonic_lower_bound(x: torch.Tensor) -> torch.Tensor: +def _monotonic_lower_bound(x: Tensor) -> Tensor: """Compute a monotonically increasing lower bound of the tensor `x` on the last dimension. The basic idea is: we traverse the tensor in reverse order, and update current element with the following statement, @@ -556,9 +581,7 @@ def _monotonic_lower_bound(x: torch.Tensor) -> torch.Tensor: return x -def _adjust_pruning_lower_bound( - s_begin: torch.Tensor, s_range: int -) -> torch.Tensor: +def _adjust_pruning_lower_bound(s_begin: Tensor, s_range: int) -> Tensor: """Adjust s_begin (pruning lower bounds) to make it satisfy the following constraints @@ -613,11 +636,11 @@ def _adjust_pruning_lower_bound( # chapter 3.2 (Pruning bounds) of our Pruned RNN-T paper # (https://arxiv.org/pdf/2206.13236.pdf) def get_rnnt_prune_ranges( - px_grad: torch.Tensor, - py_grad: torch.Tensor, - boundary: torch.Tensor, + px_grad: Tensor, + py_grad: Tensor, + boundary: Tensor, s_range: int, -) -> torch.Tensor: +) -> Tensor: """Get the pruning ranges of normal rnnt loss according to the grads of px and py returned by mutual_information_recursion. @@ -661,28 +684,44 @@ def get_rnnt_prune_ranges( """ (B, S, T1) = px_grad.shape T = py_grad.shape[-1] + + is_regular = T1 != T + assert T1 in [T, T + 1], T1 S1 = S + 1 assert py_grad.shape == (B, S + 1, T), py_grad.shape assert boundary.shape == (B, 4), boundary.shape - assert S >= 1, S - assert T >= S, (T, S) + validate_st_lengths(S, T, is_regular, boundary) + + # in regular case s_range should be no less than + # a minimum integer satisfying `(s_range - 1) * t + 1 >= s + 1` + if is_regular: + Ss = boundary[:, 2] + Ts = boundary[:, 3] + s_range_min = ( + Ss.sub(1).div(Ts, rounding_mode="trunc").add(2).max().item() + ) + if s_range < s_range_min: + print( + f"Warning: get_rnnt_prune_ranges - got s_range={s_range} " + f"for boundaries S={Ss}, T={Ts}. Adjusting to {s_range_min}" + ) + s_range = s_range_min # s_range > S means we won't prune out any symbols. To make indexing with # ranges run normally, s_range should be equal to or less than ``S + 1``. if s_range > S: s_range = S + 1 - if T1 == T: - assert ( - s_range >= 1 - ), "Pruning range for modified RNN-T should be equal to or greater than 1, or no valid paths could survive pruning." - - else: + if is_regular: assert ( s_range >= 2 ), "Pruning range for standard RNN-T should be equal to or greater than 2, or no valid paths could survive pruning." + else: + assert ( + s_range >= 1 + ), "Pruning range for modified RNN-T should be equal to or greater than 1, or no valid paths could survive pruning." (B_stride, S_stride, T_stride) = py_grad.stride() blk_grad = torch.as_strided( @@ -739,8 +778,8 @@ def get_rnnt_prune_ranges( def do_rnnt_pruning( - am: torch.Tensor, lm: torch.Tensor, ranges: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor]: + am: Tensor, lm: Tensor, ranges: Tensor +) -> Tuple[Tensor, Tensor]: """Prune the output of encoder(am) and prediction network(lm) with ranges generated by `get_rnnt_prune_ranges`. @@ -779,7 +818,7 @@ def do_rnnt_pruning( return am_pruning, lm_pruning -def _roll_by_shifts(src: torch.Tensor, shifts: torch.LongTensor): +def _roll_by_shifts(src: Tensor, shifts: torch.LongTensor): """Roll tensor with different shifts for each row. Note: @@ -819,7 +858,7 @@ def get_rnnt_logprobs_pruned( symbols: Tensor, ranges: Tensor, termination_symbol: int, - boundary: Tensor, + boundary: Optional[Tensor] = None, rnnt_type: str = "regular", ) -> Tuple[Tensor, Tensor]: """Construct px, py for mutual_information_recursion with pruned output. @@ -888,10 +927,14 @@ def get_rnnt_logprobs_pruned( # ranges (B, T, s_range) assert logits.ndim == 4, logits.ndim (B, T, s_range, C) = logits.shape - assert ranges.shape == (B, T, s_range), ranges.shape + assert ranges.shape == ( + B, + T, + s_range, + ), f"{ranges.shape} == ({B}, {T}, {s_range})" (B, S) = symbols.shape - assert S >= 1, S - assert T >= S, (T, S) + + validate_st_lengths(S, T, rnnt_type == "regular", boundary) assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type normalizers = torch.logsumexp(logits, dim=3) @@ -986,10 +1029,11 @@ def rnnt_loss_pruned( symbols: Tensor, ranges: Tensor, termination_symbol: int, - boundary: Tensor = None, + boundary: Optional[Tensor] = None, rnnt_type: str = "regular", delay_penalty: float = 0.0, reduction: Optional[str] = "mean", + return_grad: bool = False, ) -> Tensor: """A RNN-T loss with pruning, which uses the output of a pruned 'joiner' network as input, i.e. a 4 dimensions tensor with shape (B, T, s_range, C), @@ -1071,17 +1115,21 @@ def rnnt_loss_pruned( penalty = penalty * delay_penalty px += penalty.to(px.dtype) - negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary) + scores_and_grads = mutual_information_recursion( + px=px, py=py, boundary=boundary, return_grad=return_grad + ) + negated_loss = scores_and_grads[0] if return_grad else scores_and_grads if reduction == "none": - return -negated_loss + loss = -negated_loss elif reduction == "mean": - return -torch.mean(negated_loss) + loss = -torch.mean(negated_loss) elif reduction == "sum": - return -torch.sum(negated_loss) + loss = -torch.sum(negated_loss) else: raise ValueError( f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" ) + return (loss, scores_and_grads[1]) if return_grad else loss def get_rnnt_logprobs_smoothed( @@ -1202,8 +1250,8 @@ def get_rnnt_logprobs_smoothed( (B, T, C) = am.shape S = lm.shape[1] - 1 assert symbols.shape == (B, S), symbols.shape - assert S >= 1, S - assert T >= S, (T, S) + + validate_st_lengths(S, T, rnnt_type == "regular", boundary) assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type # Caution: some parts of this code are a little less clear than they could diff --git a/fast_rnnt/python/tests/mutual_information_test.py b/fast_rnnt/python/tests/mutual_information_test.py index 6a335f6..20dff35 100644 --- a/fast_rnnt/python/tests/mutual_information_test.py +++ b/fast_rnnt/python/tests/mutual_information_test.py @@ -206,7 +206,6 @@ def test_mutual_information_deriv(self): for dtype in self.dtypes: for device in self.devices: - if random_boundary: def get_boundary_row(): diff --git a/fast_rnnt/python/tests/rnnt_loss_test.py b/fast_rnnt/python/tests/rnnt_loss_test.py old mode 100644 new mode 100755 index 7edaf8c..d469b19 --- a/fast_rnnt/python/tests/rnnt_loss_test.py +++ b/fast_rnnt/python/tests/rnnt_loss_test.py @@ -343,7 +343,6 @@ def test_rnnt_loss_gradient(self): boundary_[:, 3] = frames for device in self.devices: - # lm: [B][S+1][C] lm = lm_.to(device) # am: [B][T][C] @@ -609,6 +608,102 @@ def test_rnnt_loss_pruned_small_symbols_number(self): ) print(f"Pruned loss with range {r} : {pruned_loss}") + # Test low s_range values with large S and small T, + # at this circumstance, the s_range would not be enough + # to cover the whole sequence length (in regular rnnt mode) + # and would result in inf loss + def test_rnnt_loss_pruned_small_s_range(self): + B = 2 + T = 2 + S = 10 + C = 10 + + frames = torch.randint(1, T, (B,)) + seq_lengths = torch.randint(1, S, (B,)) + T = torch.max(frames) + S = torch.max(seq_lengths) + + am_ = torch.randn((B, T, C), dtype=torch.float64) + lm_ = torch.randn((B, S + 1, C), dtype=torch.float64) + symbols_ = torch.randint(0, C, (B, S)) + terminal_symbol = C - 1 + + boundary_ = torch.zeros((B, 4), dtype=torch.int64) + boundary_[:, 2] = seq_lengths + boundary_[:, 3] = frames + + print(f"B = {B}, T = {T}, S = {S}, C = {C}") + + for rnnt_type in ["regular"]: + for device in self.devices: + # normal rnnt + am = am_.to(device) + lm = lm_.to(device) + symbols = symbols_.to(device) + boundary = boundary_.to(device) + + logits = am.unsqueeze(2) + lm.unsqueeze(1) + logits = logits.float() + + # nonlinear transform + logits = torch.sigmoid(logits) + + loss = fast_rnnt.rnnt_loss( + logits=logits, + symbols=symbols, + termination_symbol=terminal_symbol, + boundary=boundary, + rnnt_type=rnnt_type, + reduction="none", + ) + + print(f"Unpruned rnnt loss with {rnnt_type} rnnt : {loss}") + + # pruning + simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple( + lm=lm, + am=am, + symbols=symbols, + termination_symbol=terminal_symbol, + boundary=boundary, + rnnt_type=rnnt_type, + return_grad=True, + reduction="none", + ) + + S0 = 2 + + for r in range(S0, S + 2): + ranges = fast_rnnt.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=r, + ) + # (B, T, r, C) + pruned_am, pruned_lm = fast_rnnt.do_rnnt_pruning( + am=am, lm=lm, ranges=ranges + ) + + logits = pruned_am + pruned_lm + + # nonlinear transform + logits = torch.sigmoid(logits) + + pruned_loss = fast_rnnt.rnnt_loss_pruned( + logits=logits, + symbols=symbols, + ranges=ranges, + termination_symbol=terminal_symbol, + boundary=boundary, + rnnt_type=rnnt_type, + reduction="none", + ) + assert ( + not pruned_loss.isinf().any() + ), f"Pruned loss is inf for r={r}, S={S}, T={T}: {pruned_loss}" + print(f"Pruned loss with range {r} : {pruned_loss}") + if __name__ == "__main__": unittest.main()