Skip to content

Commit

Permalink
added validation jams file for guitarset, removed wavs, modified test…
Browse files Browse the repository at this point in the history
…s to use generated wav files, added tests for example deserialize
  • Loading branch information
bgenchel committed Aug 11, 2024
1 parent 805ec25 commit c7db2bb
Show file tree
Hide file tree
Showing 8 changed files with 70,408 additions and 25 deletions.
1 change: 0 additions & 1 deletion basic_pitch/data/datasets/guitarset.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def process(self, element: List[str], *args: Tuple[Any, Any], **kwargs: Dict[str
duration = sox.file_info.duration(local_wav_path)
time_scale = np.arange(0, duration + ANNOTATION_HOP, ANNOTATION_HOP)
n_time_frames = len(time_scale)

note_indices, note_values = track_local.notes_all.to_sparse_index(
time_scale, "s", FREQ_BINS_NOTES, "hz"
)
Expand Down
4 changes: 2 additions & 2 deletions basic_pitch/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ def main(
def console_entry_point() -> None:
"""From pip installed script."""
parser = argparse.ArgumentParser(description="")
parser.add_argument("--source", help="Path to directory containing train/validation splits.")
parser.add_argument("--output", help="Directory to save the model in.")
parser.add_argument("--source", required=True, help="Path to directory containing train/validation splits.")
parser.add_argument("--output", required=True, help="Directory to save the model in.")
parser.add_argument("-e", "--epochs", type=int, default=500, help="Number of training epochs.")
parser.add_argument(
"-b",
Expand Down
34 changes: 26 additions & 8 deletions tests/data/test_guitarset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import itertools
import os
import pathlib
import shutil

from apache_beam.testing.test_pipeline import TestPipeline
from typing import List
Expand All @@ -29,24 +30,41 @@
)
from basic_pitch.data.pipeline import WriteBatchToTfRecord

from utils import create_mock_wav

RESOURCES_PATH = pathlib.Path(__file__).parent.parent / "resources"
TRACK_ID = "00_BN1-129-Eb_comp"


def test_guitarset_to_tf_example(tmpdir: str) -> None:
def test_guitarset_to_tf_example(tmp_path: pathlib.Path) -> None:
mock_guitarset_home = tmp_path / "guitarset"
mock_guitarset_audio = mock_guitarset_home / "audio_mono-mic"
mock_guitarset_annotations = mock_guitarset_home / "annotation"
output_dir = tmp_path / "output"

mock_guitarset_audio.mkdir(parents=True)
mock_guitarset_annotations.mkdir(parents=True)
output_dir.mkdir()

create_mock_wav(mock_guitarset_audio / f"{TRACK_ID}_mic.wav", duration_min=1)
shutil.copy(
RESOURCES_PATH / "data" / "guitarset" / "annotation" / f"{TRACK_ID}.jams",
mock_guitarset_annotations / f"{TRACK_ID}.jams",
)

input_data: List[str] = [TRACK_ID]
with TestPipeline() as p:
(
p
| "Create PCollection of track IDs" >> beam.Create([input_data])
| "Create tf.Example"
>> beam.ParDo(GuitarSetToTfExample(str(RESOURCES_PATH / "data" / "guitarset"), download=False))
| "Write to tfrecord" >> beam.ParDo(WriteBatchToTfRecord(tmpdir))
| "Create tf.Example" >> beam.ParDo(GuitarSetToTfExample(str(mock_guitarset_home), download=False))
| "Write to tfrecord" >> beam.ParDo(WriteBatchToTfRecord(str(output_dir)))
)

assert len(os.listdir(tmpdir)) == 1
assert os.path.splitext(os.listdir(tmpdir)[0])[-1] == ".tfrecord"
with open(os.path.join(tmpdir, os.listdir(tmpdir)[0]), "rb") as fp:
listdir = os.listdir(output_dir)
assert len(listdir) == 1
assert os.path.splitext(listdir[0])[-1] == ".tfrecord"
with open(output_dir / listdir[0], "rb") as fp:
data = fp.read()
assert len(data) != 0

Expand Down Expand Up @@ -77,7 +95,7 @@ def test_guitarset_create_input_data() -> None:
data = create_input_data(train_percent=0.33, validation_percent=0.33)
data.sort(key=lambda el: el[1]) # sort by split
tolerance = 0.1
for key, group in itertools.groupby(data, lambda el: el[1]):
for _, group in itertools.groupby(data, lambda el: el[1]):
assert (0.33 - tolerance) * len(data) <= len(list(group)) <= (0.33 + tolerance) * len(data)


Expand Down
125 changes: 111 additions & 14 deletions tests/data/test_tf_example_deserialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import apache_beam as beam
import numpy as np
import os
import pathlib
import shutil
import tensorflow as tf

from apache_beam.testing.test_pipeline import TestPipeline
from typing import List

from basic_pitch.data.tf_example_deserialization import transcription_dataset, transcription_file_generator
from basic_pitch.data.datasets.guitarset import GuitarSetToTfExample
from basic_pitch.data.pipeline import WriteBatchToTfRecord
from basic_pitch.data.tf_example_deserialization import (
prepare_datasets,
prepare_visualization_datasets,
sample_datasets,
transcription_file_generator,
)

from utils import create_mock_wav

RESOURCES_PATH = pathlib.Path(__file__).parent.parent / "resources"
TRAIN_TRACK_ID = "00_BN1-129-Eb_comp"
VALID_TRACK_ID = "00_BN1-147-Gb_comp"


def create_empty_tfrecord(filepath: pathlib.Path) -> None:
Expand All @@ -30,24 +46,105 @@ def create_empty_tfrecord(filepath: pathlib.Path) -> None:
writer.write("")


# def test_prepare_dataset() -> None:
# pass
def create_tfrecord(input_data: List[str], dataset_home: str, output_dir: str) -> None:
with TestPipeline() as p:
(
p
| "Create PCollection of track IDs" >> beam.Create([input_data])
| "Create tf.Example" >> beam.ParDo(GuitarSetToTfExample(dataset_home, download=False))
| "Write to tfrecord" >> beam.ParDo(WriteBatchToTfRecord(output_dir))
)


def setup_test_resources(tmp_path: pathlib.Path) -> pathlib.Path:
mock_guitarset_home = tmp_path / "guitarset"
mock_guitarset_audio = mock_guitarset_home / "audio_mono-mic"
mock_guitarset_annotations = mock_guitarset_home / "annotation"

mock_guitarset_audio.mkdir(parents=True)
mock_guitarset_annotations.mkdir(parents=True)

output_home = tmp_path / "data" / "basic_pitch"
output_splits_dir = output_home / "guitarset" / "splits"

def mock_and_process(split: str, track_id: str) -> None:
create_mock_wav(mock_guitarset_audio / f"{track_id}_mic.wav", duration_min=1)
shutil.copy(
RESOURCES_PATH / "data" / "guitarset" / "annotation" / f"{track_id}.jams",
mock_guitarset_annotations / f"{track_id}.jams",
)

output_dir = output_splits_dir / split
output_dir.mkdir(parents=True)

create_tfrecord(input_data=[track_id], dataset_home=str(mock_guitarset_home), output_dir=str(output_dir))

mock_and_process("train", TRAIN_TRACK_ID)
mock_and_process("validation", VALID_TRACK_ID)

# def test_sample_datasets() -> None:
# pass
return output_home


# def test_transcription_dataset(tmp_path: pathlib.Path) -> None:
# dataset_path = tmp_path / "test_ds" / "splits" / "train"
# dataset_path.mkdir(parents=True)
# create_empty_tfrecord(dataset_path / "test.tfrecord")
def test_prepare_datasets(tmp_path: pathlib.Path) -> None:
datasets_home = setup_test_resources(tmp_path)

# file_gen, random_seed = transcription_file_generator(
# "train", ["test_ds"], datasets_base_path=str(tmp_path), sample_weights=np.array([1])
# )
ds_train, ds_valid = prepare_datasets(
datasets_base_path=str(datasets_home),
training_shuffle_buffer_size=1,
batch_size=1,
validation_steps=1,
datasets_to_use=["guitarset"],
dataset_sampling_frequency=np.array([1]),
)

assert ds_train is not None and isinstance(ds_train, tf.data.Dataset)
assert ds_valid is not None and isinstance(ds_valid, tf.data.Dataset)


def test_prepare_visualization_dataset(tmp_path: pathlib.Path) -> None:
datasets_home = setup_test_resources(tmp_path)

ds_train, ds_valid = prepare_visualization_datasets(
datasets_base_path=str(datasets_home),
batch_size=1,
validation_steps=1,
datasets_to_use=["guitarset"],
dataset_sampling_frequency=np.array([1]),
)

assert ds_train is not None and isinstance(ds_train, tf.data.Dataset)
assert ds_valid is not None and isinstance(ds_train, tf.data.Dataset)


def test_sample_datasets(tmp_path: pathlib.Path) -> None:
"""touches the following methods:
- transcription_dataset
- parse_transcription_tfexample
- is_not_bad_shape
- sparse2dense
- reduce_transcription_inputs
- get_sample_weights
- _infer_time_size
- get_transcription_chunks
- extract_random_window
- extract_window
- trim_time
- is_not_all_silent_annotations
- to_transcription_training_input
"""
datasets_home = setup_test_resources(tmp_path)

ds = sample_datasets(
split="train",
datasets_base_path=str(datasets_home),
datasets=["guitarset"],
dataset_sampling_frequency=np.array([1]),
n_shuffle=1,
n_samples_per_track=1,
pairs=True,
)

# transcription_dataset(file_generator=file_gen, n_samples_per_track=1, random_seed=random_seed)
assert ds is not None and isinstance(ds, tf.data.Dataset)


def test_transcription_file_generator_train(tmp_path: pathlib.Path) -> None:
Expand Down
5 changes: 5 additions & 0 deletions tests/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@


def create_mock_wav(output_fpath: pathlib.Path, duration_min: int) -> None:
assert output_fpath.suffix == ".wav"

duration_seconds = duration_min * 60
sample_rate = 44100
n_channels = 2 # Stereo
Expand All @@ -45,6 +47,8 @@ def create_mock_wav(output_fpath: pathlib.Path, duration_min: int) -> None:


def create_mock_flac(output_fpath: pathlib.Path) -> None:
assert output_fpath.suffix == ".flac"

frequency = 440 # A4
duration = 2 # seconds
sample_rate = 44100 # standard
Expand All @@ -60,6 +64,7 @@ def create_mock_flac(output_fpath: pathlib.Path) -> None:


def create_mock_midi(output_fpath: pathlib.Path) -> None:
assert output_fpath.suffix in (".mid", ".midi")
# Create a new MIDI file with one track
mid = MidiFile()
track = MidiTrack()
Expand Down
Loading

0 comments on commit c7db2bb

Please sign in to comment.