Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BestRQ pretraining #873

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 201 additions & 6 deletions src/fairseq2/datasets/speech.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

Check failure on line 1 in src/fairseq2/datasets/speech.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

would reformat
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand All @@ -8,18 +8,32 @@

from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Literal, final
from typing import Any, Literal, cast, final

import torch
from typing_extensions import override

from fairseq2.assets import AssetCard, AssetError
from fairseq2.data.audio import AudioDecoder
from fairseq2.data.text import StrSplitter, TextTokenizer, read_text

Check failure on line 18 in src/fairseq2/datasets/speech.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

'fairseq2.data.text.TextTokenizer' imported but unused
from fairseq2.datasets.batching import Batching
from fairseq2.datasets.data_reader import DataPipelineReader, DataReader
from fairseq2.datasets.loader import AbstractDatasetLoader, DelegatingDatasetLoader
from fairseq2.gang import Gang
from fairseq2.models.sequence import SequenceBatch
from fairseq2.typing import DataType
from fairseq2.nn.padding import get_seqs_and_padding_mask
from fairseq2.datasets.batching import Batching, LengthBatching, StaticBatching

Check failure on line 26 in src/fairseq2/datasets/speech.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

redefinition of unused 'Batching' from line 19
from fairseq2.data import (

Check failure on line 27 in src/fairseq2/datasets/speech.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

'fairseq2.data.CollateOptionsOverride' imported but unused
CollateOptionsOverride,
Collater,
DataPipeline,
DataPipelineBuilder,
FileMapper,
SequenceData,
create_bucket_sizes,
read_sequence,
)


class SpeechDataset(ABC):
Expand Down Expand Up @@ -107,18 +121,42 @@

load_speech_dataset = DelegatingDatasetLoader[SpeechDataset]()

# TODO: FIX, INFER
npc = 10

@final
class GenericSpeechDataset(SpeechDataset):
"""Represents a generic manifest-based Speech dataset."""
"""Represents a generic manifest-based ASR dataset."""

def __init__(self) -> None:
pass
_manifest_dir: Path
_splits: set[str]

def __init__(self, manifest_dir: Path, splits: set[str]) -> None:
"""
:param manifest_dir:
The directory under which the manifest files resides.
:param splits:
The available splits.
"""
self._manifest_dir = manifest_dir
self._splits = splits

@classmethod
def from_path(cls, path: Path) -> GenericSpeechDataset:
"""Load a :class:`GenericSpeechDataset` from ``path``."""
return GenericSpeechDataset()
path = path.expanduser().resolve()
if not path.is_dir():
return GenericSpeechDataset(manifest_dir=path.parent, splits={path.stem})

try:
splits = {f.stem for f in path.glob("*.tsv")}
except OSError as ex:
raise RuntimeError(
"The splits cannot be determined. See nested exception for details."
) from ex

return GenericSpeechDataset(path, splits)

@override
def create_reader(
Expand Down Expand Up @@ -148,11 +186,168 @@
The maximum number of file descriptors to keep open while reading
audio files.
"""
raise RuntimeError("not supported yet.")
if split not in self._splits:
raise ValueError(
f"`split` must be one of the following splits, but is '{split}' instead: {', '.join(sorted(self._splits))}"
)

audio_dir = self._retrieve_data_directory(split)

builder = self._read_manifest(split)

# Shuffle examples. Must be consistent across all processes.
if example_shuffle_window != 1:
builder.shuffle(example_shuffle_window, seed)

seed += 1

# Shard.
builder.shard(gang.rank, gang.size, allow_uneven=True)

seed += gang.rank

if isinstance(batching, LengthBatching):
# Bucket by the audio length.
bucket_sizes = create_bucket_sizes(
max_seq_len=max_audio_len,
min_seq_len=min_audio_len,
max_num_elements=batching.max_num_elements,
num_seqs_multiple_of=8,
)

builder.bucket_by_length(
bucket_sizes,
selector="audio_size",
min_data_len=min_audio_len,
skip_below_min_examples=True,
skip_above_max_examples=True,
drop_remainder=drop_remainder,
)
elif isinstance(batching, StaticBatching):
# Filter out out-of-range audios.
def skip(example: dict[str, Any]) -> bool:
audio_len = cast(int, example["audio_size"])

return audio_len >= min_audio_len and audio_len <= max_audio_len

builder.filter(skip)

# Bucket `batch_size` examples.
builder.bucket(batching.batch_size, drop_remainder=drop_remainder)
else:
raise RuntimeError(f"`{batching}` is not supported.")

# Shuffle buckets.
if batch_shuffle_window != 1:
builder.shuffle(batch_shuffle_window, seed)

seed += 1

# Memory map audio files.
file_mapper = FileMapper(audio_dir, cached_fd_count=cached_fd_count)

builder.map(file_mapper, selector="[*].audio")

# Decode audio.
audio_decoder = AudioDecoder(dtype=torch.float32 if normalize_audio else dtype)

builder.map(audio_decoder, selector="[*].audio.data")

# TODO(balioglu): Check/adjust sample size

# Normalize audio if requested.
def normalize(waveform: Tensor) -> Tensor:

Check failure on line 259 in src/fairseq2/datasets/speech.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

undefined name 'Tensor'

Check failure on line 259 in src/fairseq2/datasets/speech.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

undefined name 'Tensor'
with torch.no_grad():
waveform = layer_norm(waveform, waveform.shape)

Check failure on line 261 in src/fairseq2/datasets/speech.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

undefined name 'layer_norm'

return waveform.to(dtype)

if normalize_audio:
builder.map(normalize, selector="[*].audio.data.waveform")

collater = Collater(pad_value=0)

builder.map(collater, num_parallel_calls=npc)

# Return only the first `max_num_batches`.
if max_num_batches is not None:
builder.take(max_num_batches)

# Prefetch `num_prefetch` batches in background.
builder.prefetch(num_prefetch)

# Wrap examples with `Seq2SeqBatch`.
def to_batch(example: dict[str, Any]) -> SequenceBatch:
source_data = cast(SequenceData, example["audio"]["data"]["waveform"])

source_seqs, source_padding_mask = get_seqs_and_padding_mask(
source_data, gang.device
)

return SequenceBatch(
seqs=source_seqs,
padding_mask=source_padding_mask,
example=example,
)

pipeline = builder.map(to_batch).and_return()

return DataPipelineReader[SequenceBatch](
pipeline,
gang,
num_accumulate=num_accumulate,
drop_remainder=drop_remainder,
sync_batches=sync_batches,
sync_mode=sync_mode,
)

def _retrieve_data_directory(self, split: str) -> Path:
manifest_file = self._manifest_dir.joinpath(f"{split}.tsv")

try:
with manifest_file.open() as fp:
line = fp.readline().rstrip()
except OSError as ex:
raise DatasetError(

Check failure on line 311 in src/fairseq2/datasets/speech.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

undefined name 'DatasetError'
f"{manifest_file} cannot be read. See nested exception for details."
) from ex

try:
return Path(line)
except ValueError:
raise DatasetError(

Check failure on line 318 in src/fairseq2/datasets/speech.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

undefined name 'DatasetError'
f"The first line of {manifest_file} must point to a data directory."
) from None

def _read_manifest(self, split: str) -> DataPipelineBuilder:
def read_tsv_file() -> DataPipelineBuilder:
tsv_file = self._manifest_dir.joinpath(f"{split}.tsv")

builder = read_text(tsv_file, rtrim=True, memory_map=True)

builder.skip(1) # Path to the data directory.

field_splitter = StrSplitter(names=["audio", "audio_size"])

builder.map(field_splitter, num_parallel_calls=npc)

return builder

tsv_pipeline = read_tsv_file().and_return()

builder = DataPipeline.zip([tsv_pipeline], flatten=True)

# Cast audio size to integer.
builder.map(int, selector="audio_size")

# TODO(balioglu): Use `cache()` op.
manifest = list(builder.and_return())

return read_sequence(manifest)

@override
def splits(self) -> set[str]:
return set()
return self._splits


@final
Expand Down
2 changes: 2 additions & 0 deletions src/fairseq2/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

Check failure on line 1 in src/fairseq2/models/__init__.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Imports are incorrectly sorted and/or formatted.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand Down Expand Up @@ -32,6 +32,7 @@
from fairseq2.models.transformer import register_transformer
from fairseq2.models.w2vbert import register_w2vbert
from fairseq2.models.wav2vec2 import register_wav2vec2
from fairseq2.models.bestrq import register_bestrq
from fairseq2.models.wav2vec2.asr import register_wav2vec2_asr


Expand All @@ -44,3 +45,4 @@
register_w2vbert(container)
register_wav2vec2(container)
register_wav2vec2_asr(container)
register_bestrq(container)
40 changes: 40 additions & 0 deletions src/fairseq2/models/bestrq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

Check failure on line 1 in src/fairseq2/models/bestrq/__init__.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Imports are incorrectly sorted and/or formatted.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from fairseq2.models.bestrq.factory import BESTRQ_FAMILY as BESTRQ_FAMILY
from fairseq2.models.bestrq.factory import BestRQBuilder as BestRQBuilder
from fairseq2.models.bestrq.factory import BestRQConfig as BestRQConfig
from fairseq2.models.bestrq.factory import (
BestRQEncoderBuilder as BestRQEncoderBuilder,
)
from fairseq2.models.bestrq.factory import (
create_bestrq_model as create_bestrq_model,
)
from fairseq2.models.bestrq.factory import bestrq_arch as bestrq_arch
from fairseq2.models.bestrq.factory import bestrq_archs as bestrq_archs
from fairseq2.models.bestrq.factory import (
bestrq_encoder_arch as bestrq_encoder_arch,
)
from fairseq2.models.bestrq.factory import (
bestrq_encoder_archs as bestrq_encoder_archs,
)
from fairseq2.models.bestrq.masker import RandomNoiseMasker as RandomNoiseMasker
from fairseq2.models.bestrq.model import BestRQFeatures as BestRQFeatures
from fairseq2.models.bestrq.model import BestRQLoss as BestRQLoss
from fairseq2.models.bestrq.model import BestRQModel as BestRQModel
from fairseq2.models.bestrq.model import BestRQOutput as BestRQOutput
from fairseq2.models.bestrq.quantizer import MultiRandomVectorQuantizerOutput

# isort: split

from fairseq2.dependency import DependencyContainer
from fairseq2.models.bestrq.archs import register_archs


def register_bestrq(container: DependencyContainer) -> None:
register_archs()
29 changes: 29 additions & 0 deletions src/fairseq2/models/bestrq/archs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

Check failure on line 1 in src/fairseq2/models/bestrq/archs.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Imports are incorrectly sorted and/or formatted.

Check failure on line 1 in src/fairseq2/models/bestrq/archs.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

would reformat
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from fairseq2.models.bestrq.factory import (
BestRQConfig,
bestrq_arch,
bestrq_encoder_arch,
)
from fairseq2.models.wav2vec2.factory import (
Wav2Vec2EncoderConfig,
)
from fairseq2.nn.transformer import TransformerNormOrder

Check failure on line 17 in src/fairseq2/models/bestrq/archs.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

'fairseq2.nn.transformer.TransformerNormOrder' imported but unused


def register_archs() -> None:
@bestrq_arch("base")
def _base() -> BestRQConfig:
return BestRQConfig()

@bestrq_encoder_arch("base")
def _base_encoder() -> Wav2Vec2EncoderConfig:
config = _base()

return config.encoder_config

Check failure on line 29 in src/fairseq2/models/bestrq/archs.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

no newline at end of file
Loading
Loading