Skip to content

Commit

Permalink
Refactor asr.py into ASREvaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
Ahmedsaed committed Aug 1, 2024
1 parent 10cbde6 commit c3be495
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 144 deletions.
24 changes: 22 additions & 2 deletions src/fairseq2/recipes/eval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,22 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from pathlib import Path
from typing import Any, Callable

from fairseq2.logging import get_log_writer
from fairseq2.recipes.cli import Cli, CliGroup, RecipeCommandHandler
from fairseq2.recipes.eval.asr import AsrEvalConfig
from fairseq2.recipes.eval.configs import wav2vec2_presets, whisper_presets

log = get_log_writer(__name__)


def _add_wav2vev2_asr_eval_cli(group: CliGroup) -> None:
from fairseq2.recipes.eval.asr import load_wav2vec2_asr_evaluator
from fairseq2.recipes.eval.asr import ASREvaluator

handler = RecipeCommandHandler(
load_wav2vec2_asr_evaluator,
ASREvaluator(),
preset_configs=wav2vec2_presets,
default_preset="librispeech_asr",
)
Expand All @@ -26,6 +30,21 @@ def _add_wav2vev2_asr_eval_cli(group: CliGroup) -> None:
)


def _add_whisper_asr_eval_cli(group: CliGroup) -> None:
from fairseq2.recipes.eval.asr import ASREvaluator

handler = RecipeCommandHandler(
ASREvaluator(),
preset_configs=whisper_presets,
default_preset="librispeech_asr",
)
group.add_command(
"whisper-asr",
handler,
help="evaluate a whisper ASR model in downstream benchmark",
)


def has_datasets() -> bool:
try:
import datasets # type: ignore[attr-defined,import-untyped,import-not-found]
Expand Down Expand Up @@ -57,3 +76,4 @@ def _setup_eval_cli(cli: Cli) -> None:

if all((has_datasets(), has_evaluate())):
_add_wav2vev2_asr_eval_cli(group)
# _add_whisper_asr_eval_cli(group)
270 changes: 128 additions & 142 deletions src/fairseq2/recipes/eval/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Any, List, Optional, Tuple, cast
from typing import Any, Callable, List, Optional, Tuple, cast

import torch
from datasets import ( # type: ignore[attr-defined,import-untyped,import-not-found]
Expand All @@ -18,15 +18,18 @@

from fairseq2.data.data_pipeline import SequenceData
from fairseq2.data.text import load_text_tokenizer
from fairseq2.data.text.text_tokenizer import TextTokenEncoder, TextTokenizer
from fairseq2.datasets.batching import StaticBatching
from fairseq2.datasets.huggingface import Example, create_hf_reader
from fairseq2.logging import get_log_writer
from fairseq2.models.seq2seq import Seq2SeqBatch
from fairseq2.models.sequence import SequenceBatch
from fairseq2.models.wav2vec2.asr import load_wav2vec2_asr_model
from fairseq2.nn.padding import get_seqs_and_padding_mask
from fairseq2.recipes.eval.configs import HFEvalConfig, wav2vec2_presets, whisper_presets
from fairseq2.recipes.eval.configs import (
HFEvalConfig,
wav2vec2_presets,
whisper_presets,
)
from fairseq2.recipes.evaluator import HFEvaluator
from fairseq2.recipes.utils.setup import setup_root_gang
from fairseq2.typing import META, DataType
Expand All @@ -39,9 +42,6 @@
class AsrEvalConfig(HFEvalConfig):
"""Holds the configuration of a ASR evaluation recipe."""

# converter: Callable[[Example], Seq2SeqBatch]
# """The converter function to convert collated data into Seq2SeqBatch"""

tokenizer_name: str = "librispeech_asr"
"""The tokenizer to use."""

Expand Down Expand Up @@ -74,148 +74,134 @@ class AsrEvalConfig(HFEvalConfig):
"""The data type of the model."""


def _librispeech_asr_to_batch(examples: Example) -> Seq2SeqBatch:
"""
Converts a collated batch of examples into a Seq2SeqBatch.
Args:
examples (dict): A dictionary containing "audio" and "text" keys.
Returns:
Seq2SeqBatch: A batch of audio and text sequences.
"""
source_data = cast(SequenceData, examples["audio"])
target_data = cast(SequenceData, examples["text"])

source_seqs, source_padding_mask = get_seqs_and_padding_mask(source_data)
target_seqs, target_padding_mask = get_seqs_and_padding_mask(target_data)

return Seq2SeqBatch(
source_seqs,
source_padding_mask,
target_seqs,
target_padding_mask,
examples,
)


def _preprocess_example(
example: Example, encoder: TextTokenEncoder, device: torch.device
) -> Example:
"""
Preprocesses an individual example by converting the audio array to a PyTorch tensor
and encoding the text.
Args:
example (dict): A dictionary containing "audio" and "text" keys.
Returns:
dict: A dictionary with "audio" and "text" as PyTorch tensors.
"""
audio_tensor = (
torch.from_numpy(example["audio"]["array"]).to(torch.float16).to(device)
)
text_tensor = encoder(example["text"].lower()).to(device)
return {"audio": audio_tensor, "text": text_tensor}


def seq2seq_preprocessor(batch: Seq2SeqBatch) -> Tuple[SequenceBatch, SequenceBatch]:
return SequenceBatch(batch.source_seqs, batch.source_padding_mask), SequenceBatch(
batch.target_seqs, batch.target_padding_mask
)


def postprocesser(
outputs: Any, targets: SequenceBatch, tokenizer: TextTokenizer
) -> Tuple[List[str], List[str]]:
decoder = tokenizer.create_decoder()
pad_idx = tokenizer.vocab_info.pad_idx

hypotheses, _ = outputs.generate_hypotheses(pad_idx=pad_idx)
predictions = [decoder(item) for item in hypotheses]
references = [decoder(item) for item in targets.seqs.to(torch.int32)]

return predictions, references


@hf_presets.decorator("librispeech_asr")
def _librispeech_asr_config() -> AsrEvalConfig:
@wav2vec2_presets.decorator("librispeech_asr")
def _wav2vec2_librispeech_asr_config() -> AsrEvalConfig:
return AsrEvalConfig(
dataset_name="librispeech_asr",
model_name="wav2vec2_asr_base_10h",
split="test.other"
# converter=librispeech_asr_to_batch,
split="test.other",
)


def load_wav2vec2_asr_evaluator(
config: HFEvalConfig, output_dir: Path
) -> HFEvaluator[Seq2SeqBatch]:
"""
Load the evaluator used for downstream evaluation of the model
in a downstream dataset and report BLEU scores
Args:
config (HFEvalConfig): The configuration for the evaluation.
output_dir (Path): The output directory to store the evaluation results.
Returns:
HFEvaluator: Evaluation process results.
"""
if not isinstance(config, AsrEvalConfig):
raise ValueError(f"Expect AsrEvalConfig, get {type(config)}")

iterable_ds = load_dataset(config.dataset_name, split=config.split, streaming=True)
max_samples = config.max_samples if config.max_samples is not None else math.inf
# Load a subset of the dataset if max_samples is set
ds = Dataset.from_generator(
lambda: (
yield from (
item for idx, item in enumerate(iterable_ds) if idx < max_samples
)
),
features=iterable_ds.features,
)

gang = setup_root_gang(log)

if gang.rank == 0:
init_device = gang.device
else:
init_device = META

tokenizer = load_text_tokenizer(config.tokenizer_name)
encoder = tokenizer.create_encoder(device=init_device)

ds = ds.map(lambda x: _preprocess_example(x, encoder, init_device))
format = {
"type": "torch",
"format_kwargs": {"dtype": torch.float16, "device": init_device},
}
ds.set_format(**format, columns=["audio", "text"])

pipeline_reader = create_hf_reader(
dataset=ds,
gang=gang,
converter=_librispeech_asr_to_batch,
batching=StaticBatching(config.max_num_elements),
num_prefetch=config.num_prefetch,
pad_value=tokenizer.vocab_info.pad_idx,
max_seq_len=config.max_audio_len,
)

model = load_wav2vec2_asr_model(
config.model_name, device=init_device, dtype=config.dtype
@whisper_presets.decorator("librispeech_asr")
def _whisper_librispeech_asr_config() -> AsrEvalConfig:
return AsrEvalConfig(
dataset_name="librispeech_asr", model_name="whisper", split="test.other"
)

wall_watch = Stopwatch(start=True, device=init_device)

return HFEvaluator[Seq2SeqBatch](
model=model,
metrics=["bleu"],
gang=gang,
data_reader=pipeline_reader,
wall_watch=wall_watch,
preprocessor=seq2seq_preprocessor,
postprocessor=lambda x, y: postprocesser(x, y, tokenizer),
)
class ASREvaluator:
def __init__(self) -> None:
self.gang = setup_root_gang(log)
self.init_device = self.gang.device if self.gang.rank == 0 else META

def _librispeech_asr_to_batch(self, examples: Example) -> Seq2SeqBatch:
source_data = cast(SequenceData, examples["audio"])
target_data = cast(SequenceData, examples["text"])

source_seqs, source_padding_mask = get_seqs_and_padding_mask(source_data)
target_seqs, target_padding_mask = get_seqs_and_padding_mask(target_data)

return Seq2SeqBatch(
source_seqs,
source_padding_mask,
target_seqs,
target_padding_mask,
examples,
)

def _preprocess_example(self, example: Example) -> Example:
audio_tensor = (
torch.from_numpy(example["audio"]["array"])
.to(torch.float16)
.to(self.init_device)
)
text_tensor = self.encoder(example["text"].lower()).to(self.init_device)
return {"audio": audio_tensor, "text": text_tensor}

def seq2seq_preprocessor(
self, batch: Seq2SeqBatch
) -> Tuple[SequenceBatch, SequenceBatch]:
return SequenceBatch(
batch.source_seqs, batch.source_padding_mask
), SequenceBatch(batch.target_seqs, batch.target_padding_mask)

def postprocesser(
self, outputs: Any, targets: SequenceBatch
) -> Tuple[List[str], List[str]]:
decoder = self.tokenizer.create_decoder()
pad_idx = self.tokenizer.vocab_info.pad_idx

hypotheses, _ = outputs.generate_hypotheses(pad_idx=pad_idx)
predictions = [decoder(item) for item in hypotheses]
references = [decoder(item) for item in targets.seqs.to(torch.int32)]

return predictions, references

def _load_evaluator(self) -> HFEvaluator[Seq2SeqBatch]:
iterable_ds = load_dataset(
self.config.dataset_name, split=self.config.split, streaming=True
)
max_samples = (
self.config.max_samples if self.config.max_samples is not None else math.inf
)

ds = Dataset.from_generator(
lambda: (
yield from (
item for idx, item in enumerate(iterable_ds) if idx < max_samples
)
),
features=iterable_ds.features,
)

ds = ds.map(self._preprocess_example)
format = {
"type": "torch",
"format_kwargs": {"dtype": torch.float16, "device": self.init_device},
}
ds.set_format(**format, columns=["audio", "text"])

pipeline_reader = create_hf_reader(
dataset=ds,
gang=self.gang,
converter=self._librispeech_asr_to_batch,
batching=StaticBatching(self.config.max_num_elements),
num_prefetch=self.config.num_prefetch,
pad_value=self.tokenizer.vocab_info.pad_idx,
max_seq_len=self.config.max_audio_len,
)

model = load_wav2vec2_asr_model(
self.config.model_name, device=self.init_device, dtype=self.config.dtype
)

wall_watch = Stopwatch(start=True, device=self.init_device)

return HFEvaluator[Seq2SeqBatch](
model=model,
metrics=["bleu"],
gang=self.gang,
data_reader=pipeline_reader,
wall_watch=wall_watch,
preprocessor=self.seq2seq_preprocessor,
postprocessor=lambda x, y: self.postprocesser(x, y),
)

def __call__(self, config: HFEvalConfig, output_dir: Path) -> Callable[[], None]:
"""
This method will run the evaluation process.
Returns:
A callable that will run the evaluation process
"""
if not isinstance(config, AsrEvalConfig):
raise ValueError(f"Expect AsrEvalConfig, get {type(config)}")

self.config = config
self.output_dir = output_dir

self.tokenizer = load_text_tokenizer(config.tokenizer_name)
self.encoder = self.tokenizer.create_encoder(device=self.init_device)

return self._load_evaluator()

0 comments on commit c3be495

Please sign in to comment.