Skip to content

Commit

Permalink
changed asr models outputs to be consistent (#11818)
Browse files Browse the repository at this point in the history
* changed asr models outputs to be consistent

Signed-off-by: Ssofja <[email protected]>

* Apply isort and black reformatting

Signed-off-by: Ssofja <[email protected]>
Signed-off-by: Ssofja <[email protected]>

* Apply isort and black reformatting

Signed-off-by: Ssofja <[email protected]>

* adding needed changes

Signed-off-by: Ssofja <[email protected]>

* Apply isort and black reformatting

Signed-off-by: Ssofja <[email protected]>

* Small fixes

* Returned previous names of return_hypotheses

Signed-off-by: Ssofja <[email protected]>

* Apply isort and black reformatting

Signed-off-by: Ssofja <[email protected]>

---------

Signed-off-by: Ssofja <[email protected]>
Signed-off-by: Ssofja <[email protected]>
Co-authored-by: Ssofja <[email protected]>
  • Loading branch information
Ssofja and Ssofja authored Feb 12, 2025
1 parent 6b59ab8 commit ee543c2
Show file tree
Hide file tree
Showing 51 changed files with 647 additions and 677 deletions.
8 changes: 5 additions & 3 deletions nemo/collections/asr/metrics/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,16 @@ def update(
target = targets_cpu_tensor[ind][:tgt_len].numpy().tolist()
reference = self.decoding.decode_tokens_to_str(target)
references.append(reference)
hypotheses, _ = self.decode(predictions, predictions_lengths, predictions_mask, input_ids, targets)
hypotheses = self.decode(predictions, predictions_lengths, predictions_mask, input_ids, targets)

if self.log_prediction:
logging.info(f"\n")
logging.info("\n")
logging.info(f"reference:{references[0]}")
logging.info(f"predicted:{hypotheses[0]}")

super().update(hypotheses, [references]) # Note: [references] since BLEU allows multiple references.
super().update(
[h.text for h in hypotheses], [references]
) # Note: [references] since BLEU allows multiple references.

def compute(self, return_all_metrics=True, prefix="", suffix=""):
"""
Expand Down
10 changes: 5 additions & 5 deletions nemo/collections/asr/metrics/wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,19 +323,19 @@ def update(
target = targets_cpu_tensor[ind][:tgt_len].numpy().tolist()
reference = self.decoding.decode_tokens_to_str(target)
references.append(reference)
hypotheses, _ = self.decode(predictions, predictions_lengths, predictions_mask, input_ids, targets)
hypotheses = self.decode(predictions, predictions_lengths, predictions_mask, input_ids, targets)

if self.log_prediction:
logging.info(f"\n")
logging.info("\n")
logging.info(f"reference:{references[0]}")
logging.info(f"predicted:{hypotheses[0]}")
logging.info(f"predicted:{hypotheses[0].text}")

for h, r in zip(hypotheses, references):
if self.use_cer:
h_list = list(h)
h_list = list(h.text)
r_list = list(r)
else:
h_list = h.split()
h_list = h.text.split()
r_list = r.split()
words += len(r_list)
# Compute Levenstein's distance
Expand Down
21 changes: 7 additions & 14 deletions nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis
from nemo.collections.common import tokenizers
from nemo.collections.common.data.lhotse.dataloader import get_lhotse_dataloader_from_config
from nemo.collections.common.data.prompt_fn import get_prompt_format_fn
from nemo.collections.common.metrics import GlobalAverageLossMetric
from nemo.collections.common.parts import transformer_weights_init
from nemo.collections.common.parts.preprocessing.manifest import get_full_path
Expand All @@ -60,7 +59,6 @@
SpectrogramType,
)
from nemo.utils import logging, model_utils
from nemo.utils.decorators import deprecated

__all__ = ['EncDecMultiTaskModel']

Expand Down Expand Up @@ -310,7 +308,7 @@ def change_vocabulary(
)

if new_tokenizer_type.lower() not in ('bpe', 'wpe'):
raise ValueError(f'New tokenizer type must be either `bpe` or `wpe`')
raise ValueError('New tokenizer type must be either `bpe` or `wpe`')

tokenizer_cfg = OmegaConf.create({'dir': new_tokenizer_dir, 'type': new_tokenizer_type})

Expand Down Expand Up @@ -821,7 +819,7 @@ def _transcribe_on_begin(self, audio, trcfg: MultiTaskTranscriptionConfig):

if isinstance(audio, list):
logging.debug(f"Found 'audio' to be a list of {len(audio)} items.")
logging.debug(f"Assuming each item in 'audio' is a path to audio file.")
logging.debug("Assuming each item in 'audio' is a path to audio file.")

if isinstance(self.tokenizer, tokenizers.AggregateTokenizer):
if hasattr(trcfg, '_internal') and hasattr(trcfg._internal, 'primary_language'):
Expand Down Expand Up @@ -929,10 +927,6 @@ def _transcribe_forward(
decoder_input_ids=decoder_input_ids,
)

@deprecated(
explanation='The return type of args will be updated in the upcoming release to ensure a consistent \
output format across all decoder types, such that a Hypothesis object is always returned.'
)
def _transcribe_output_processing(self, outputs, trcfg: MultiTaskTranscriptionConfig) -> GenericTranscriptionType:
"""
Internal function to process the model's outputs to return the results to the user. This function is called by
Expand All @@ -944,7 +938,7 @@ def _transcribe_output_processing(self, outputs, trcfg: MultiTaskTranscriptionCo
Returns:
The output can be a list of
objects, list of list of objects, tuple of objects, tuple of list of objects, or a dict of list of objects.
objects, list of list of objects.
Its type is defined in `TranscriptionReturnType`.
"""
log_probs = outputs.pop('log_probs')
Expand All @@ -955,17 +949,16 @@ def _transcribe_output_processing(self, outputs, trcfg: MultiTaskTranscriptionCo

del log_probs, encoded_len

best_hypotheses, all_hypotheses = self.decoding.decode_predictions_tensor(
hypotheses = self.decoding.decode_predictions_tensor(
encoder_hidden_states=enc_states,
encoder_input_mask=enc_mask,
decoder_input_ids=decoder_input_ids,
return_hypotheses=trcfg.return_hypotheses,
)

del enc_states, enc_mask, decoder_input_ids
if all_hypotheses is None:
return best_hypotheses
return best_hypotheses, all_hypotheses

return hypotheses

def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader':
"""
Expand Down Expand Up @@ -1092,7 +1085,7 @@ def predict_step(
encoder_input_mask=enc_mask,
decoder_input_ids=batch.prompt,
return_hypotheses=False,
)[0]
)
if batch.cuts:
return list(zip(batch.cuts, text))
else:
Expand Down
32 changes: 11 additions & 21 deletions nemo/collections/asr/models/ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@
from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType
from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecoding, CTCDecodingConfig
from nemo.collections.asr.parts.utils.asr_batching import get_semi_sorted_batch_sampler
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis
from nemo.collections.asr.parts.utils.transcribe_utils import process_timestamp_outputs
from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config
from nemo.collections.common.parts.preprocessing.parsers import make_parser
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.classes.mixins import AccessMixin
from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, LogprobsType, NeuralType, SpectrogramType
from nemo.utils import logging
from nemo.utils.decorators import deprecated

__all__ = ['EncDecCTCModel']

Expand Down Expand Up @@ -612,7 +612,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0):
else:
log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len)

transcribed_texts, _ = self.wer.decoding.ctc_decoder_predictions_tensor(
transcribed_texts = self.wer.decoding.ctc_decoder_predictions_tensor(
decoder_outputs=log_probs,
decoder_lengths=encoded_len,
return_hypotheses=False,
Expand Down Expand Up @@ -703,15 +703,11 @@ def _transcribe_forward(self, batch: Any, trcfg: TranscribeConfig):
del greedy_predictions
return output

@deprecated(
explanation='The return type of args will be updated in the upcoming release to ensure a consistent output \
format across all decoder types, such that a Hypothesis object is always returned.'
)
def _transcribe_output_processing(self, outputs, trcfg: TranscribeConfig) -> GenericTranscriptionType:
logits = outputs.pop('logits')
logits_len = outputs.pop('logits_len')

current_hypotheses, all_hyp = self.decoding.ctc_decoder_predictions_tensor(
hypotheses = self.decoding.ctc_decoder_predictions_tensor(
logits,
decoder_lengths=logits_len,
return_hypotheses=trcfg.return_hypotheses,
Expand All @@ -732,30 +728,24 @@ def _transcribe_output_processing(self, outputs, trcfg: TranscribeConfig) -> Gen
# cudaMallocHost()-allocated tensor to be floating
# around. Were that to be the case, then the pinned
# memory cache would always miss.
current_hypotheses[idx].y_sequence = logits_cpu[idx, : logits_len[idx]].clone()
if current_hypotheses[idx].alignments is None:
current_hypotheses[idx].alignments = current_hypotheses[idx].y_sequence
hypotheses[idx].y_sequence = logits_cpu[idx, : logits_len[idx]].clone()
if hypotheses[idx].alignments is None:
hypotheses[idx].alignments = hypotheses[idx].y_sequence
del logits_cpu

# cleanup memory
del logits, logits_len

if trcfg.timestamps:
current_hypotheses = process_timestamp_outputs(
current_hypotheses, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride']
hypotheses = process_timestamp_outputs(
hypotheses, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride']
)
all_hyp = process_timestamp_outputs(
all_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride']
)

hypotheses = []
if all_hyp is None:
hypotheses += current_hypotheses
else:
hypotheses += all_hyp

return hypotheses

def get_best_hyptheses(self, all_hypothesis: list[list[Hypothesis]]):
return [hyp[0] for hyp in all_hypothesis]

def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader':
"""
Setup function for a temporary data loader which wraps the provided audio file.
Expand Down
37 changes: 11 additions & 26 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,11 @@
# limitations under the License.

import copy
import json
import os
import tempfile
from typing import Any, List, Optional, Tuple
from typing import Any, List, Optional, Union

import torch
from lightning.pytorch import Trainer
from omegaconf import DictConfig, OmegaConf, open_dict
from tqdm.auto import tqdm

from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs
from nemo.collections.asr.losses.ctc import CTCLoss
Expand All @@ -31,6 +27,7 @@
from nemo.collections.asr.parts.mixins.transcription import TranscriptionReturnType
from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType
from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecoding, CTCDecodingConfig
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis
from nemo.collections.asr.parts.utils.transcribe_utils import process_timestamp_outputs
from nemo.core.classes.common import PretrainedModelInfo
from nemo.core.classes.mixins import AccessMixin
Expand Down Expand Up @@ -200,15 +197,15 @@ def _transcribe_forward(self, batch: Any, trcfg: TranscribeConfig):

def _transcribe_output_processing(
self, outputs, trcfg: TranscribeConfig
) -> Tuple[List['Hypothesis'], List['Hypothesis']]:
) -> Union[List['Hypothesis'], List[List['Hypothesis']]]:
if self.cur_decoder == "rnnt":
return super()._transcribe_output_processing(outputs, trcfg)

# CTC Path
logits = outputs.pop('logits')
encoded_len = outputs.pop('encoded_len')

best_hyp, all_hyp = self.ctc_decoding.ctc_decoder_predictions_tensor(
hypotheses = self.ctc_decoding.ctc_decoder_predictions_tensor(
logits,
encoded_len,
return_hypotheses=trcfg.return_hypotheses,
Expand All @@ -218,35 +215,23 @@ def _transcribe_output_processing(
if trcfg.return_hypotheses:
# dump log probs per file
for idx in range(logits.shape[0]):
best_hyp[idx].y_sequence = logits[idx][: encoded_len[idx]]
if best_hyp[idx].alignments is None:
best_hyp[idx].alignments = best_hyp[idx].y_sequence
hypotheses[idx].y_sequence = logits[idx][: encoded_len[idx]]
if hypotheses[idx].alignments is None:
hypotheses[idx].alignments = hypotheses[idx].y_sequence

# DEPRECATED?
# if logprobs:
# for logit, elen in zip(logits, encoded_len):
# logits_list.append(logit[:elen])

if trcfg.timestamps:
best_hyp = process_timestamp_outputs(
best_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride']
)
all_hyp = process_timestamp_outputs(
all_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride']
hypotheses = process_timestamp_outputs(
hypotheses, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride']
)

del logits, encoded_len

hypotheses = []
all_hypotheses = []

hypotheses += best_hyp
if all_hyp is not None:
all_hypotheses += all_hyp
else:
all_hypotheses += best_hyp

return (hypotheses, all_hypotheses)
return hypotheses

def change_vocabulary(
self,
Expand Down Expand Up @@ -515,7 +500,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0):
encoded, encoded_len = self.forward(input_signal=signal, input_signal_length=signal_len)
del signal

best_hyp_text, all_hyp_text = self.decoding.rnnt_decoder_predictions_tensor(
best_hyp_text = self.decoding.rnnt_decoder_predictions_tensor(
encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=False
)
if isinstance(sample_id, torch.Tensor):
Expand Down
32 changes: 8 additions & 24 deletions nemo/collections/asr/models/rnnt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import copy
import os
from math import ceil
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Union

import numpy as np
import torch
Expand All @@ -40,14 +40,14 @@
from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType
from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecoding, RNNTDecodingConfig
from nemo.collections.asr.parts.utils.asr_batching import get_semi_sorted_batch_sampler
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis
from nemo.collections.asr.parts.utils.transcribe_utils import process_timestamp_outputs
from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config
from nemo.collections.common.parts.preprocessing.parsers import make_parser
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.classes.mixins import AccessMixin
from nemo.core.neural_types import AcousticEncodedRepresentation, AudioSignal, LengthsType, NeuralType, SpectrogramType
from nemo.utils import logging
from nemo.utils.decorators import deprecated


class EncDecRNNTModel(ASRModel, ASRModuleMixin, ExportableEncDecModel, ASRTranscriptionMixin):
Expand Down Expand Up @@ -814,7 +814,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0):
encoded, encoded_len = self.forward(input_signal=signal, input_signal_length=signal_len)
del signal

best_hyp_text, all_hyp_text = self.decoding.rnnt_decoder_predictions_tensor(
best_hyp_text = self.decoding.rnnt_decoder_predictions_tensor(
encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=False
)

Expand Down Expand Up @@ -936,17 +936,13 @@ def _transcribe_forward(self, batch: Any, trcfg: TranscribeConfig):
output = dict(encoded=encoded, encoded_len=encoded_len)
return output

@deprecated(
explanation='The return type of args will be updated in the upcoming release to ensure a consistent \
output format across all decoder types, such that a "Hypothesis" object is always returned.'
)
def _transcribe_output_processing(
self, outputs, trcfg: TranscribeConfig
) -> Tuple[List['Hypothesis'], List['Hypothesis']]:
) -> Union[List['Hypothesis'], List[List['Hypothesis']]]:
encoded = outputs.pop('encoded')
encoded_len = outputs.pop('encoded_len')

best_hyp, all_hyp = self.decoding.rnnt_decoder_predictions_tensor(
hyp = self.decoding.rnnt_decoder_predictions_tensor(
encoded,
encoded_len,
return_hypotheses=trcfg.return_hypotheses,
Expand All @@ -956,23 +952,11 @@ def _transcribe_output_processing(
del encoded, encoded_len

if trcfg.timestamps:
best_hyp = process_timestamp_outputs(
best_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride']
hyp = process_timestamp_outputs(
hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride']
)
all_hyp = process_timestamp_outputs(
all_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride']
)

hypotheses = []
all_hypotheses = []

hypotheses += best_hyp
if all_hyp is not None:
all_hypotheses += all_hyp
else:
all_hypotheses += best_hyp

return (hypotheses, all_hypotheses)
return hyp

def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader':
"""
Expand Down
Loading

0 comments on commit ee543c2

Please sign in to comment.