Skip to content

Commit

Permalink
[transducer] refine forward args to dict (#2245)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct authored Dec 14, 2023
1 parent d508589 commit 1a02467
Showing 1 changed file with 7 additions and 11 deletions.
18 changes: 7 additions & 11 deletions wenet/transducer/transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,20 +92,16 @@ def __init__(

def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
steps: int = 0,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
"""Frontend + Encoder + predictor + joint + loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
speech = batch['feats'].to(device)
speech_lengths = batch['feats_lengths'].to(device)
text = batch['target'].to(device)
text_lengths = batch['target_lengths'].to(device)
steps = batch.get('steps', 0)
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] ==
Expand Down

0 comments on commit 1a02467

Please sign in to comment.