Skip to content

Commit

Permalink
added slakh processing, tests, and test data, tox passing.
Browse files Browse the repository at this point in the history
  • Loading branch information
bgenchel-avail committed Jul 16, 2024
1 parent 91d220b commit 5cfa7fe
Show file tree
Hide file tree
Showing 95 changed files with 798 additions and 0 deletions.
226 changes: 226 additions & 0 deletions basic_pitch/data/datasets/slakh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
#!/usr/bin/env python
# encoding: utf-8
#
# Copyright 2024 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import logging
import os
import time

from typing import List, Tuple

import apache_beam as beam
import mirdata

from basic_pitch.data import commandline, pipeline


class SlakhFilterInvalidTracks(beam.DoFn):
DOWNLOAD_ATTRIBUTES = ["audio_path", "metadata_path", "midi_path"]

def __init__(self, source: str):
self.source = source

def setup(self):
import mirdata

self.slakh_remote = mirdata.initialize("slakh", data_home=self.source)
self.filesystem = beam.io.filesystems.FileSystems()

def process(self, element: Tuple[str, str]):
import tempfile

import apache_beam as beam
import sox

from basic_pitch.constants import (
AUDIO_N_CHANNELS,
AUDIO_SAMPLE_RATE,
)

track_id, split = element
if split == "omitted":
return None

logging.info(f"Processing (track_id, split): ({track_id}, {split})")

track_remote = self.slakh_remote.track(track_id)

with tempfile.TemporaryDirectory() as local_tmp_dir:
slakh_local = mirdata.initialize("slakh", local_tmp_dir)
track_local = slakh_local.track(track_id)

for attr in self.DOWNLOAD_ATTRIBUTES:
source = getattr(track_remote, attr)
dest = getattr(track_local, attr)
if not dest:
return None
logging.info(f"Downloading {attr} from {source} to {dest}")
os.makedirs(os.path.dirname(dest), exist_ok=True)
with self.filesystem.open(source) as s, open(dest, "wb") as d:
d.write(s.read())

if track_local.is_drum:
return None

local_wav_path = "{}_tmp.wav".format(track_local.audio_path)
tfm = sox.Transformer()
tfm.rate(AUDIO_SAMPLE_RATE)
tfm.channels(AUDIO_N_CHANNELS)
try:
tfm.build(track_local.audio_path, local_wav_path)
except Exception as e:
logging.info(f"Could not process {local_wav_path}. Exception: {e}")
return None

# if there are no notes, skip this track
if track_local.notes is None or len(track_local.notes.intervals) == 0:
return None

yield beam.pvalue.TaggedOutput(split, track_id)


class SlakhToTfExample(beam.DoFn):
DOWNLOAD_ATTRIBUTES = ["audio_path", "metadata_path", "midi_path"]

def __init__(self, source: str, download: bool) -> None:
self.source = source
self.download = download

def setup(self):
import apache_beam as beam
import os
import mirdata

self.slakh_remote = mirdata.initialize("slakh", data_home=self.source)
self.filesystem = beam.io.filesystems.FileSystems() # TODO: replace with fsspec
if self.download:
self.slakh_remote.download()

def process(self, element: List[str]):
import tempfile

import numpy as np
import sox

from basic_pitch.constants import (
AUDIO_N_CHANNELS,
AUDIO_SAMPLE_RATE,
FREQ_BINS_CONTOURS,
FREQ_BINS_NOTES,
ANNOTATION_HOP,
N_FREQ_BINS_NOTES,
N_FREQ_BINS_CONTOURS,
)
from basic_pitch.data import tf_example_serialization

logging.info(f"Processing {element}")
batch = []

for track_id in element:
track_remote = self.slakh_remote.track(track_id)

with tempfile.TemporaryDirectory() as local_tmp_dir:
slakh_local = mirdata.initialize("slakh", local_tmp_dir)
track_local = slakh_local.track(track_id)

for attr in self.DOWNLOAD_ATTRIBUTES:
source = getattr(track_remote, attr)
dest = getattr(track_local, attr)
logging.info(f"Downloading {attr} from {source} to {dest}")
os.makedirs(os.path.dirname(dest), exist_ok=True)
with self.filesystem.open(source) as s, open(dest, "wb") as d:
d.write(s.read())

local_wav_path = "{}_tmp.wav".format(track_local.audio_path)
tfm = sox.Transformer()
tfm.rate(AUDIO_SAMPLE_RATE)
tfm.channels(AUDIO_N_CHANNELS)
tfm.build(track_local.audio_path, local_wav_path)

duration = sox.file_info.duration(local_wav_path)
time_scale = np.arange(0, duration + ANNOTATION_HOP, ANNOTATION_HOP)
n_time_frames = len(time_scale)

note_indices, note_values = track_local.notes.to_sparse_index(time_scale, "s", FREQ_BINS_NOTES, "hz")
onset_indices, onset_values = track_local.notes.to_sparse_index(
time_scale, "s", FREQ_BINS_NOTES, "hz", onsets_only=True
)
contour_indices, contour_values = track_local.multif0.to_sparse_index(
time_scale, "s", FREQ_BINS_CONTOURS, "hz"
)

batch.append(
tf_example_serialization.to_transcription_tfexample(
track_id,
"slakh",
local_wav_path,
note_indices,
note_values,
onset_indices,
onset_values,
contour_indices,
contour_values,
(n_time_frames, N_FREQ_BINS_NOTES),
(n_time_frames, N_FREQ_BINS_CONTOURS),
)
)

logging.info(f"Finished processing batch of length {len(batch)}")
return [batch]


def create_input_data() -> List[Tuple[str, str]]:
slakh = mirdata.initialize("slakh")
return [(track_id, track.data_split) for track_id, track in slakh.load_tracks().items()]


def main(known_args, pipeline_args):
time_created = int(time.time())
destination = commandline.resolve_destination(known_args, time_created)
input_data = create_input_data()

pipeline_options = {
"runner": known_args.runner,
"job_name": f"slakh-tfrecords-{time_created}",
"machine_type": "e2-standard-4",
"num_workers": 25,
"disk_size_gb": 128,
"experiments": ["use_runner_v2"],
"save_main_session": True,
"sdk_container_image": known_args.sdk_container_image,
"job_endpoint": known_args.job_endpoint,
"environment_type": "DOCKER",
"environment_config": known_args.sdk_container_image,
}
pipeline.run(
pipeline_options,
pipeline_args,
input_data,
SlakhToTfExample(known_args.source, download=True),
SlakhFilterInvalidTracks(known_args.source),
destination,
known_args.batch_size,
)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
commandline.add_default(parser, os.path.basename(os.path.splitext(__file__)[0]))
commandline.add_split(parser)
known_args, pipeline_args = parser.parse_known_args() # sys.argv)

main(known_args, pipeline_args)
144 changes: 144 additions & 0 deletions tests/data/test_slakh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
#!/usr/bin/env python
# encoding: utf-8
#
# Copyright 2024 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import apache_beam as beam
import itertools
import os
import pathlib

from typing import List

from apache_beam.testing.test_pipeline import TestPipeline

from basic_pitch.data.datasets.slakh import (
SlakhFilterInvalidTracks,
SlakhToTfExample,
create_input_data,
)
from basic_pitch.data.pipeline import WriteBatchToTfRecord

RESOURCES_PATH = pathlib.Path(__file__).parent.parent / "resources"
TRAIN_PIANO_TRACK_ID = "Track00001-S02"
TRAIN_DRUMS_TRACK_ID = "Track00001-S01"
VALID_PIANO_TRACK_ID = "Track01501-S06"
VALID_DRUMS_TRACK_ID = "Track01501-S03"
TEST_PIANO_TRACK_ID = "Track01876-S01"
TEST_DRUMS_TRACK_ID = "Track01876-S08"
OMITTED_PIANO_TRACK_ID = "Track00049-S05"
OMITTED_DRUMS_TRACK_ID = "Track00049-S06"


def test_slakh_to_tf_example(tmpdir: str) -> None:
input_data: List[str] = [TRAIN_PIANO_TRACK_ID]
with TestPipeline() as p:
(
p
| "Create PCollection of track IDs" >> beam.Create([input_data])
| "Create tf.Example"
>> beam.ParDo(SlakhToTfExample(str(RESOURCES_PATH / "data" / "slakh"), download=False))
| "Write to tfrecord" >> beam.ParDo(WriteBatchToTfRecord(tmpdir))
)

assert len(os.listdir(tmpdir)) == 1
assert os.path.splitext(os.listdir(tmpdir)[0])[-1] == ".tfrecord"
with open(os.path.join(tmpdir, os.listdir(tmpdir)[0]), "rb") as fp:
data = fp.read()
assert len(data) != 0


def test_slakh_invalid_tracks(tmpdir: str) -> None:
split_labels = ["train", "validation", "test"]
input_data = [(TRAIN_PIANO_TRACK_ID, "train"),
(VALID_PIANO_TRACK_ID, "validation"),
(TEST_PIANO_TRACK_ID, "test")]

with TestPipeline() as p:
splits = (
p
| "Create PCollection" >> beam.Create(input_data)
| "Tag it" >> beam.ParDo(
SlakhFilterInvalidTracks(str(RESOURCES_PATH / "data" / "slakh"))).with_outputs(*split_labels)
)

for split in split_labels:
(
getattr(splits, split)
| f"Write {split} to text"
>> beam.io.WriteToText(os.path.join(tmpdir, f"output_{split}.txt"), shard_name_template="")
)

for track_id, split in input_data:
with open(os.path.join(tmpdir, f"output_{split}.txt"), "r") as fp:
assert fp.read().strip() == track_id


def test_slakh_invalid_tracks_omitted(tmpdir: str) -> None:
split_labels = ["train", "omitted"]
input_data = [(TRAIN_PIANO_TRACK_ID, "train"),
(OMITTED_PIANO_TRACK_ID, "omitted")]

with TestPipeline() as p:
splits = (
p
| "Create PCollection" >> beam.Create(input_data)
| "Tag it" >> beam.ParDo(
SlakhFilterInvalidTracks(str(RESOURCES_PATH / "data" / "slakh"))).with_outputs(*split_labels)
)

for split in split_labels:
(
getattr(splits, split)
| f"Write {split} to text"
>> beam.io.WriteToText(os.path.join(tmpdir, f"output_{split}.txt"), shard_name_template="")
)

with open(os.path.join(tmpdir, "output_train.txt"), "r") as fp:
assert fp.read().strip() == TRAIN_PIANO_TRACK_ID

with open(os.path.join(tmpdir, "output_omitted.txt"), "r") as fp:
assert fp.read().strip() == ""


def test_slakh_invalid_tracks_drums(tmpdir: str) -> None:
split_labels = ["train", "validation", "test"]
input_data = [(TRAIN_DRUMS_TRACK_ID, "train"),
(VALID_DRUMS_TRACK_ID, "validation"),
(TEST_DRUMS_TRACK_ID, "test")]

with TestPipeline() as p:
splits = (
p
| "Create PCollection" >> beam.Create(input_data)
| "Tag it" >> beam.ParDo(
SlakhFilterInvalidTracks(str(RESOURCES_PATH / "data" / "slakh"))).with_outputs(*split_labels)
)

for split in split_labels:
(
getattr(splits, split)
| f"Write {split} to text"
>> beam.io.WriteToText(os.path.join(tmpdir, f"output_{split}.txt"), shard_name_template="")
)

for track_id, split in input_data:
with open(os.path.join(tmpdir, f"output_{split}.txt"), "r") as fp:
assert fp.read().strip() == ""


def test_create_input_data() -> None:
data = create_input_data()
for key, group in itertools.groupby(data, lambda el: el[1]):
assert len(list(group))
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit 5cfa7fe

Please sign in to comment.