Skip to content

Commit

Permalink
Better tests + minor improvements (#108)
Browse files Browse the repository at this point in the history
* parametrizing tests, improvements in preprocess_midi, fixes for miditoolkit 1.0.1

* fixing absolute path for data aug and io tests

* fix test file tree, data aug report saved in out_path

* fix in data augmentation saving paths

* forced disabling original in out_dir when calling data aug from tok_dataset

* using pytest tmp_path to write files, and TEST_LOG_DIR if required

* lighter and more elegant MIDI assertions + covering check_midi_equals

* better tokenization test sets, set_midi_max_tick method, renamed "nb" contractions to "num", handling empty tokens lists in methods

* dealing with empty midi file (#110)

* dealing with empty midi file

* add a new test midi tokenizer file instead of changing the original one

* delete test_midi_tokenizer

* Adding check empty input for _ids_to_tokens as well

---------

Co-authored-by: Nathan Fradet <[email protected]>

* adding tests for empty MIDI and associated fixes

* fixes from tests with empty midi + retry hf hub tests when http errors

* fix convert_sequence_to_tokseq when list in last dim is empty

* better tok test sets

* testing with multiple time resolutions, adjusting notes ends

* fix _quantize_time_signatures (delete_equal_successive_time_sig_changes)

---------

Co-authored-by: feiyuehchen <[email protected]>
  • Loading branch information
Natooz and feiyuehchen authored Nov 28, 2023
1 parent 62b9c2e commit 89c4678
Show file tree
Hide file tree
Showing 59 changed files with 1,462 additions and 1,363 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# separate terms of service, privacy policy, and support
# documentation.

name: Upload Python Package
name: Publish package on PyPi

on:
release:
Expand All @@ -21,9 +21,9 @@ jobs:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v3
uses: actions/setup-python@v4
with:
python-version: '3.x'
- name: Install dependencies
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ from miditoolkit import MidiFile
from pathlib import Path

# Creating a multitrack tokenizer configuration, read the doc to explore other parameters
config = TokenizerConfig(nb_velocities=16, use_chords=True, use_programs=True)
config = TokenizerConfig(num_velocities=16, use_chords=True, use_programs=True)
tokenizer = REMI(config)

# Loads a midi, converts to tokens, and back to a MIDI
Expand Down
40 changes: 27 additions & 13 deletions miditok/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Common classes.
"""
import json
import warnings
from copy import deepcopy
from dataclasses import dataclass
from pathlib import Path
Expand All @@ -20,8 +21,8 @@
DELETE_EQUAL_SUCCESSIVE_TIME_SIG_CHANGES,
LOG_TEMPOS,
MAX_PITCH_INTERVAL,
NB_TEMPOS,
NB_VELOCITIES,
NUM_TEMPOS,
NUM_VELOCITIES,
ONE_TOKEN_STREAM_FOR_PROGRAMS,
PITCH_BEND_RANGE,
PITCH_INTERVALS_MAX_TIME_DIST,
Expand Down Expand Up @@ -160,9 +161,9 @@ class TokenizerConfig:
lengths / resolutions. Note: for tokenization with ``Position`` tokens, the total number of possible
positions will be set at four times the maximum resolution given (``max(beat_res.values)``\).
(default: ``{(0, 4): 8, (4, 12): 4}``)
:param nb_velocities: number of velocity bins. In the MIDI norm, velocities can take
:param num_velocities: number of velocity bins. In the MIDI norm, velocities can take
up to 128 values (0 to 127). This parameter allows to reduce the number of velocity values.
The velocities of the MIDIs resolution will be downsampled to ``nb_velocities`` values, equally
The velocities of the MIDIs resolution will be downsampled to ``num_velocities`` values, equally
separated between 0 and 127. (default: ``32``)
:param special_tokens: list of special tokens. This must be given as a list of strings given
only the names of the tokens. (default: ``["PAD", "BOS", "EOS", "MASK"]``\)
Expand All @@ -178,7 +179,7 @@ class TokenizerConfig:
values to represent with the ``beat_res_rest`` argument. (default: ``False``)
:param use_tempos: will use ``Tempo`` tokens, if the tokenizer is compatible.
``Tempo`` tokens will specify the current tempo. This allows to train a model to predict tempo changes.
Tempo values are quantized accordingly to the ``nb_tempos`` and ``tempo_range`` entries in the
Tempo values are quantized accordingly to the ``num_tempos`` and ``tempo_range`` entries in the
``additional_tokens`` dictionary (default is 32 tempos from 40 to 250). (default: ``False``)
:param use_time_signatures: will use ``TimeSignature`` tokens, if the tokenizer is compatible.
``TimeSignature`` tokens will specify the current time signature. Note that :ref:`REMI` adds a
Expand Down Expand Up @@ -215,7 +216,7 @@ class TokenizerConfig:
:param chord_unknown: range of number of notes to represent unknown chords.
If you want to represent chords that does not match any combination in ``chord_maps``, use this argument.
Leave ``None`` to not represent unknown chords. (default: ``None``)
:param nb_tempos: number of tempos "bins" to use. (default: ``32``)
:param num_tempos: number of tempos "bins" to use. (default: ``32``)
:param tempo_range: range of minimum and maximum tempos within which the bins fall. (default: ``(40, 250)``)
:param log_tempos: will use log scaled tempo values instead of linearly scaled. (default: ``False``)
:param delete_equal_successive_tempo_changes: setting this option True will delete identical successive tempo
Expand All @@ -234,7 +235,7 @@ class TokenizerConfig:
durations. If you use this parameter, make sure to configure ``beat_res`` to cover the durations you expect.
(default: ``False``)
:param pitch_bend_range: range of the pitch bend to consider, to be given as a tuple with the form
``(lowest_value, highest_value, nb_of_values)``. There will be ``nb_of_values`` tokens equally spaced
``(lowest_value, highest_value, num_of_values)``. There will be ``num_of_values`` tokens equally spaced
between ``lowest_value` and `highest_value``. (default: ``(-8192, 8191, 32)``)
:param delete_equal_successive_time_sig_changes: setting this option True will delete identical successive time
signature changes when preprocessing a MIDI file after loading it. For examples, if a MIDI has two time
Expand Down Expand Up @@ -268,7 +269,7 @@ def __init__(
self,
pitch_range: Tuple[int, int] = PITCH_RANGE,
beat_res: Dict[Tuple[int, int], int] = BEAT_RES,
nb_velocities: int = NB_VELOCITIES,
num_velocities: int = NUM_VELOCITIES,
special_tokens: Sequence[str] = SPECIAL_TOKENS,
use_chords: bool = USE_CHORDS,
use_rests: bool = USE_RESTS,
Expand All @@ -282,7 +283,7 @@ def __init__(
chord_maps: Dict[str, Tuple] = CHORD_MAPS,
chord_tokens_with_root_note: bool = CHORD_TOKENS_WITH_ROOT_NOTE,
chord_unknown: Tuple[int, int] = CHORD_UNKNOWN,
nb_tempos: int = NB_TEMPOS,
num_tempos: int = NUM_TEMPOS,
tempo_range: Tuple[int, int] = TEMPO_RANGE,
log_tempos: bool = LOG_TEMPOS,
delete_equal_successive_tempo_changes: bool = DELETE_EQUAL_SUCCESSIVE_TEMPO_CHANGES,
Expand All @@ -306,16 +307,16 @@ def __init__(
f"(received {pitch_range})"
)
assert (
1 <= nb_velocities <= 127
), f"nb_velocities must be within 1 and 127 (received {nb_velocities})"
1 <= num_velocities <= 127
), f"num_velocities must be within 1 and 127 (received {num_velocities})"
assert (
0 <= max_pitch_interval <= 127
), f"max_pitch_interval must be within 0 and 127 (received {max_pitch_interval})."

# Global parameters
self.pitch_range: Tuple[int, int] = pitch_range
self.beat_res: Dict[Tuple[int, int], int] = beat_res
self.nb_velocities: int = nb_velocities
self.num_velocities: int = num_velocities
self.special_tokens: Sequence[str] = special_tokens

# Additional token types params, enabling additional token types
Expand All @@ -339,7 +340,7 @@ def __init__(
self.chord_unknown: Tuple[int, int] = chord_unknown

# Tempo params
self.nb_tempos: int = nb_tempos # nb of tempo bins for additional tempo tokens, quantized like velocities
self.num_tempos: int = num_tempos
self.tempo_range: Tuple[int, int] = tempo_range # (min_tempo, max_tempo)
self.log_tempos: bool = log_tempos
self.delete_equal_successive_tempo_changes = (
Expand Down Expand Up @@ -372,6 +373,19 @@ def __init__(
self.max_pitch_interval = max_pitch_interval
self.pitch_intervals_max_time_dist = pitch_intervals_max_time_dist

# Pop legacy kwargs
legacy_args = (
("nb_velocities", "num_velocities"),
("nb_tempos", "num_tempos"),
)
for legacy_arg, new_arg in legacy_args:
if legacy_arg in kwargs:
setattr(self, new_arg, kwargs.pop(legacy_arg))
warnings.warn(
f"Argument {legacy_arg} has been renamed {new_arg}, you should consider to update"
f"your code with this new argument name."
)

# Additional params
self.additional_params = kwargs

Expand Down
4 changes: 2 additions & 2 deletions miditok/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
PITCH_RANGE = (21, 109)
BEAT_RES = {(0, 4): 8, (4, 12): 4} # samples per beat
# nb of velocity bins, velocities values from 0 to 127 will be quantized
NB_VELOCITIES = 32
NUM_VELOCITIES = 32
# default special tokens
SPECIAL_TOKENS = ["PAD", "BOS", "EOS", "MASK"]

Expand Down Expand Up @@ -70,7 +70,7 @@

# Tempo params
# nb of tempo bins for additional tempo tokens, quantized like velocities
NB_TEMPOS = 32 # TODO raname num contractions
NUM_TEMPOS = 32
TEMPO_RANGE = (40, 250) # (min_tempo, max_tempo)
LOG_TEMPOS = False # log or linear scale tempos
DELETE_EQUAL_SUCCESSIVE_TEMPO_CHANGES = False
Expand Down
66 changes: 45 additions & 21 deletions miditok/midi_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
DEFAULT_TOKENIZER_FILE_NAME,
MIDI_FILES_EXTENSIONS,
PITCH_CLASSES,
TEMPO,
TIME_DIVISION,
TIME_SIGNATURE,
UNKNOWN_CHORD_PREFIX,
Expand All @@ -45,6 +46,7 @@
get_midi_programs,
merge_same_program_tracks,
remove_duplicated_notes,
set_midi_max_tick,
)


Expand Down Expand Up @@ -76,18 +78,20 @@ def convert_sequence_to_tokseq(
# Deduce nb of subscripts / dims
nb_io_dims = len(tokenizer.io_format)
nb_seq_dims = 1
if isinstance(arg[1][0], list):
if len(arg[1]) > 0 and isinstance(arg[1][0], list):
nb_seq_dims += 1
if isinstance(arg[1][0][0], list):
if len(arg[1][0]) > 0 and isinstance(arg[1][0][0], list):
nb_seq_dims += 1
elif len(arg[1][0]) == 0 and nb_seq_dims == nb_io_dims - 1:
# Special case where the sequence contains no tokens, we increment anyway
nb_seq_dims += 1

# Check the number of dimensions is good
# In case of no one_token_stream and one dimension short --> unsqueeze
if not tokenizer.one_token_stream and nb_seq_dims == nb_io_dims - 1:
print(
f"The input sequence has one dimension less than expected ({nb_seq_dims} instead of "
f"{nb_io_dims}). It is being unsqueezed to conform with the tokenizer's i/o format "
f"({tokenizer.io_format})"
f"The input sequence has one dimension less than expected ({nb_seq_dims} instead of {nb_io_dims})."
f"It is being unsqueezed to conform with the tokenizer's i/o format ({tokenizer.io_format})"
)
arg = (arg[0], [arg[1]])

Expand Down Expand Up @@ -218,7 +222,7 @@ def __init__(
self.config.pitch_range[0] >= 0 and self.config.pitch_range[1] <= 128
), "You must specify a pitch_range between 0 and 127 (included, i.e. range.stop at 128)"
assert (
0 < self.config.nb_velocities < 128
0 < self.config.num_velocities < 128
), "You must specify a nb_velocities between 1 and 127 (included)"

# Tweak the tokenizer's configuration and / or attributes before creating the vocabulary
Expand All @@ -233,7 +237,7 @@ def __init__(
self.durations = self.__create_durations_tuples()
# [1:] so that there is no velocity_0
self.velocities = np.linspace(
0, 127, self.config.nb_velocities + 1, dtype=np.intc
0, 127, self.config.num_velocities + 1, dtype=np.intc
)[1:]
self._first_beat_res = list(self.config.beat_res.values())[0]
for beat_range, res in self.config.beat_res.items():
Expand All @@ -242,9 +246,15 @@ def __init__(
break

# Tempos
# _DEFAULT_TEMPO is useful when `log_tempos` is enabled
self.tempos = np.zeros(1)
self._DEFAULT_TEMPO = TEMPO
if self.config.use_tempos:
self.tempos = self.__create_tempos()
if self.config.log_tempos:
self._DEFAULT_TEMPO = self.tempos[
np.argmin(np.abs(self.tempos - TEMPO))
]

# Rests
self.rests = []
Expand Down Expand Up @@ -366,8 +376,7 @@ def preprocess_midi(self, midi: MidiFile):
if self.config.use_programs and self.one_token_stream:
merge_same_program_tracks(midi.instruments)

t = 0
while t < len(midi.instruments):
for t in range(len(midi.instruments) - 1, -1, -1):
# quantize notes attributes
self._quantize_notes(midi.instruments[t].notes, midi.ticks_per_beat)
# sort notes
Expand All @@ -388,17 +397,12 @@ def preprocess_midi(self, midi: MidiFile):
midi.instruments[t].pitch_bends, midi.ticks_per_beat
)
# TODO quantize control changes
t += 1

# Recalculate max_tick is this could have changed after notes quantization
if len(midi.instruments) > 0:
midi.max_tick = max(
[max([note.end for note in track.notes]) for track in midi.instruments]
)

# Process tempo changes
if self.config.use_tempos:
self._quantize_tempos(midi.tempo_changes, midi.ticks_per_beat)

# Process time signature changes
if len(midi.time_signature_changes) == 0: # can sometimes happen
midi.time_signature_changes.append(
TimeSignature(*TIME_SIGNATURE, 0)
Expand All @@ -408,6 +412,11 @@ def preprocess_midi(self, midi: MidiFile):
midi.time_signature_changes, midi.ticks_per_beat
)

# We do not change key signature changes, markers and lyrics here as they are not used by MidiTok (yet)

# Recalculate max_tick is this could have changed after notes quantization
set_midi_max_tick(midi)

def _quantize_notes(self, notes: List[Note], time_division: int):
r"""Quantize the notes attributes: their pitch, velocity, start and end values.
It shifts the notes so that they start at times that match the time resolution
Expand Down Expand Up @@ -464,6 +473,11 @@ def _quantize_tempos(self, tempos: List[TempoChange], time_division: int):
"""
ticks_per_sample = int(time_division / max(self.config.beat_res.values()))
prev_tempo = TempoChange(-1, -1)
# If we delete the successive equal tempo changes, we need to sort them by time
# Otherwise it is not required here as the tokens will be sorted by time
if self.config.delete_equal_successive_tempo_changes:
tempos.sort(key=lambda x: x.time)

i = 0
while i < len(tempos):
# Quantize tempo value
Expand Down Expand Up @@ -505,6 +519,11 @@ def _quantize_time_signatures(
)
previous_tick = 0 # first time signature change is always at tick 0
prev_ts = time_sigs[0]
# If we delete the successive equal tempo changes, we need to sort them by time
# Otherwise it is not required here as the tokens will be sorted by time
if self.config.delete_equal_successive_time_sig_changes:
time_sigs.sort(key=lambda x: x.time)

i = 1
while i < len(time_sigs):
time_sig = time_sigs[i]
Expand Down Expand Up @@ -562,7 +581,6 @@ def _quantize_sustain_pedals(self, pedals: List[Pedal], time_division: int):
)
if pedal.start == pedal.end:
pedal.end += ticks_per_sample
pedal.duration = pedal.end - pedal.start

def _quantize_pitch_bends(self, pitch_bends: List[PitchBend], time_division: int):
r"""Quantize the pitch bend events from a track. Their onset and offset times will be adjusted
Expand Down Expand Up @@ -609,6 +627,8 @@ def _midi_to_tokens(
# Create events list
all_events = []
if not self.one_token_stream:
if len(midi.instruments) == 0:
all_events.append([])
for i in range(len(midi.instruments)):
all_events.append([])

Expand Down Expand Up @@ -778,7 +798,7 @@ def _create_track_events(self, track: Instrument) -> List[Event]:

# Pitch / interval
add_absolute_pitch_token = True
if self.config.use_pitch_intervals:
if self.config.use_pitch_intervals and not track.is_drum:
if note.start != previous_note_onset:
if (
note.start - previous_note_onset <= max_time_interval
Expand Down Expand Up @@ -1487,7 +1507,7 @@ def __create_tempos(self) -> np.ndarray:
:return: the tempos.
"""
tempo_fn = np.geomspace if self.config.log_tempos else np.linspace
tempos = tempo_fn(*self.config.tempo_range, self.config.nb_tempos).round(2)
tempos = tempo_fn(*self.config.tempo_range, self.config.num_tempos).round(2)

return tempos

Expand Down Expand Up @@ -1810,7 +1830,7 @@ def decode_bpe(self, seq: Union[TokSequence, List[TokSequence]]):

def tokenize_midi_dataset(
self,
midi_paths: Union[str, Path, List[str], List[Path]],
midi_paths: Union[str, Path, Sequence[Union[str, Path]]],
out_dir: Union[str, Path],
overwrite_mode: bool = True,
tokenizer_config_file_name: str = DEFAULT_TOKENIZER_FILE_NAME,
Expand Down Expand Up @@ -1854,7 +1874,7 @@ def tokenize_midi_dataset(
out_dir.mkdir(parents=True, exist_ok=True)

# User gave a path to a directory, we'll scan it to find MIDI files
if not isinstance(midi_paths, list):
if not isinstance(midi_paths, Sequence):
if isinstance(midi_paths, str):
midi_paths = Path(midi_paths)
root_dir = midi_paths
Expand Down Expand Up @@ -1988,6 +2008,8 @@ def tokens_errors(
# If list of TokSequence -> recursive
if isinstance(tokens, list):
return [self.tokens_errors(tok_seq) for tok_seq in tokens]
elif len(tokens) == 0:
return 0

nb_tok_predicted = len(tokens) # used to norm the score
if self.has_bpe:
Expand Down Expand Up @@ -2094,6 +2116,8 @@ def save_tokens(
self.complete_sequence(tokens)
ids_bpe_encoded = tokens.ids_bpe_encoded
ids = tokens.ids
elif isinstance(tokens, list) and len(tokens) == 0:
pass
elif isinstance(tokens[0], TokSequence):
ids_bpe_encoded = []
for seq in tokens:
Expand Down
Loading

0 comments on commit 89c4678

Please sign in to comment.