From c3be4959e0a9bc1edc83410a7f1ca2b9ea2396e1 Mon Sep 17 00:00:00 2001 From: Ahmed Saed <37080003+Ahmedsaed@users.noreply.github.com> Date: Thu, 1 Aug 2024 15:10:36 +0000 Subject: [PATCH] Refactor asr.py into ASREvaluator --- src/fairseq2/recipes/eval/__init__.py | 24 ++- src/fairseq2/recipes/eval/asr.py | 270 ++++++++++++-------------- 2 files changed, 150 insertions(+), 144 deletions(-) diff --git a/src/fairseq2/recipes/eval/__init__.py b/src/fairseq2/recipes/eval/__init__.py index c21e2ca32..1e1dd478c 100644 --- a/src/fairseq2/recipes/eval/__init__.py +++ b/src/fairseq2/recipes/eval/__init__.py @@ -4,18 +4,22 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from pathlib import Path +from typing import Any, Callable + from fairseq2.logging import get_log_writer from fairseq2.recipes.cli import Cli, CliGroup, RecipeCommandHandler +from fairseq2.recipes.eval.asr import AsrEvalConfig from fairseq2.recipes.eval.configs import wav2vec2_presets, whisper_presets log = get_log_writer(__name__) def _add_wav2vev2_asr_eval_cli(group: CliGroup) -> None: - from fairseq2.recipes.eval.asr import load_wav2vec2_asr_evaluator + from fairseq2.recipes.eval.asr import ASREvaluator handler = RecipeCommandHandler( - load_wav2vec2_asr_evaluator, + ASREvaluator(), preset_configs=wav2vec2_presets, default_preset="librispeech_asr", ) @@ -26,6 +30,21 @@ def _add_wav2vev2_asr_eval_cli(group: CliGroup) -> None: ) +def _add_whisper_asr_eval_cli(group: CliGroup) -> None: + from fairseq2.recipes.eval.asr import ASREvaluator + + handler = RecipeCommandHandler( + ASREvaluator(), + preset_configs=whisper_presets, + default_preset="librispeech_asr", + ) + group.add_command( + "whisper-asr", + handler, + help="evaluate a whisper ASR model in downstream benchmark", + ) + + def has_datasets() -> bool: try: import datasets # type: ignore[attr-defined,import-untyped,import-not-found] @@ -57,3 +76,4 @@ def _setup_eval_cli(cli: Cli) -> None: if all((has_datasets(), has_evaluate())): _add_wav2vev2_asr_eval_cli(group) + # _add_whisper_asr_eval_cli(group) diff --git a/src/fairseq2/recipes/eval/asr.py b/src/fairseq2/recipes/eval/asr.py index 7b297c8fe..de11bad50 100644 --- a/src/fairseq2/recipes/eval/asr.py +++ b/src/fairseq2/recipes/eval/asr.py @@ -8,7 +8,7 @@ import math from dataclasses import dataclass from pathlib import Path -from typing import Any, List, Optional, Tuple, cast +from typing import Any, Callable, List, Optional, Tuple, cast import torch from datasets import ( # type: ignore[attr-defined,import-untyped,import-not-found] @@ -18,7 +18,6 @@ from fairseq2.data.data_pipeline import SequenceData from fairseq2.data.text import load_text_tokenizer -from fairseq2.data.text.text_tokenizer import TextTokenEncoder, TextTokenizer from fairseq2.datasets.batching import StaticBatching from fairseq2.datasets.huggingface import Example, create_hf_reader from fairseq2.logging import get_log_writer @@ -26,7 +25,11 @@ from fairseq2.models.sequence import SequenceBatch from fairseq2.models.wav2vec2.asr import load_wav2vec2_asr_model from fairseq2.nn.padding import get_seqs_and_padding_mask -from fairseq2.recipes.eval.configs import HFEvalConfig, wav2vec2_presets, whisper_presets +from fairseq2.recipes.eval.configs import ( + HFEvalConfig, + wav2vec2_presets, + whisper_presets, +) from fairseq2.recipes.evaluator import HFEvaluator from fairseq2.recipes.utils.setup import setup_root_gang from fairseq2.typing import META, DataType @@ -39,9 +42,6 @@ class AsrEvalConfig(HFEvalConfig): """Holds the configuration of a ASR evaluation recipe.""" - # converter: Callable[[Example], Seq2SeqBatch] - # """The converter function to convert collated data into Seq2SeqBatch""" - tokenizer_name: str = "librispeech_asr" """The tokenizer to use.""" @@ -74,148 +74,134 @@ class AsrEvalConfig(HFEvalConfig): """The data type of the model.""" -def _librispeech_asr_to_batch(examples: Example) -> Seq2SeqBatch: - """ - Converts a collated batch of examples into a Seq2SeqBatch. - - Args: - examples (dict): A dictionary containing "audio" and "text" keys. - - Returns: - Seq2SeqBatch: A batch of audio and text sequences. - """ - source_data = cast(SequenceData, examples["audio"]) - target_data = cast(SequenceData, examples["text"]) - - source_seqs, source_padding_mask = get_seqs_and_padding_mask(source_data) - target_seqs, target_padding_mask = get_seqs_and_padding_mask(target_data) - - return Seq2SeqBatch( - source_seqs, - source_padding_mask, - target_seqs, - target_padding_mask, - examples, - ) - - -def _preprocess_example( - example: Example, encoder: TextTokenEncoder, device: torch.device -) -> Example: - """ - Preprocesses an individual example by converting the audio array to a PyTorch tensor - and encoding the text. - - Args: - example (dict): A dictionary containing "audio" and "text" keys. - - Returns: - dict: A dictionary with "audio" and "text" as PyTorch tensors. - """ - audio_tensor = ( - torch.from_numpy(example["audio"]["array"]).to(torch.float16).to(device) - ) - text_tensor = encoder(example["text"].lower()).to(device) - return {"audio": audio_tensor, "text": text_tensor} - - -def seq2seq_preprocessor(batch: Seq2SeqBatch) -> Tuple[SequenceBatch, SequenceBatch]: - return SequenceBatch(batch.source_seqs, batch.source_padding_mask), SequenceBatch( - batch.target_seqs, batch.target_padding_mask - ) - - -def postprocesser( - outputs: Any, targets: SequenceBatch, tokenizer: TextTokenizer -) -> Tuple[List[str], List[str]]: - decoder = tokenizer.create_decoder() - pad_idx = tokenizer.vocab_info.pad_idx - - hypotheses, _ = outputs.generate_hypotheses(pad_idx=pad_idx) - predictions = [decoder(item) for item in hypotheses] - references = [decoder(item) for item in targets.seqs.to(torch.int32)] - - return predictions, references - - -@hf_presets.decorator("librispeech_asr") -def _librispeech_asr_config() -> AsrEvalConfig: +@wav2vec2_presets.decorator("librispeech_asr") +def _wav2vec2_librispeech_asr_config() -> AsrEvalConfig: return AsrEvalConfig( dataset_name="librispeech_asr", model_name="wav2vec2_asr_base_10h", - split="test.other" - # converter=librispeech_asr_to_batch, + split="test.other", ) -def load_wav2vec2_asr_evaluator( - config: HFEvalConfig, output_dir: Path -) -> HFEvaluator[Seq2SeqBatch]: - """ - Load the evaluator used for downstream evaluation of the model - in a downstream dataset and report BLEU scores - - Args: - config (HFEvalConfig): The configuration for the evaluation. - output_dir (Path): The output directory to store the evaluation results. - - Returns: - HFEvaluator: Evaluation process results. - """ - if not isinstance(config, AsrEvalConfig): - raise ValueError(f"Expect AsrEvalConfig, get {type(config)}") - - iterable_ds = load_dataset(config.dataset_name, split=config.split, streaming=True) - max_samples = config.max_samples if config.max_samples is not None else math.inf - # Load a subset of the dataset if max_samples is set - ds = Dataset.from_generator( - lambda: ( - yield from ( - item for idx, item in enumerate(iterable_ds) if idx < max_samples - ) - ), - features=iterable_ds.features, - ) - - gang = setup_root_gang(log) - - if gang.rank == 0: - init_device = gang.device - else: - init_device = META - - tokenizer = load_text_tokenizer(config.tokenizer_name) - encoder = tokenizer.create_encoder(device=init_device) - - ds = ds.map(lambda x: _preprocess_example(x, encoder, init_device)) - format = { - "type": "torch", - "format_kwargs": {"dtype": torch.float16, "device": init_device}, - } - ds.set_format(**format, columns=["audio", "text"]) - - pipeline_reader = create_hf_reader( - dataset=ds, - gang=gang, - converter=_librispeech_asr_to_batch, - batching=StaticBatching(config.max_num_elements), - num_prefetch=config.num_prefetch, - pad_value=tokenizer.vocab_info.pad_idx, - max_seq_len=config.max_audio_len, - ) - - model = load_wav2vec2_asr_model( - config.model_name, device=init_device, dtype=config.dtype +@whisper_presets.decorator("librispeech_asr") +def _whisper_librispeech_asr_config() -> AsrEvalConfig: + return AsrEvalConfig( + dataset_name="librispeech_asr", model_name="whisper", split="test.other" ) - wall_watch = Stopwatch(start=True, device=init_device) - return HFEvaluator[Seq2SeqBatch]( - model=model, - metrics=["bleu"], - gang=gang, - data_reader=pipeline_reader, - wall_watch=wall_watch, - preprocessor=seq2seq_preprocessor, - postprocessor=lambda x, y: postprocesser(x, y, tokenizer), - ) +class ASREvaluator: + def __init__(self) -> None: + self.gang = setup_root_gang(log) + self.init_device = self.gang.device if self.gang.rank == 0 else META + + def _librispeech_asr_to_batch(self, examples: Example) -> Seq2SeqBatch: + source_data = cast(SequenceData, examples["audio"]) + target_data = cast(SequenceData, examples["text"]) + + source_seqs, source_padding_mask = get_seqs_and_padding_mask(source_data) + target_seqs, target_padding_mask = get_seqs_and_padding_mask(target_data) + + return Seq2SeqBatch( + source_seqs, + source_padding_mask, + target_seqs, + target_padding_mask, + examples, + ) + + def _preprocess_example(self, example: Example) -> Example: + audio_tensor = ( + torch.from_numpy(example["audio"]["array"]) + .to(torch.float16) + .to(self.init_device) + ) + text_tensor = self.encoder(example["text"].lower()).to(self.init_device) + return {"audio": audio_tensor, "text": text_tensor} + + def seq2seq_preprocessor( + self, batch: Seq2SeqBatch + ) -> Tuple[SequenceBatch, SequenceBatch]: + return SequenceBatch( + batch.source_seqs, batch.source_padding_mask + ), SequenceBatch(batch.target_seqs, batch.target_padding_mask) + + def postprocesser( + self, outputs: Any, targets: SequenceBatch + ) -> Tuple[List[str], List[str]]: + decoder = self.tokenizer.create_decoder() + pad_idx = self.tokenizer.vocab_info.pad_idx + + hypotheses, _ = outputs.generate_hypotheses(pad_idx=pad_idx) + predictions = [decoder(item) for item in hypotheses] + references = [decoder(item) for item in targets.seqs.to(torch.int32)] + + return predictions, references + + def _load_evaluator(self) -> HFEvaluator[Seq2SeqBatch]: + iterable_ds = load_dataset( + self.config.dataset_name, split=self.config.split, streaming=True + ) + max_samples = ( + self.config.max_samples if self.config.max_samples is not None else math.inf + ) + + ds = Dataset.from_generator( + lambda: ( + yield from ( + item for idx, item in enumerate(iterable_ds) if idx < max_samples + ) + ), + features=iterable_ds.features, + ) + + ds = ds.map(self._preprocess_example) + format = { + "type": "torch", + "format_kwargs": {"dtype": torch.float16, "device": self.init_device}, + } + ds.set_format(**format, columns=["audio", "text"]) + + pipeline_reader = create_hf_reader( + dataset=ds, + gang=self.gang, + converter=self._librispeech_asr_to_batch, + batching=StaticBatching(self.config.max_num_elements), + num_prefetch=self.config.num_prefetch, + pad_value=self.tokenizer.vocab_info.pad_idx, + max_seq_len=self.config.max_audio_len, + ) + + model = load_wav2vec2_asr_model( + self.config.model_name, device=self.init_device, dtype=self.config.dtype + ) + + wall_watch = Stopwatch(start=True, device=self.init_device) + + return HFEvaluator[Seq2SeqBatch]( + model=model, + metrics=["bleu"], + gang=self.gang, + data_reader=pipeline_reader, + wall_watch=wall_watch, + preprocessor=self.seq2seq_preprocessor, + postprocessor=lambda x, y: self.postprocesser(x, y), + ) + + def __call__(self, config: HFEvalConfig, output_dir: Path) -> Callable[[], None]: + """ + This method will run the evaluation process. + + Returns: + A callable that will run the evaluation process + """ + if not isinstance(config, AsrEvalConfig): + raise ValueError(f"Expect AsrEvalConfig, get {type(config)}") + + self.config = config + self.output_dir = output_dir + + self.tokenizer = load_text_tokenizer(config.tokenizer_name) + self.encoder = self.tokenizer.create_encoder(device=self.init_device) + + return self._load_evaluator()