diff --git a/wenet/transducer/transducer.py b/wenet/transducer/transducer.py index ed1e58a35..ae026636b 100644 --- a/wenet/transducer/transducer.py +++ b/wenet/transducer/transducer.py @@ -92,20 +92,16 @@ def __init__( def forward( self, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - text: torch.Tensor, - text_lengths: torch.Tensor, - steps: int = 0, + batch: dict, + device: torch.device, ) -> Dict[str, Optional[torch.Tensor]]: """Frontend + Encoder + predictor + joint + loss - - Args: - speech: (Batch, Length, ...) - speech_lengths: (Batch, ) - text: (Batch, Length) - text_lengths: (Batch,) """ + speech = batch['feats'].to(device) + speech_lengths = batch['feats_lengths'].to(device) + text = batch['target'].to(device) + text_lengths = batch['target_lengths'].to(device) + steps = batch.get('steps', 0) assert text_lengths.dim() == 1, text_lengths.shape # Check that batch_size is unified assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] ==