diff --git a/.github/workflows/python-publish.yml b/.github/workflows/publish-pypi.yml similarity index 91% rename from .github/workflows/python-publish.yml rename to .github/workflows/publish-pypi.yml index 13291ca8..61c244dd 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/publish-pypi.yml @@ -6,7 +6,7 @@ # separate terms of service, privacy policy, and support # documentation. -name: Upload Python Package +name: Publish package on PyPi on: release: @@ -21,9 +21,9 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v3 + uses: actions/setup-python@v4 with: python-version: '3.x' - name: Install dependencies diff --git a/README.md b/README.md index e0a65e78..a9a171d5 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ from miditoolkit import MidiFile from pathlib import Path # Creating a multitrack tokenizer configuration, read the doc to explore other parameters -config = TokenizerConfig(nb_velocities=16, use_chords=True, use_programs=True) +config = TokenizerConfig(num_velocities=16, use_chords=True, use_programs=True) tokenizer = REMI(config) # Loads a midi, converts to tokens, and back to a MIDI diff --git a/miditok/classes.py b/miditok/classes.py index a09b806b..4ee234cf 100644 --- a/miditok/classes.py +++ b/miditok/classes.py @@ -2,6 +2,7 @@ Common classes. """ import json +import warnings from copy import deepcopy from dataclasses import dataclass from pathlib import Path @@ -20,8 +21,8 @@ DELETE_EQUAL_SUCCESSIVE_TIME_SIG_CHANGES, LOG_TEMPOS, MAX_PITCH_INTERVAL, - NB_TEMPOS, - NB_VELOCITIES, + NUM_TEMPOS, + NUM_VELOCITIES, ONE_TOKEN_STREAM_FOR_PROGRAMS, PITCH_BEND_RANGE, PITCH_INTERVALS_MAX_TIME_DIST, @@ -160,9 +161,9 @@ class TokenizerConfig: lengths / resolutions. Note: for tokenization with ``Position`` tokens, the total number of possible positions will be set at four times the maximum resolution given (``max(beat_res.values)``\). (default: ``{(0, 4): 8, (4, 12): 4}``) - :param nb_velocities: number of velocity bins. In the MIDI norm, velocities can take + :param num_velocities: number of velocity bins. In the MIDI norm, velocities can take up to 128 values (0 to 127). This parameter allows to reduce the number of velocity values. - The velocities of the MIDIs resolution will be downsampled to ``nb_velocities`` values, equally + The velocities of the MIDIs resolution will be downsampled to ``num_velocities`` values, equally separated between 0 and 127. (default: ``32``) :param special_tokens: list of special tokens. This must be given as a list of strings given only the names of the tokens. (default: ``["PAD", "BOS", "EOS", "MASK"]``\) @@ -178,7 +179,7 @@ class TokenizerConfig: values to represent with the ``beat_res_rest`` argument. (default: ``False``) :param use_tempos: will use ``Tempo`` tokens, if the tokenizer is compatible. ``Tempo`` tokens will specify the current tempo. This allows to train a model to predict tempo changes. - Tempo values are quantized accordingly to the ``nb_tempos`` and ``tempo_range`` entries in the + Tempo values are quantized accordingly to the ``num_tempos`` and ``tempo_range`` entries in the ``additional_tokens`` dictionary (default is 32 tempos from 40 to 250). (default: ``False``) :param use_time_signatures: will use ``TimeSignature`` tokens, if the tokenizer is compatible. ``TimeSignature`` tokens will specify the current time signature. Note that :ref:`REMI` adds a @@ -215,7 +216,7 @@ class TokenizerConfig: :param chord_unknown: range of number of notes to represent unknown chords. If you want to represent chords that does not match any combination in ``chord_maps``, use this argument. Leave ``None`` to not represent unknown chords. (default: ``None``) - :param nb_tempos: number of tempos "bins" to use. (default: ``32``) + :param num_tempos: number of tempos "bins" to use. (default: ``32``) :param tempo_range: range of minimum and maximum tempos within which the bins fall. (default: ``(40, 250)``) :param log_tempos: will use log scaled tempo values instead of linearly scaled. (default: ``False``) :param delete_equal_successive_tempo_changes: setting this option True will delete identical successive tempo @@ -234,7 +235,7 @@ class TokenizerConfig: durations. If you use this parameter, make sure to configure ``beat_res`` to cover the durations you expect. (default: ``False``) :param pitch_bend_range: range of the pitch bend to consider, to be given as a tuple with the form - ``(lowest_value, highest_value, nb_of_values)``. There will be ``nb_of_values`` tokens equally spaced + ``(lowest_value, highest_value, num_of_values)``. There will be ``num_of_values`` tokens equally spaced between ``lowest_value` and `highest_value``. (default: ``(-8192, 8191, 32)``) :param delete_equal_successive_time_sig_changes: setting this option True will delete identical successive time signature changes when preprocessing a MIDI file after loading it. For examples, if a MIDI has two time @@ -268,7 +269,7 @@ def __init__( self, pitch_range: Tuple[int, int] = PITCH_RANGE, beat_res: Dict[Tuple[int, int], int] = BEAT_RES, - nb_velocities: int = NB_VELOCITIES, + num_velocities: int = NUM_VELOCITIES, special_tokens: Sequence[str] = SPECIAL_TOKENS, use_chords: bool = USE_CHORDS, use_rests: bool = USE_RESTS, @@ -282,7 +283,7 @@ def __init__( chord_maps: Dict[str, Tuple] = CHORD_MAPS, chord_tokens_with_root_note: bool = CHORD_TOKENS_WITH_ROOT_NOTE, chord_unknown: Tuple[int, int] = CHORD_UNKNOWN, - nb_tempos: int = NB_TEMPOS, + num_tempos: int = NUM_TEMPOS, tempo_range: Tuple[int, int] = TEMPO_RANGE, log_tempos: bool = LOG_TEMPOS, delete_equal_successive_tempo_changes: bool = DELETE_EQUAL_SUCCESSIVE_TEMPO_CHANGES, @@ -306,8 +307,8 @@ def __init__( f"(received {pitch_range})" ) assert ( - 1 <= nb_velocities <= 127 - ), f"nb_velocities must be within 1 and 127 (received {nb_velocities})" + 1 <= num_velocities <= 127 + ), f"num_velocities must be within 1 and 127 (received {num_velocities})" assert ( 0 <= max_pitch_interval <= 127 ), f"max_pitch_interval must be within 0 and 127 (received {max_pitch_interval})." @@ -315,7 +316,7 @@ def __init__( # Global parameters self.pitch_range: Tuple[int, int] = pitch_range self.beat_res: Dict[Tuple[int, int], int] = beat_res - self.nb_velocities: int = nb_velocities + self.num_velocities: int = num_velocities self.special_tokens: Sequence[str] = special_tokens # Additional token types params, enabling additional token types @@ -339,7 +340,7 @@ def __init__( self.chord_unknown: Tuple[int, int] = chord_unknown # Tempo params - self.nb_tempos: int = nb_tempos # nb of tempo bins for additional tempo tokens, quantized like velocities + self.num_tempos: int = num_tempos self.tempo_range: Tuple[int, int] = tempo_range # (min_tempo, max_tempo) self.log_tempos: bool = log_tempos self.delete_equal_successive_tempo_changes = ( @@ -372,6 +373,19 @@ def __init__( self.max_pitch_interval = max_pitch_interval self.pitch_intervals_max_time_dist = pitch_intervals_max_time_dist + # Pop legacy kwargs + legacy_args = ( + ("nb_velocities", "num_velocities"), + ("nb_tempos", "num_tempos"), + ) + for legacy_arg, new_arg in legacy_args: + if legacy_arg in kwargs: + setattr(self, new_arg, kwargs.pop(legacy_arg)) + warnings.warn( + f"Argument {legacy_arg} has been renamed {new_arg}, you should consider to update" + f"your code with this new argument name." + ) + # Additional params self.additional_params = kwargs diff --git a/miditok/constants.py b/miditok/constants.py index e4015f18..5133fc88 100644 --- a/miditok/constants.py +++ b/miditok/constants.py @@ -22,7 +22,7 @@ PITCH_RANGE = (21, 109) BEAT_RES = {(0, 4): 8, (4, 12): 4} # samples per beat # nb of velocity bins, velocities values from 0 to 127 will be quantized -NB_VELOCITIES = 32 +NUM_VELOCITIES = 32 # default special tokens SPECIAL_TOKENS = ["PAD", "BOS", "EOS", "MASK"] @@ -70,7 +70,7 @@ # Tempo params # nb of tempo bins for additional tempo tokens, quantized like velocities -NB_TEMPOS = 32 # TODO raname num contractions +NUM_TEMPOS = 32 TEMPO_RANGE = (40, 250) # (min_tempo, max_tempo) LOG_TEMPOS = False # log or linear scale tempos DELETE_EQUAL_SUCCESSIVE_TEMPO_CHANGES = False diff --git a/miditok/midi_tokenizer.py b/miditok/midi_tokenizer.py index 3df10086..58bb7227 100644 --- a/miditok/midi_tokenizer.py +++ b/miditok/midi_tokenizer.py @@ -34,6 +34,7 @@ DEFAULT_TOKENIZER_FILE_NAME, MIDI_FILES_EXTENSIONS, PITCH_CLASSES, + TEMPO, TIME_DIVISION, TIME_SIGNATURE, UNKNOWN_CHORD_PREFIX, @@ -45,6 +46,7 @@ get_midi_programs, merge_same_program_tracks, remove_duplicated_notes, + set_midi_max_tick, ) @@ -76,18 +78,20 @@ def convert_sequence_to_tokseq( # Deduce nb of subscripts / dims nb_io_dims = len(tokenizer.io_format) nb_seq_dims = 1 - if isinstance(arg[1][0], list): + if len(arg[1]) > 0 and isinstance(arg[1][0], list): nb_seq_dims += 1 - if isinstance(arg[1][0][0], list): + if len(arg[1][0]) > 0 and isinstance(arg[1][0][0], list): + nb_seq_dims += 1 + elif len(arg[1][0]) == 0 and nb_seq_dims == nb_io_dims - 1: + # Special case where the sequence contains no tokens, we increment anyway nb_seq_dims += 1 # Check the number of dimensions is good # In case of no one_token_stream and one dimension short --> unsqueeze if not tokenizer.one_token_stream and nb_seq_dims == nb_io_dims - 1: print( - f"The input sequence has one dimension less than expected ({nb_seq_dims} instead of " - f"{nb_io_dims}). It is being unsqueezed to conform with the tokenizer's i/o format " - f"({tokenizer.io_format})" + f"The input sequence has one dimension less than expected ({nb_seq_dims} instead of {nb_io_dims})." + f"It is being unsqueezed to conform with the tokenizer's i/o format ({tokenizer.io_format})" ) arg = (arg[0], [arg[1]]) @@ -218,7 +222,7 @@ def __init__( self.config.pitch_range[0] >= 0 and self.config.pitch_range[1] <= 128 ), "You must specify a pitch_range between 0 and 127 (included, i.e. range.stop at 128)" assert ( - 0 < self.config.nb_velocities < 128 + 0 < self.config.num_velocities < 128 ), "You must specify a nb_velocities between 1 and 127 (included)" # Tweak the tokenizer's configuration and / or attributes before creating the vocabulary @@ -233,7 +237,7 @@ def __init__( self.durations = self.__create_durations_tuples() # [1:] so that there is no velocity_0 self.velocities = np.linspace( - 0, 127, self.config.nb_velocities + 1, dtype=np.intc + 0, 127, self.config.num_velocities + 1, dtype=np.intc )[1:] self._first_beat_res = list(self.config.beat_res.values())[0] for beat_range, res in self.config.beat_res.items(): @@ -242,9 +246,15 @@ def __init__( break # Tempos + # _DEFAULT_TEMPO is useful when `log_tempos` is enabled self.tempos = np.zeros(1) + self._DEFAULT_TEMPO = TEMPO if self.config.use_tempos: self.tempos = self.__create_tempos() + if self.config.log_tempos: + self._DEFAULT_TEMPO = self.tempos[ + np.argmin(np.abs(self.tempos - TEMPO)) + ] # Rests self.rests = [] @@ -366,8 +376,7 @@ def preprocess_midi(self, midi: MidiFile): if self.config.use_programs and self.one_token_stream: merge_same_program_tracks(midi.instruments) - t = 0 - while t < len(midi.instruments): + for t in range(len(midi.instruments) - 1, -1, -1): # quantize notes attributes self._quantize_notes(midi.instruments[t].notes, midi.ticks_per_beat) # sort notes @@ -388,17 +397,12 @@ def preprocess_midi(self, midi: MidiFile): midi.instruments[t].pitch_bends, midi.ticks_per_beat ) # TODO quantize control changes - t += 1 - - # Recalculate max_tick is this could have changed after notes quantization - if len(midi.instruments) > 0: - midi.max_tick = max( - [max([note.end for note in track.notes]) for track in midi.instruments] - ) + # Process tempo changes if self.config.use_tempos: self._quantize_tempos(midi.tempo_changes, midi.ticks_per_beat) + # Process time signature changes if len(midi.time_signature_changes) == 0: # can sometimes happen midi.time_signature_changes.append( TimeSignature(*TIME_SIGNATURE, 0) @@ -408,6 +412,11 @@ def preprocess_midi(self, midi: MidiFile): midi.time_signature_changes, midi.ticks_per_beat ) + # We do not change key signature changes, markers and lyrics here as they are not used by MidiTok (yet) + + # Recalculate max_tick is this could have changed after notes quantization + set_midi_max_tick(midi) + def _quantize_notes(self, notes: List[Note], time_division: int): r"""Quantize the notes attributes: their pitch, velocity, start and end values. It shifts the notes so that they start at times that match the time resolution @@ -464,6 +473,11 @@ def _quantize_tempos(self, tempos: List[TempoChange], time_division: int): """ ticks_per_sample = int(time_division / max(self.config.beat_res.values())) prev_tempo = TempoChange(-1, -1) + # If we delete the successive equal tempo changes, we need to sort them by time + # Otherwise it is not required here as the tokens will be sorted by time + if self.config.delete_equal_successive_tempo_changes: + tempos.sort(key=lambda x: x.time) + i = 0 while i < len(tempos): # Quantize tempo value @@ -505,6 +519,11 @@ def _quantize_time_signatures( ) previous_tick = 0 # first time signature change is always at tick 0 prev_ts = time_sigs[0] + # If we delete the successive equal tempo changes, we need to sort them by time + # Otherwise it is not required here as the tokens will be sorted by time + if self.config.delete_equal_successive_time_sig_changes: + time_sigs.sort(key=lambda x: x.time) + i = 1 while i < len(time_sigs): time_sig = time_sigs[i] @@ -562,7 +581,6 @@ def _quantize_sustain_pedals(self, pedals: List[Pedal], time_division: int): ) if pedal.start == pedal.end: pedal.end += ticks_per_sample - pedal.duration = pedal.end - pedal.start def _quantize_pitch_bends(self, pitch_bends: List[PitchBend], time_division: int): r"""Quantize the pitch bend events from a track. Their onset and offset times will be adjusted @@ -609,6 +627,8 @@ def _midi_to_tokens( # Create events list all_events = [] if not self.one_token_stream: + if len(midi.instruments) == 0: + all_events.append([]) for i in range(len(midi.instruments)): all_events.append([]) @@ -778,7 +798,7 @@ def _create_track_events(self, track: Instrument) -> List[Event]: # Pitch / interval add_absolute_pitch_token = True - if self.config.use_pitch_intervals: + if self.config.use_pitch_intervals and not track.is_drum: if note.start != previous_note_onset: if ( note.start - previous_note_onset <= max_time_interval @@ -1487,7 +1507,7 @@ def __create_tempos(self) -> np.ndarray: :return: the tempos. """ tempo_fn = np.geomspace if self.config.log_tempos else np.linspace - tempos = tempo_fn(*self.config.tempo_range, self.config.nb_tempos).round(2) + tempos = tempo_fn(*self.config.tempo_range, self.config.num_tempos).round(2) return tempos @@ -1810,7 +1830,7 @@ def decode_bpe(self, seq: Union[TokSequence, List[TokSequence]]): def tokenize_midi_dataset( self, - midi_paths: Union[str, Path, List[str], List[Path]], + midi_paths: Union[str, Path, Sequence[Union[str, Path]]], out_dir: Union[str, Path], overwrite_mode: bool = True, tokenizer_config_file_name: str = DEFAULT_TOKENIZER_FILE_NAME, @@ -1854,7 +1874,7 @@ def tokenize_midi_dataset( out_dir.mkdir(parents=True, exist_ok=True) # User gave a path to a directory, we'll scan it to find MIDI files - if not isinstance(midi_paths, list): + if not isinstance(midi_paths, Sequence): if isinstance(midi_paths, str): midi_paths = Path(midi_paths) root_dir = midi_paths @@ -1988,6 +2008,8 @@ def tokens_errors( # If list of TokSequence -> recursive if isinstance(tokens, list): return [self.tokens_errors(tok_seq) for tok_seq in tokens] + elif len(tokens) == 0: + return 0 nb_tok_predicted = len(tokens) # used to norm the score if self.has_bpe: @@ -2094,6 +2116,8 @@ def save_tokens( self.complete_sequence(tokens) ids_bpe_encoded = tokens.ids_bpe_encoded ids = tokens.ids + elif isinstance(tokens, list) and len(tokens) == 0: + pass elif isinstance(tokens[0], TokSequence): ids_bpe_encoded = [] for seq in tokens: diff --git a/miditok/tokenizations/cp_word.py b/miditok/tokenizations/cp_word.py index 0d4794a5..4835ede1 100644 --- a/miditok/tokenizations/cp_word.py +++ b/miditok/tokenizations/cp_word.py @@ -7,8 +7,9 @@ from miditoolkit import Instrument, MidiFile, Note, TempoChange, TimeSignature from ..classes import Event, TokSequence -from ..constants import MIDI_INSTRUMENTS, TEMPO, TIME_DIVISION, TIME_SIGNATURE +from ..constants import MIDI_INSTRUMENTS, TIME_DIVISION, TIME_SIGNATURE from ..midi_tokenizer import MIDITokenizer, _in_as_seq +from ..utils import set_midi_max_tick class CPWord(MIDITokenizer): @@ -91,9 +92,11 @@ def _add_time_events(self, events: List[Event]) -> List[List[Event]]: current_time_sig = TIME_SIGNATURE if self.config.log_tempos: # pick the closest to the default value - current_tempo = float(self.tempos[(np.abs(self.tempos - TEMPO)).argmin()]) + current_tempo = float( + self.tempos[(np.abs(self.tempos - self._DEFAULT_TEMPO)).argmin()] + ) else: - current_tempo = TEMPO + current_tempo = self._DEFAULT_TEMPO current_program = None ticks_per_bar = self._compute_ticks_per_bar( TimeSignature(*current_time_sig, 0), time_division @@ -372,7 +375,7 @@ def tokens_to_midi( # RESULTS instruments: Dict[int, Instrument] = {} - tempo_changes = [TempoChange(TEMPO, -1)] + tempo_changes = [TempoChange(self._DEFAULT_TEMPO, -1)] time_signature_changes = [] def check_inst(prog: int): @@ -536,12 +539,7 @@ def check_inst(prog: int): midi.instruments = list(instruments.values()) midi.tempo_changes = tempo_changes midi.time_signature_changes = time_signature_changes - midi.max_tick = max( - [ - max([note.end for note in track.notes]) if len(track.notes) > 0 else 0 - for track in midi.instruments - ] - ) + set_midi_max_tick(midi) # Write MIDI file if output_path: Path(output_path).mkdir(parents=True, exist_ok=True) @@ -675,6 +673,8 @@ def tokens_errors( # If list of TokSequence -> recursive if isinstance(tokens, list): return [self.tokens_errors(tok_seq) for tok_seq in tokens] + if len(tokens) == 0: + return 0 def cp_token_type(tok: List[int]) -> List[str]: family = self[0, tok[0]].split("_")[1] diff --git a/miditok/tokenizations/midi_like.py b/miditok/tokenizations/midi_like.py index 195c38f4..f3cbf82b 100644 --- a/miditok/tokenizations/midi_like.py +++ b/miditok/tokenizations/midi_like.py @@ -15,12 +15,11 @@ from ..classes import Event, TokSequence from ..constants import ( MIDI_INSTRUMENTS, - TEMPO, TIME_DIVISION, TIME_SIGNATURE, ) from ..midi_tokenizer import MIDITokenizer, _in_as_seq -from ..utils import fix_offsets_overlapping_notes +from ..utils import fix_offsets_overlapping_notes, set_midi_max_tick class MIDILike(MIDITokenizer): @@ -174,7 +173,7 @@ def tokens_to_midi( # RESULTS instruments: Dict[int, Instrument] = {} - tempo_changes = [TempoChange(TEMPO, -1)] + tempo_changes = [TempoChange(self._DEFAULT_TEMPO, -1)] time_signature_changes = [TimeSignature(*TIME_SIGNATURE, 0)] active_notes = {p: {} for p in self.config.programs} @@ -356,12 +355,7 @@ def clear_active_notes(): midi.instruments = list(instruments.values()) midi.tempo_changes = tempo_changes midi.time_signature_changes = time_signature_changes - midi.max_tick = max( - [ - max([note.end for note in track.notes]) if len(track.notes) > 0 else 0 - for track in midi.instruments - ] - ) + set_midi_max_tick(midi) # Write MIDI file if output_path: Path(output_path).mkdir(parents=True, exist_ok=True) @@ -427,7 +421,7 @@ def _create_token_types_graph(self) -> Dict[str, List[str]]: first_note_token_type = "NoteOn" dic["Velocity"] = [first_note_token_type, "TimeShift"] dic["NoteOff"] = ["NoteOff", first_note_token_type, "TimeShift"] - dic["TimeShift"] = ["NoteOff", first_note_token_type] + dic["TimeShift"] = ["NoteOff", first_note_token_type, "TimeShift"] if self.config.use_pitch_intervals: for token_type in ("PitchIntervalTime", "PitchIntervalChord"): dic[token_type] = ["Velocity"] @@ -594,6 +588,8 @@ def tokens_errors( # If list of TokSequence -> recursive if isinstance(tokens, list): return [self.tokens_errors(tok_seq) for tok_seq in tokens] + if len(tokens) == 0: + return 0 nb_tok_predicted = len(tokens) # used to norm the score if self.has_bpe: @@ -620,7 +616,7 @@ def tokens_errors( ) for i in range(1, len(events)): - # err_tokens = events[i - 4: i + 4] # uncomment for debug + # err_tokens = events[i - 4 : i + 4] # uncomment for debug # Good token type if events[i].type in self.tokens_types_graph[events[i - 1].type]: if events[i].type in [ diff --git a/miditok/tokenizations/mmm.py b/miditok/tokenizations/mmm.py index a3c06e0c..46339da0 100644 --- a/miditok/tokenizations/mmm.py +++ b/miditok/tokenizations/mmm.py @@ -9,11 +9,11 @@ from ..constants import ( MIDI_INSTRUMENTS, MMM_DENSITY_BINS_MAX, - TEMPO, TIME_DIVISION, TIME_SIGNATURE, ) from ..midi_tokenizer import MIDITokenizer, _in_as_seq +from ..utils import set_midi_max_tick class MMM(MIDITokenizer): @@ -223,7 +223,7 @@ def tokens_to_midi( # RESULTS instruments: List[Instrument] = [] tempo_changes = [ - TempoChange(TEMPO, -1) + TempoChange(self._DEFAULT_TEMPO, -1) ] # mock the first tempo change to optimize below time_signature_changes = [ TimeSignature(*TIME_SIGNATURE, 0) @@ -321,12 +321,7 @@ def tokens_to_midi( midi.instruments = instruments midi.tempo_changes = tempo_changes midi.time_signature_changes = time_signature_changes - midi.max_tick = max( - [ - max([note.end for note in track.notes]) if len(track.notes) > 0 else 0 - for track in midi.instruments - ] - ) + set_midi_max_tick(midi) # Write MIDI file if output_path: Path(output_path).mkdir(parents=True, exist_ok=True) @@ -436,6 +431,8 @@ def tokens_errors( ) -> float: tokens_to_check = cast(TokSequence, tokens_to_check) nb_tok_predicted = len(tokens_to_check) # used to norm the score + if nb_tok_predicted == 0: + return 0 if self.has_bpe: self.decode_bpe(tokens_to_check) self.complete_sequence(tokens_to_check) diff --git a/miditok/tokenizations/mumidi.py b/miditok/tokenizations/mumidi.py index 1a382d10..20752b4b 100644 --- a/miditok/tokenizations/mumidi.py +++ b/miditok/tokenizations/mumidi.py @@ -9,11 +9,10 @@ from ..constants import ( DRUM_PITCH_RANGE, MIDI_INSTRUMENTS, - TEMPO, TIME_DIVISION, ) from ..midi_tokenizer import MIDITokenizer, _in_as_seq, _out_as_complete_seq -from ..utils import detect_chords +from ..utils import detect_chords, set_midi_max_tick class MuMIDI(MIDITokenizer): @@ -292,10 +291,10 @@ def tokens_to_midi( midi = MidiFile(ticks_per_beat=time_division) # Tempos - if self.config.use_tempos: + if self.config.use_tempos and len(tokens) > 0: first_tempo = float(tokens.tokens[0][3].split("_")[1]) else: - first_tempo = TEMPO + first_tempo = self._DEFAULT_TEMPO midi.tempo_changes.append(TempoChange(first_tempo, 0)) ticks_per_sample = time_division // max(self.config.beat_res.values()) @@ -351,6 +350,7 @@ def tokens_to_midi( ) ) midi.instruments[-1].notes = notes + set_midi_max_tick(midi) # Write MIDI file if output_path: @@ -462,6 +462,8 @@ def tokens_errors(self, tokens: Union[TokSequence, List, np.ndarray, Any]) -> fl :param tokens: sequence of tokens to check :return: the error ratio (lower is better) """ + if len(tokens) == 0: + return 0 tokens = tokens.tokens err = 0 previous_type = tokens[0][0].split("_")[0] diff --git a/miditok/tokenizations/octuple.py b/miditok/tokenizations/octuple.py index 17b178fd..7c3f42d7 100644 --- a/miditok/tokenizations/octuple.py +++ b/miditok/tokenizations/octuple.py @@ -8,11 +8,11 @@ from ..classes import Event, TokSequence from ..constants import ( MIDI_INSTRUMENTS, - TEMPO, TIME_DIVISION, TIME_SIGNATURE, ) from ..midi_tokenizer import MIDITokenizer, _in_as_seq +from ..utils import set_midi_max_tick class Octuple(MIDITokenizer): @@ -102,7 +102,7 @@ def _add_time_events(self, events: List[Event]) -> List[List[Event]]: current_pos = 0 previous_tick = 0 current_time_sig = TIME_SIGNATURE - current_tempo = TEMPO + current_tempo = self._DEFAULT_TEMPO current_program = None ticks_per_bar = self._compute_ticks_per_bar( TimeSignature(*current_time_sig, 0), time_division @@ -214,7 +214,7 @@ def tokens_to_midi( # RESULTS instruments: Dict[int, Instrument] = {} - tempo_changes = [TempoChange(TEMPO, -1)] + tempo_changes = [TempoChange(self._DEFAULT_TEMPO, -1)] time_signature_changes = [] def check_inst(prog: int): @@ -339,12 +339,8 @@ def check_inst(prog: int): midi.instruments = list(instruments.values()) midi.tempo_changes = tempo_changes midi.time_signature_changes = time_signature_changes - midi.max_tick = max( - [ - max([note.end for note in track.notes]) if len(track.notes) > 0 else 0 - for track in midi.instruments - ] - ) + set_midi_max_tick(midi) + # Write MIDI file if output_path: Path(output_path).mkdir(parents=True, exist_ok=True) @@ -429,6 +425,8 @@ def tokens_errors( # If list of TokSequence -> recursive if isinstance(tokens, list): return [self.tokens_errors(tok_seq) for tok_seq in tokens] + if len(tokens) == 0: + return 0 err = 0 current_bar = current_pos = -1 diff --git a/miditok/tokenizations/remi.py b/miditok/tokenizations/remi.py index 767e2e51..39560213 100644 --- a/miditok/tokenizations/remi.py +++ b/miditok/tokenizations/remi.py @@ -16,11 +16,11 @@ from ..classes import Event, TokenizerConfig, TokSequence from ..constants import ( MIDI_INSTRUMENTS, - TEMPO, TIME_DIVISION, TIME_SIGNATURE, ) from ..midi_tokenizer import MIDITokenizer, _in_as_seq +from ..utils import set_midi_max_tick class REMI(MIDITokenizer): @@ -261,7 +261,7 @@ def tokens_to_midi( # RESULTS instruments: Dict[int, Instrument] = {} - tempo_changes = [TempoChange(TEMPO, -1)] + tempo_changes = [TempoChange(self._DEFAULT_TEMPO, -1)] time_signature_changes = [] def check_inst(prog: int): @@ -457,12 +457,8 @@ def check_inst(prog: int): midi.instruments = list(instruments.values()) midi.tempo_changes = tempo_changes midi.time_signature_changes = time_signature_changes - midi.max_tick = max( - [ - max([note.end for note in track.notes]) if len(track.notes) > 0 else 0 - for track in midi.instruments - ] - ) + set_midi_max_tick(midi) + # Write MIDI file if output_path: Path(output_path).mkdir(parents=True, exist_ok=True) diff --git a/miditok/tokenizations/structured.py b/miditok/tokenizations/structured.py index 35dc0e13..19a721ed 100644 --- a/miditok/tokenizations/structured.py +++ b/miditok/tokenizations/structured.py @@ -7,11 +7,11 @@ from ..classes import Event, TokSequence from ..constants import ( MIDI_INSTRUMENTS, - TEMPO, TIME_DIVISION, TIME_SIGNATURE, ) from ..midi_tokenizer import MIDITokenizer, _in_as_seq +from ..utils import set_midi_max_tick class Structured(MIDITokenizer): @@ -150,6 +150,8 @@ def _midi_to_tokens( all_events = [] # Adds note tokens + if not self.one_token_stream and len(midi.instruments) == 0: + all_events.append([]) for track in midi.instruments: note_events = self._create_track_events(track) if self.one_token_stream: @@ -204,7 +206,7 @@ def tokens_to_midi( # RESULTS instruments: Dict[int, Instrument] = {} - tempo_changes = [TempoChange(TEMPO, 0)] + tempo_changes = [TempoChange(self._DEFAULT_TEMPO, 0)] time_signature_changes = [TimeSignature(*TIME_SIGNATURE, 0)] def check_inst(prog: int): @@ -274,12 +276,8 @@ def check_inst(prog: int): midi.instruments = list(instruments.values()) midi.tempo_changes = tempo_changes midi.time_signature_changes = time_signature_changes - midi.max_tick = max( - [ - max([note.end for note in track.notes]) if len(track.notes) > 0 else 0 - for track in midi.instruments - ] - ) + set_midi_max_tick(midi) + # Write MIDI file if output_path: Path(output_path).mkdir(parents=True, exist_ok=True) diff --git a/miditok/tokenizations/tsd.py b/miditok/tokenizations/tsd.py index aedddf89..be476cf8 100644 --- a/miditok/tokenizations/tsd.py +++ b/miditok/tokenizations/tsd.py @@ -15,11 +15,11 @@ from ..classes import Event, TokSequence from ..constants import ( MIDI_INSTRUMENTS, - TEMPO, TIME_DIVISION, TIME_SIGNATURE, ) from ..midi_tokenizer import MIDITokenizer, _in_as_seq +from ..utils import set_midi_max_tick class TSD(MIDITokenizer): @@ -142,7 +142,7 @@ def tokens_to_midi( # RESULTS instruments: Dict[int, Instrument] = {} - tempo_changes = [TempoChange(TEMPO, -1)] + tempo_changes = [TempoChange(self._DEFAULT_TEMPO, -1)] time_signature_changes = [TimeSignature(*TIME_SIGNATURE, 0)] def check_inst(prog: int): @@ -304,12 +304,8 @@ def check_inst(prog: int): midi.instruments = list(instruments.values()) midi.tempo_changes = tempo_changes midi.time_signature_changes = time_signature_changes - midi.max_tick = max( - [ - max([note.end for note in track.notes]) if len(track.notes) > 0 else 0 - for track in midi.instruments - ] - ) + set_midi_max_tick(midi) + # Write MIDI file if output_path: Path(output_path).mkdir(parents=True, exist_ok=True) @@ -371,7 +367,7 @@ def _create_token_types_graph(self) -> Dict[str, List[str]]: dic["Pitch"] = ["Velocity"] dic["Velocity"] = ["Duration"] dic["Duration"] = [first_note_token_type, "TimeShift"] - dic["TimeShift"] = [first_note_token_type] + dic["TimeShift"] = [first_note_token_type, "TimeShift"] if self.config.use_pitch_intervals: for token_type in ("PitchIntervalTime", "PitchIntervalChord"): dic[token_type] = ["Velocity"] diff --git a/miditok/utils/__init__.py b/miditok/utils/__init__.py index 88cacd46..59cc4a57 100644 --- a/miditok/utils/__init__.py +++ b/miditok/utils/__init__.py @@ -8,6 +8,7 @@ merge_tracks_per_class, nb_bar_pos, remove_duplicated_notes, + set_midi_max_tick, ) __all__ = [ @@ -20,4 +21,5 @@ "merge_tracks", "merge_same_program_tracks", "nb_bar_pos", + "set_midi_max_tick", ] diff --git a/miditok/utils/utils.py b/miditok/utils/utils.py index c20ff969..70effbbb 100644 --- a/miditok/utils/utils.py +++ b/miditok/utils/utils.py @@ -38,10 +38,12 @@ def convert_ids_tensors_to_list(ids: Any): # Recursively checks the content are ints (only check first item) el = ids[0] while isinstance(el, list): - el = el[0] + el = el[0] if len(el) > 0 else None # Check endpoint type - if not isinstance(el, int): + if el is None: + pass + elif not isinstance(el, int): # Recursively try to convert elements of the list for ei in range(len(ids)): ids[ei] = convert_ids_tensors_to_list(ids[ei]) @@ -391,6 +393,44 @@ def merge_same_program_tracks(tracks: List[Instrument]): del tracks[i] +def set_midi_max_tick(midi: MidiFile): + midi.max_tick = 0 + + # Parse track events + if len(midi.instruments) > 0: + event_type_attr = ( + ("notes", "end"), + ("pedals", "end"), + ("control_changes", "time"), + ("pitch_bends", "time"), + ) + for track in midi.instruments: + for event_type, time_attr in event_type_attr: + if len(getattr(track, event_type)) > 0: + midi.max_tick = max( + midi.max_tick, + max( + [ + getattr(event, time_attr) + for event in getattr(track, event_type) + ] + ), + ) + + # Parse global MIDI events + for event_type in ( + "tempo_changes", + "time_signature_changes", + "key_signature_changes", + "lyrics", + ): + if len(getattr(midi, event_type)) > 0: + midi.max_tick = max( + midi.max_tick, + max(event.time for event in getattr(midi, event_type)), + ) + + def nb_bar_pos( seq: Sequence[int], bar_token: int, position_tokens: Sequence[int] ) -> Tuple[int, int]: diff --git a/pyproject.toml b/pyproject.toml index 286ef4ee..aa4f6c5a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,11 +37,10 @@ classifiers = [ ] dependencies = [ "numpy>=1.19", - "miditoolkit", # TODO >=v1.0.1? + "miditoolkit", "tqdm", "tokenizers>=0.13.0", "huggingface_hub>=0.16.4", - "scipy", # TODO remove when miditoolkit v1.0.1 ] [project.optional-dependencies] diff --git a/tests/Multitrack_MIDIs/Aicha.mid b/tests/MIDIs_multitrack/Aicha.mid similarity index 100% rename from tests/Multitrack_MIDIs/Aicha.mid rename to tests/MIDIs_multitrack/Aicha.mid diff --git a/tests/Multitrack_MIDIs/All The Small Things.mid b/tests/MIDIs_multitrack/All The Small Things.mid similarity index 100% rename from tests/Multitrack_MIDIs/All The Small Things.mid rename to tests/MIDIs_multitrack/All The Small Things.mid diff --git a/tests/Multitrack_MIDIs/Funkytown.mid b/tests/MIDIs_multitrack/Funkytown.mid similarity index 100% rename from tests/Multitrack_MIDIs/Funkytown.mid rename to tests/MIDIs_multitrack/Funkytown.mid diff --git a/tests/Multitrack_MIDIs/Girls Just Want to Have Fun.mid b/tests/MIDIs_multitrack/Girls Just Want to Have Fun.mid similarity index 100% rename from tests/Multitrack_MIDIs/Girls Just Want to Have Fun.mid rename to tests/MIDIs_multitrack/Girls Just Want to Have Fun.mid diff --git a/tests/Multitrack_MIDIs/I Gotta Feeling.mid b/tests/MIDIs_multitrack/I Gotta Feeling.mid similarity index 100% rename from tests/Multitrack_MIDIs/I Gotta Feeling.mid rename to tests/MIDIs_multitrack/I Gotta Feeling.mid diff --git a/tests/Multitrack_MIDIs/In Too Deep.mid b/tests/MIDIs_multitrack/In Too Deep.mid similarity index 100% rename from tests/Multitrack_MIDIs/In Too Deep.mid rename to tests/MIDIs_multitrack/In Too Deep.mid diff --git a/tests/Multitrack_MIDIs/Les Yeux Revolvers.mid b/tests/MIDIs_multitrack/Les Yeux Revolvers.mid similarity index 100% rename from tests/Multitrack_MIDIs/Les Yeux Revolvers.mid rename to tests/MIDIs_multitrack/Les Yeux Revolvers.mid diff --git a/tests/Multitrack_MIDIs/Mr. Blue Sky.mid b/tests/MIDIs_multitrack/Mr. Blue Sky.mid similarity index 100% rename from tests/Multitrack_MIDIs/Mr. Blue Sky.mid rename to tests/MIDIs_multitrack/Mr. Blue Sky.mid diff --git a/tests/Multitrack_MIDIs/Shut Up.mid b/tests/MIDIs_multitrack/Shut Up.mid similarity index 100% rename from tests/Multitrack_MIDIs/Shut Up.mid rename to tests/MIDIs_multitrack/Shut Up.mid diff --git a/tests/Multitrack_MIDIs/What a Fool Believes.mid b/tests/MIDIs_multitrack/What a Fool Believes.mid similarity index 100% rename from tests/Multitrack_MIDIs/What a Fool Believes.mid rename to tests/MIDIs_multitrack/What a Fool Believes.mid diff --git a/tests/One_track_MIDIs/6338816_Etude No. 4.mid b/tests/MIDIs_one_track/6338816_Etude No. 4.mid similarity index 100% rename from tests/One_track_MIDIs/6338816_Etude No. 4.mid rename to tests/MIDIs_one_track/6338816_Etude No. 4.mid diff --git a/tests/One_track_MIDIs/6354774_Macabre Waltz.mid b/tests/MIDIs_one_track/6354774_Macabre Waltz.mid similarity index 100% rename from tests/One_track_MIDIs/6354774_Macabre Waltz.mid rename to tests/MIDIs_one_track/6354774_Macabre Waltz.mid diff --git a/tests/One_track_MIDIs/Maestro_1.mid b/tests/MIDIs_one_track/Maestro_1.mid similarity index 100% rename from tests/One_track_MIDIs/Maestro_1.mid rename to tests/MIDIs_one_track/Maestro_1.mid diff --git a/tests/One_track_MIDIs/Maestro_10.mid b/tests/MIDIs_one_track/Maestro_10.mid similarity index 100% rename from tests/One_track_MIDIs/Maestro_10.mid rename to tests/MIDIs_one_track/Maestro_10.mid diff --git a/tests/One_track_MIDIs/Maestro_2.mid b/tests/MIDIs_one_track/Maestro_2.mid similarity index 100% rename from tests/One_track_MIDIs/Maestro_2.mid rename to tests/MIDIs_one_track/Maestro_2.mid diff --git a/tests/One_track_MIDIs/Maestro_3.mid b/tests/MIDIs_one_track/Maestro_3.mid similarity index 100% rename from tests/One_track_MIDIs/Maestro_3.mid rename to tests/MIDIs_one_track/Maestro_3.mid diff --git a/tests/One_track_MIDIs/Maestro_4.mid b/tests/MIDIs_one_track/Maestro_4.mid similarity index 100% rename from tests/One_track_MIDIs/Maestro_4.mid rename to tests/MIDIs_one_track/Maestro_4.mid diff --git a/tests/One_track_MIDIs/Maestro_5.mid b/tests/MIDIs_one_track/Maestro_5.mid similarity index 100% rename from tests/One_track_MIDIs/Maestro_5.mid rename to tests/MIDIs_one_track/Maestro_5.mid diff --git a/tests/One_track_MIDIs/Maestro_6.mid b/tests/MIDIs_one_track/Maestro_6.mid similarity index 100% rename from tests/One_track_MIDIs/Maestro_6.mid rename to tests/MIDIs_one_track/Maestro_6.mid diff --git a/tests/One_track_MIDIs/Maestro_7.mid b/tests/MIDIs_one_track/Maestro_7.mid similarity index 100% rename from tests/One_track_MIDIs/Maestro_7.mid rename to tests/MIDIs_one_track/Maestro_7.mid diff --git a/tests/One_track_MIDIs/Maestro_8.mid b/tests/MIDIs_one_track/Maestro_8.mid similarity index 100% rename from tests/One_track_MIDIs/Maestro_8.mid rename to tests/MIDIs_one_track/Maestro_8.mid diff --git a/tests/One_track_MIDIs/Maestro_9.mid b/tests/MIDIs_one_track/Maestro_9.mid similarity index 100% rename from tests/One_track_MIDIs/Maestro_9.mid rename to tests/MIDIs_one_track/Maestro_9.mid diff --git a/tests/One_track_MIDIs/POP909_008.mid b/tests/MIDIs_one_track/POP909_008.mid similarity index 100% rename from tests/One_track_MIDIs/POP909_008.mid rename to tests/MIDIs_one_track/POP909_008.mid diff --git a/tests/One_track_MIDIs/POP909_010.mid b/tests/MIDIs_one_track/POP909_010.mid similarity index 100% rename from tests/One_track_MIDIs/POP909_010.mid rename to tests/MIDIs_one_track/POP909_010.mid diff --git a/tests/One_track_MIDIs/POP909_022.mid b/tests/MIDIs_one_track/POP909_022.mid similarity index 100% rename from tests/One_track_MIDIs/POP909_022.mid rename to tests/MIDIs_one_track/POP909_022.mid diff --git a/tests/One_track_MIDIs/POP909_191.mid b/tests/MIDIs_one_track/POP909_191.mid similarity index 100% rename from tests/One_track_MIDIs/POP909_191.mid rename to tests/MIDIs_one_track/POP909_191.mid diff --git a/tests/MIDIs_one_track/empty.mid b/tests/MIDIs_one_track/empty.mid new file mode 100644 index 00000000..4d2cb2e7 Binary files /dev/null and b/tests/MIDIs_one_track/empty.mid differ diff --git a/tests/conftest.py b/tests/conftest.py index f21f6945..40241996 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,8 @@ +""" +Pytest configuration file. +Doc: https://docs.pytest.org/en/latest/reference/reference.html +""" + import os import pytest diff --git a/tests/test_bpe.py b/tests/test_bpe.py index 9e0636e2..5b4f31d5 100644 --- a/tests/test_bpe.py +++ b/tests/test_bpe.py @@ -7,14 +7,20 @@ from copy import deepcopy from pathlib import Path from time import time -from typing import Union +from typing import Sequence, Union +import pytest from miditoolkit import MidiFile from tqdm import tqdm import miditok -from .tests_utils import TIME_SIGNATURE_RANGE_TESTS +from .utils import ( + MIDI_PATHS_ONE_TRACK, + SEED, + TIME_SIGNATURE_RANGE_TESTS, + TOKENIZATIONS_BPE, +) # Special beat res for test, up to 64 beats so the duration and time-shift values are # long enough for MIDI-Like and Structured encodings, and with a single beat resolution @@ -33,150 +39,133 @@ } -def test_bpe_conversion(data_path: Union[str, Path] = "./tests/One_track_MIDIs"): +@pytest.mark.parametrize("tokenization", TOKENIZATIONS_BPE) +def test_bpe_conversion( + tokenization: str, + tmp_path: Path, + midi_paths: Sequence[Union[str, Path]] = None, + seed: int = SEED, +): r"""Reads a few MIDI files, convert them into token sequences, convert them back to MIDI files. The converted back MIDI files should identical to original one, expect with note starting and ending times quantized, and maybe a some duplicated notes removed - :param data_path: root path to the data to test + :param tokenization: name of the tokenizer class to test. + :param midi_paths: list of paths of MIDI files to use for the tests. + :param seed: seed. """ - random.seed(777) - tokenizations = ["Structured", "REMI", "MIDILike", "TSD", "MMM"] - data_path = Path(data_path) - files = list(data_path.glob("**/*.mid")) + if midi_paths is None: + midi_paths = MIDI_PATHS_ONE_TRACK + random.seed(seed) # Creates tokenizers and computes BPE (build voc) - first_tokenizers = [] - first_samples_bpe = {tok: [] for tok in tokenizations} - for tokenization in tokenizations: - tokenizer_config = miditok.TokenizerConfig(**TOKENIZER_PARAMS) - tokenizer: miditok.MIDITokenizer = getattr(miditok, tokenization)( - tokenizer_config=tokenizer_config - ) - # Give a str to dir for coverage - tokenizer.tokenize_midi_dataset( - str(data_path), Path("tests", "test_results", tokenization) - ) - tokenizer.learn_bpe( - vocab_size=len(tokenizer) + 400, - tokens_paths=list( - Path("tests", "test_results", tokenization).glob("**/*.json") - ), - start_from_empty_voc=True, - ) - tokenizer.save_params( - Path("tests", "test_results", f"{tokenization}_bpe", "config.txt") - ) - first_tokenizers.append(tokenizer) - - test_id_to_token = { - id_: tokenizer._vocab_base_byte_to_token[byte_] - for id_, byte_ in tokenizer._vocab_base_id_to_byte.items() - } - vocab_inv = {v: k for k, v in tokenizer._vocab_base.items()} - assert ( - test_id_to_token == vocab_inv - ), "Vocabulary inversion failed, something might be wrong with the way they are built" - - for file_path in tqdm(files, desc=f"Checking BPE tok / detok ({tokenization})"): - tokens = tokenizer(file_path, apply_bpe_if_possible=False) - if not tokenizer.one_token_stream: - tokens = tokens[0] - to_tok = tokenizer._bytes_to_tokens(tokens.bytes) - to_id = tokenizer._tokens_to_ids(to_tok) - to_by = tokenizer._ids_to_bytes(to_id, as_one_str=True) - assert all( - [to_by == tokens.bytes, to_tok == tokens.tokens, to_id == tokens.ids] - ), "Conversion between tokens / bytes / ids failed, something might be wrong in vocabularies" - - tokenizer.apply_bpe(tokens) - first_samples_bpe[tokenization].append(tokens) + first_samples_bpe = [] + tokenizer_config = miditok.TokenizerConfig(**TOKENIZER_PARAMS) + tokenizer: miditok.MIDITokenizer = getattr(miditok, tokenization)( + tokenizer_config=tokenizer_config + ) + # Give a str to dir for coverage + tokenizer.tokenize_midi_dataset(midi_paths, tmp_path) + tokenizer.learn_bpe( + vocab_size=len(tokenizer) + 400, + tokens_paths=list(tmp_path.glob("**/*.json")), + start_from_empty_voc=True, + ) + tokenizer.save_params(tmp_path / "bpe_config.txt") + + test_id_to_token = { + id_: tokenizer._vocab_base_byte_to_token[byte_] + for id_, byte_ in tokenizer._vocab_base_id_to_byte.items() + } + vocab_inv = {v: k for k, v in tokenizer._vocab_base.items()} + assert ( + test_id_to_token == vocab_inv + ), "Vocabulary inversion failed, something might be wrong with the way they are built" + + for file_path in tqdm( + midi_paths, desc=f"Checking BPE tok / detok ({tokenization})" + ): + tokens = tokenizer(file_path, apply_bpe_if_possible=False) + if not tokenizer.one_token_stream: + tokens = tokens[0] + to_tok = tokenizer._bytes_to_tokens(tokens.bytes) + to_id = tokenizer._tokens_to_ids(to_tok) + to_by = tokenizer._ids_to_bytes(to_id, as_one_str=True) + assert all( + [to_by == tokens.bytes, to_tok == tokens.tokens, to_id == tokens.ids] + ), "Conversion between tokens / bytes / ids failed, something might be wrong in vocabularies" + + tokenizer.apply_bpe(tokens) + first_samples_bpe.append(tokens) # Reload (test) tokenizer from the saved config file - tokenizers = [] - for i, tokenization in enumerate(tokenizations): - tokenizers.append( - getattr(miditok, tokenization)( - params=Path( - "tests", "test_results", f"{tokenization}_bpe", "config.txt" - ) - ) - ) - assert ( - tokenizers[i] == first_tokenizers[i] - ), "Saving and reloading tokenizer failed. The reloaded tokenizer is different from the first one." + tokenizer_reloaded = getattr(miditok, tokenization)( + params=tmp_path / "bpe_config.txt" + ) + assert ( + tokenizer_reloaded == tokenizer + ), "Saving and reloading tokenizer failed. The reloaded tokenizer is different from the first one." # Unbatched BPE at_least_one_error = False - tok_times = [] - for i, file_path in enumerate(tqdm(files, desc="Testing BPE unbatched")): + tok_time = 0 + for i, file_path in enumerate(tqdm(midi_paths, desc="Testing BPE unbatched")): midi = MidiFile(file_path) + tokens_no_bpe = tokenizer(deepcopy(midi), apply_bpe_if_possible=False) + if not tokenizer.one_token_stream: + tokens_no_bpe = tokens_no_bpe[0] + tokens_bpe = deepcopy(tokens_no_bpe) # with BPE - for tokenization, tokenizer in zip(tokenizations, tokenizers): - tokens_no_bpe = tokenizer(deepcopy(midi), apply_bpe_if_possible=False) - if not tokenizer.one_token_stream: - tokens_no_bpe = tokens_no_bpe[0] - tokens_bpe = deepcopy(tokens_no_bpe) # with BPE - - t0 = time() - tokenizer.apply_bpe(tokens_bpe) - tok_times.append(time() - t0) - - tokens_bpe_decoded = deepcopy(tokens_bpe) - tokenizer.decode_bpe(tokens_bpe_decoded) # BPE decomposed - if tokens_bpe != first_samples_bpe[tokenization][i]: - at_least_one_error = True - print( - f"Error with BPE for {tokenization} and {file_path.name}: " - f"BPE encoding failed after tokenizer reload" - ) - if tokens_no_bpe != tokens_bpe_decoded: - at_least_one_error = True - print( - f"Error with BPE for {tokenization} and {file_path.name}: encoding - decoding test failed" - ) - print(f"Mean BPE encoding time unbatched: {sum(tok_times) / len(tok_times):.2f}") + t0 = time() + tokenizer.apply_bpe(tokens_bpe) + tok_time += time() - t0 + + tokens_bpe_decoded = deepcopy(tokens_bpe) + tokenizer.decode_bpe(tokens_bpe_decoded) # BPE decomposed + if tokens_bpe != first_samples_bpe[i]: + at_least_one_error = True + print( + f"Error with BPE for {tokenization} and {file_path.name}: " + f"BPE encoding failed after tokenizer reload" + ) + if tokens_no_bpe != tokens_bpe_decoded: + at_least_one_error = True + print( + f"Error with BPE for {tokenization} and {file_path.name}: encoding - decoding test failed" + ) + print( + f"BPE encoding time un-batched: {tok_time:.2f} (mean: {tok_time / len(midi_paths):.4f})" + ) assert not at_least_one_error # Batched BPE at_least_one_error = False - tok_times = [] - for tokenization, tokenizer in zip(tokenizations, tokenizers): - samples_no_bpe = [] - for i, file_path in enumerate(tqdm(files, desc="Testing BPE batched")): - # Reads the midi - midi = MidiFile(file_path) - if not tokenizer.one_token_stream: - samples_no_bpe.append(tokenizer(midi, apply_bpe_if_possible=False)[0]) - else: - samples_no_bpe.append(tokenizer(midi, apply_bpe_if_possible=False)) - - t0 = time() - samples_bpe = deepcopy(samples_no_bpe) - tokenizer.apply_bpe(samples_bpe) - tok_times.append((time() - t0) / len(files)) - - samples_bpe_decoded = deepcopy(samples_bpe) - tokenizer.decode_bpe(samples_bpe_decoded) # BPE decomposed - for sample_bpe, sample_bpe_first in zip( - samples_bpe, first_samples_bpe[tokenization] - ): - if sample_bpe != sample_bpe_first: - at_least_one_error = True - print( - f"Error with BPE for {tokenization}: BPE encoding failed after tokenizer reload" - ) - for sample_no_bpe, sample_bpe_decoded in zip( - samples_no_bpe, samples_bpe_decoded - ): - if sample_no_bpe != sample_bpe_decoded: - at_least_one_error = True - print( - f"Error with BPE for {tokenization}: encoding - decoding test failed" - ) - print(f"Mean BPE encoding time batched: {sum(tok_times) / len(tok_times):.2f}") + samples_no_bpe = [] + for i, file_path in enumerate(tqdm(midi_paths, desc="Testing BPE batched")): + # Reads the midi + midi = MidiFile(file_path) + tokens_no_bpe = tokenizer(midi, apply_bpe_if_possible=False) + if not tokenizer.one_token_stream: + samples_no_bpe.append(tokens_no_bpe[0]) + else: + samples_no_bpe.append(tokens_no_bpe) + samples_bpe = deepcopy(samples_no_bpe) + t0 = time() + tokenizer.apply_bpe(samples_bpe) + tok_time = time() - t0 + samples_bpe_decoded = deepcopy(samples_bpe) + tokenizer.decode_bpe(samples_bpe_decoded) # BPE decomposed + for sample_bpe, sample_bpe_first in zip(samples_bpe, first_samples_bpe): + if sample_bpe != sample_bpe_first: + at_least_one_error = True + print( + f"Error with BPE for {tokenization}: BPE encoding failed after tokenizer reload" + ) + for sample_no_bpe, sample_bpe_decoded in zip(samples_no_bpe, samples_bpe_decoded): + if sample_no_bpe != sample_bpe_decoded: + at_least_one_error = True + print(f"Error with BPE for {tokenization}: encoding - decoding test failed") + print( + f"BPE encoding time batched: {tok_time:.2f} (mean: {tok_time / len(midi_paths):.4f})" + ) assert not at_least_one_error - - -if __name__ == "__main__": - test_bpe_conversion() diff --git a/tests/test_data_augmentation.py b/tests/test_data_augmentation.py new file mode 100644 index 00000000..e65a6730 --- /dev/null +++ b/tests/test_data_augmentation.py @@ -0,0 +1,251 @@ +#!/usr/bin/python3 python + +"""Test methods + +""" + +import json +from pathlib import Path +from typing import Union + +import numpy as np +import pytest +from miditoolkit import MidiFile +from tqdm import tqdm + +import miditok + +from .utils import ALL_TOKENIZATIONS, HERE + +MAX_NUM_FILES_TEST_TOKENS = 7 + + +def test_data_augmentation_midi( + tmp_path: Path, + data_path: Union[str, Path] = HERE / "MIDIs_multitrack", + tokenization: str = "MIDILike", +): + # We only test data augmentation on MIDIs with one tokenization, as tokenizers does not play here + + tokenizer = getattr(miditok, tokenization)() + midi_aug_path = tmp_path / "Multitrack_MIDIs_aug" / tokenization + miditok.data_augmentation.data_augmentation_dataset( + data_path, + tokenizer, + 2, + 1, + 1, + out_path=midi_aug_path, + copy_original_in_new_location=False, + ) + + aug_midi_paths = list(midi_aug_path.glob("**/*.mid")) + for aug_midi_path in tqdm( + aug_midi_paths, desc="CHECKING DATA AUGMENTATION ON MIDIS" + ): + # Determine offsets of file + parts = aug_midi_path.stem.split("§") + original_stem, offsets_str = parts[0], parts[1].split("_") + offsets = [0, 0, 0] + for offset_str in offsets_str: + for pos, letter in enumerate(["p", "v", "d"]): + if offset_str[0] == letter: + offsets[pos] = int(offset_str[1:]) + + # Loads MIDIs to compare + aug_midi = MidiFile(aug_midi_path) + original_midi = MidiFile(data_path / f"{original_stem}.mid") + + # Compare them + for track_ogi, track_aug in zip( + original_midi.instruments, aug_midi.instruments + ): + if track_ogi.is_drum: + continue + track_ogi.notes.sort(key=lambda x: (x.start, x.pitch, x.end, x.velocity)) + track_aug.notes.sort(key=lambda x: (x.start, x.pitch, x.end, x.velocity)) + for note_o, note_s in zip(track_ogi.notes, track_aug.notes): + assert note_s.pitch == note_o.pitch + offsets[0] + assert note_s.velocity in [ + tokenizer.velocities[0], + tokenizer.velocities[-1], + note_o.velocity + offsets[1], + ] + + +@pytest.mark.parametrize("tokenization", ALL_TOKENIZATIONS) +def test_data_augmentation_tokens( + tmp_path: Path, + tokenization: str, + data_path: Union[str, Path] = HERE / "MIDIs_multitrack", + max_num_files: int = MAX_NUM_FILES_TEST_TOKENS, +): + original_midi_paths = list(data_path.glob("**/*.mid"))[:max_num_files] + if tokenization == "MuMIDI": + pytest.skip( + "MuMIDI is not compatible with data augmentation at the token level" + ) + + tokenizer = getattr(miditok, tokenization)() + tokens_path = tmp_path / "Multitrack_tokens" / tokenization + tokens_aug_path = tmp_path / "Multitrack_tokens_aug" / tokenization + + print("PERFORMING DATA AUGMENTATION ON TOKENS") + tokenizer.tokenize_midi_dataset(original_midi_paths, tokens_path) + miditok.data_augmentation.data_augmentation_dataset( + tokens_path, + tokenizer, + 2, + 1, + 1, + out_path=tokens_aug_path, + copy_original_in_new_location=False, + ) + + # Getting tokens idx from tokenizer for assertions + aug_tokens_paths = list(tokens_aug_path.glob("**/*.json")) + pitch_voc_idx, vel_voc_idx, dur_voc_idx = None, None, None + note_off_tokens = [] + if tokenizer.is_multi_voc: + pitch_voc_idx = tokenizer.vocab_types_idx["Pitch"] + vel_voc_idx = tokenizer.vocab_types_idx["Velocity"] + dur_voc_idx = tokenizer.vocab_types_idx["Duration"] + pitch_tokens = np.array(tokenizer.token_ids_of_type("Pitch", pitch_voc_idx)) + vel_tokens = np.array(tokenizer.token_ids_of_type("Velocity", vel_voc_idx)) + dur_tokens = np.array(tokenizer.token_ids_of_type("Duration", dur_voc_idx)) + else: + pitch_tokens = np.array( + tokenizer.token_ids_of_type("Pitch") + tokenizer.token_ids_of_type("NoteOn") + ) + vel_tokens = np.array(tokenizer.token_ids_of_type("Velocity")) + dur_tokens = np.array(tokenizer.token_ids_of_type("Duration")) + note_off_tokens = np.array( + tokenizer.token_ids_of_type("NoteOff") + ) # for MidiLike + tok_vel_min, tok_vel_max = vel_tokens[0], vel_tokens[-1] + tok_dur_min, tok_dur_max = None, None + if tokenization != "MIDILike": + tok_dur_min, tok_dur_max = dur_tokens[0], dur_tokens[-1] + + for aug_token_path in aug_tokens_paths: + # Determine offsets of file + parts = aug_token_path.stem.split("§") + original_stem, offsets_str = parts[0], parts[1].split("_") + offsets = [0, 0, 0] + for offset_str in offsets_str: + for pos, letter in enumerate(["p", "v", "d"]): + if offset_str[0] == letter: + offsets[pos] = int(offset_str[1:]) + + # Loads tokens to compare + with open(aug_token_path) as json_file: + file = json.load(json_file) + aug_tokens = file["ids"] + + with open(tokens_path / f"{original_stem}.json") as json_file: + file = json.load(json_file) + original_tokens = file["ids"] + original_programs = file["programs"] if "programs" in file else None + + # Compare them + if tokenizer.one_token_stream: + original_tokens, aug_tokens = [original_tokens], [aug_tokens] + for ti, (original_track, aug_track) in enumerate( + zip(original_tokens, aug_tokens) + ): + if original_programs is not None and original_programs[ti][1]: # drums + continue + for idx, (original_token, aug_token) in enumerate( + zip(original_track, aug_track) + ): + if not tokenizer.is_multi_voc: + if original_token in pitch_tokens: + pitch_offset = offsets[0] + # no offset for drum pitches + if ( + tokenizer.one_token_stream + and idx > 0 + and tokenizer[original_track[idx - 1]] == "Program_-1" + ): + pitch_offset = 0 + assert aug_token == original_token + pitch_offset + elif original_token in vel_tokens: + assert aug_token in [ + original_token + offsets[1], + tok_vel_min, + tok_vel_max, + ] + elif original_token in dur_tokens and tokenization != "MIDILike": + assert aug_token in [ + original_token + offsets[2], + tok_dur_min, + tok_dur_max, + ] + elif original_token in note_off_tokens: + assert aug_token == original_token + offsets[0] + else: + if original_token[pitch_voc_idx] in pitch_tokens: + assert ( + aug_token[pitch_voc_idx] + == original_token[pitch_voc_idx] + offsets[0] + ) + elif original_token[vel_voc_idx] in vel_tokens: + assert aug_token[vel_voc_idx] in [ + original_token[vel_voc_idx] + offsets[1], + tok_vel_min, + tok_vel_max, + ] + elif ( + original_token[dur_voc_idx] in dur_tokens + and tokenization != "MIDILike" + ): + assert aug_token[dur_voc_idx] in [ + original_token[dur_voc_idx] + offsets[2], + tok_dur_min, + tok_dur_max, + ] + + +"""def time_data_augmentation_tokens_vs_mid(): + from time import time + tokenizers = [miditok.TSD(), miditok.REMI()] + data_paths = [Path('./tests/One_track_MIDIs'), Path('./tests/Multitrack_MIDIs')] + + for data_path in data_paths: + for tokenizer in tokenizers: + print(f'\n{data_path.stem} - {type(tokenizer).__name__}') + files = list(data_path.glob('**/*.mid')) + + # Testing opening midi -> augment midis -> tokenize midis + t0 = time() + for file_path in files: + # Reads the MIDI + try: + midi = MidiFile(Path(file_path)) + except Exception: # ValueError, OSError, FileNotFoundError, IOError, EOFError, mido.KeySignatureError + continue + + offsets = miditok.data_augmentation.get_offsets(tokenizer, 2, 2, 2, midi=midi) + midis = miditok.data_augmentation.data_augmentation_midi(midi, tokenizer, *offsets) + for _, aug_mid in midis: + _ = tokenizer(aug_mid) + tt = time() - t0 + print(f'Opening midi -> augment midis -> tokenize midis: took {tt:.2f} sec ' + f'({tt / len(files):.2f} sec/file)') + + # Testing opening midi -> tokenize midi -> augment tokens + t0 = time() + for file_path in files: + # Reads the MIDI + try: + midi = MidiFile(Path(file_path)) + except Exception: # ValueError, OSError, FileNotFoundError, IOError, EOFError, mido.KeySignatureError + continue + + tokens = tokenizer(midi) + for track_tokens in tokens: + offsets = miditok.data_augmentation.get_offsets(tokenizer, 2, 2, 2, tokens=tokens) + _ = miditok.data_augmentation.data_augmentation_tokens(track_tokens, tokenizer, *offsets) + tt = time() - t0 + print(f'Opening midi -> tokenize midi -> augment tokens: took {tt:.2f} sec ' + f'({tt / len(files):.2f} sec/file)')""" diff --git a/tests/test_hf_hub.py b/tests/test_hf_hub.py index de0c5229..8c64b5d5 100644 --- a/tests/test_hf_hub.py +++ b/tests/test_hf_hub.py @@ -4,20 +4,37 @@ """ +from pathlib import Path +from time import sleep + +from huggingface_hub.utils._errors import HfHubHTTPError + from miditok import REMI, TSD +MAX_NUM_TRIES_HF_PUSH = 5 +NUM_SECONDS_RETRY = 8 + def test_push_and_load_to_hf_hub(hf_token: str): tokenizer = REMI() - tokenizer.push_to_hub("Natooz/MidiTok-tests", private=True, token=hf_token) + num_tries = 0 + while num_tries < MAX_NUM_TRIES_HF_PUSH: + try: + tokenizer.push_to_hub("Natooz/MidiTok-tests", private=True, token=hf_token) + except HfHubHTTPError as e: + if e.response.status_code in [500, 412, 429]: + num_tries += 1 + sleep(NUM_SECONDS_RETRY) + else: + num_tries = MAX_NUM_TRIES_HF_PUSH tokenizer2 = REMI.from_pretrained("Natooz/MidiTok-tests", token=hf_token) assert tokenizer == tokenizer2 -def test_from_pretrained_local(): +def test_from_pretrained_local(tmp_path: Path): # Here using paths to directories tokenizer = TSD() - tokenizer.save_pretrained("tests/tokenizer_confs") - tokenizer2 = TSD.from_pretrained("tests/tokenizer_confs") + tokenizer.save_pretrained(tmp_path) + tokenizer2 = TSD.from_pretrained(tmp_path) assert tokenizer == tokenizer2 diff --git a/tests/test_io_formats.py b/tests/test_io_formats.py index 93ad5559..22b09736 100644 --- a/tests/test_io_formats.py +++ b/tests/test_io_formats.py @@ -6,71 +6,79 @@ from copy import deepcopy from pathlib import Path +from typing import Any, Dict, Tuple, Union +import pytest from miditoolkit import MidiFile import miditok -from .tests_utils import ALL_TOKENIZATIONS, midis_equals - -BEAT_RES_TEST = {(0, 16): 8} -TOKENIZER_PARAMS = { - "beat_res": BEAT_RES_TEST, - "use_chords": True, - "use_rests": True, - "use_tempos": True, - "use_time_signatures": True, - "use_sustain_pedals": True, - "use_pitch_bends": True, - "use_programs": False, - "chord_maps": miditok.constants.CHORD_MAPS, - "chord_tokens_with_root_note": True, # Tokens will look as "Chord_C:maj" - "chord_unknown": (3, 6), - "beat_res_rest": {(0, 16): 4}, - "nb_tempos": 32, - "tempo_range": (40, 250), - "time_signature_range": {4: [4]}, -} - -programs_tokenizations = ["TSD", "REMI", "MIDILike", "Structured", "CPWord"] -test_cases_programs = [ - ( - { - "use_programs": True, - "one_token_stream_for_programs": True, - "program_changes": False, - }, - [], - ), - ( - { - "use_programs": True, - "one_token_stream_for_programs": True, - "program_changes": True, - }, - ["Structured", "CPWord"], - ), - ( - { - "use_programs": True, - "one_token_stream_for_programs": False, - "program_changes": False, - }, - ["Structured"], - ), +from .utils import ( + ALL_TOKENIZATIONS, + HERE, + TOKENIZER_CONFIG_KWARGS, + adjust_tok_params_for_tests, + prepare_midi_for_tests, +) + +default_params = deepcopy(TOKENIZER_CONFIG_KWARGS) +default_params.update( + { + "use_chords": True, + "use_rests": True, + "use_tempos": True, + "use_time_signatures": True, + "use_sustain_pedals": True, + "use_pitch_bends": True, + } +) +tokenizations_no_one_stream = [ + "TSD", + "REMI", + "MIDILike", + "Structured", + "CPWord", + "Octuple", ] - - -def encode_decode_and_check(tokenizer: miditok.MIDITokenizer, midi: MidiFile): +configs = ( + { + "use_programs": True, + "one_token_stream_for_programs": True, + "program_changes": False, + }, + { + "use_programs": True, + "one_token_stream_for_programs": True, + "program_changes": True, + }, + { + "use_programs": True, + "one_token_stream_for_programs": False, + "program_changes": False, + }, +) +TOK_PARAMS_IO = [] +for tokenization_ in ALL_TOKENIZATIONS: + params_ = deepcopy(default_params) + adjust_tok_params_for_tests(tokenization_, params_) + TOK_PARAMS_IO.append((tokenization_, params_)) + + if tokenization_ in tokenizations_no_one_stream: + for config in configs: + params_tmp = deepcopy(params_) + params_tmp.update(config) + TOK_PARAMS_IO.append((tokenization_, params_tmp)) + + +def encode_decode_and_check(tokenizer: miditok.MIDITokenizer, midi: MidiFile) -> bool: + """Tests if a + + :param tokenizer: + :param midi: + :return: + """ # Process the MIDI - midi_to_compare = deepcopy(midi) - for track in midi_to_compare.instruments: - if track.is_drum: - track.program = 0 # need to be done before sorting tracks per program - # MIDI produced with one_token_stream contains tracks with different orders - midi_to_compare.instruments.sort( - key=lambda x: (x.program, x.is_drum) - ) # sort tracks + midi_to_compare = prepare_midi_for_tests(midi) # Convert the midi to tokens, and keeps the ids (integers) tokens = tokenizer(midi_to_compare) @@ -90,62 +98,27 @@ def encode_decode_and_check(tokenizer: miditok.MIDITokenizer, midi: MidiFile): return True # Checks its good - decoded_midi.instruments.sort(key=lambda x: (x.program, x.is_drum)) - if type(tokenizer).__name__ == "MIDILike": - for track in decoded_midi.instruments: - track.notes.sort(key=lambda x: (x.start, x.pitch, x.end)) - errors = midis_equals(midi_to_compare, decoded_midi) - if len(errors) > 0: - print( - f"Failed to encode/decode NOTES with {tokenizer.__class__.__name__} ({len(errors)} errors)" - ) - return True + decoded_midi = prepare_midi_for_tests(decoded_midi, sort_notes=True) + return decoded_midi == midi_to_compare - return False - -def test_io_formats(): +@pytest.mark.parametrize("tok_params_set", TOK_PARAMS_IO) +def test_io_formats( + tok_params_set: Tuple[str, Dict[str, Any]], + midi_path: Union[str, Path] = HERE / "MIDIs_multitrack" / "Funkytown.mid", +): r"""Reads a few MIDI files, convert them into token sequences, convert them back to MIDI files. The converted back MIDI files should identical to original one, expect with note starting and ending - times quantized, and maybe a some duplicated notes removed + times quantized, and maybe a some duplicated notes removed. + + :param tok_params_set: tokenizer and its parameters to run. + :param midi_path: path to the MIDI file to test. """ - at_least_one_error = False - - file_path = Path("tests", "Multitrack_MIDIs", "Funkytown.mid") - midi = MidiFile(file_path) - - for tokenization in ALL_TOKENIZATIONS: - params = deepcopy(TOKENIZER_PARAMS) - if tokenization == "Structured": - params["beat_res"] = {(0, 512): 8} - elif tokenization == "Octuple": - params["use_time_signatures"] = False - tokenizer_config = miditok.TokenizerConfig(**params) - tokenizer: miditok.MIDITokenizer = getattr(miditok, tokenization)( - tokenizer_config=tokenizer_config - ) - - at_least_one_error = ( - encode_decode_and_check(tokenizer, midi) or at_least_one_error - ) - - # If TSD, also test in use_programs / one_token_stream mode - if tokenization in programs_tokenizations: - for custom_params, excluded_tok in test_cases_programs: - if tokenization in excluded_tok: - continue - params = deepcopy(TOKENIZER_PARAMS) - params.update(custom_params) - tokenizer_config = miditok.TokenizerConfig(**params) - tokenizer: miditok.MIDITokenizer = getattr(miditok, tokenization)( - tokenizer_config=tokenizer_config - ) - at_least_one_error = ( - encode_decode_and_check(tokenizer, midi) or at_least_one_error - ) + midi = MidiFile(midi_path) + tokenization, params = tok_params_set + tokenizer: miditok.MIDITokenizer = getattr(miditok, tokenization)( + tokenizer_config=miditok.TokenizerConfig(**params) + ) + at_least_one_error = encode_decode_and_check(tokenizer, midi) assert not at_least_one_error - - -if __name__ == "__main__": - test_io_formats() diff --git a/tests/test_methods.py b/tests/test_methods.py index 5f71bf57..dbd37a2f 100644 --- a/tests/test_methods.py +++ b/tests/test_methods.py @@ -4,13 +4,9 @@ """ -import json -import random from pathlib import Path -from typing import Union +from typing import Sequence, Union -import numpy as np -from miditoolkit import MidiFile from tensorflow import Tensor as tfTensor from tensorflow import convert_to_tensor from torch import ( @@ -22,11 +18,10 @@ from torch import ( Tensor as ptTensor, ) -from tqdm import tqdm import miditok -from .tests_utils import ALL_TOKENIZATIONS +from .utils import HERE, MIDI_PATHS_ALL def test_convert_tensors(): @@ -44,259 +39,23 @@ def test_convert_tensors(): assert as_list == original -"""def time_data_augmentation_tokens_vs_mid(): - from time import time - tokenizers = [miditok.TSD(), miditok.REMI()] - data_paths = [Path('./tests/One_track_MIDIs'), Path('./tests/Multitrack_MIDIs')] +def test_tokenize_datasets_file_tree( + tmp_path: Path, midi_paths: Sequence[Union[str, Path]] = None +): + if midi_paths is None: + midi_paths = MIDI_PATHS_ALL - for data_path in data_paths: - for tokenizer in tokenizers: - print(f'\n{data_path.stem} - {type(tokenizer).__name__}') - files = list(data_path.glob('**/*.mid')) - - # Testing opening midi -> augment midis -> tokenize midis - t0 = time() - for file_path in files: - # Reads the MIDI - try: - midi = MidiFile(Path(file_path)) - except Exception: # ValueError, OSError, FileNotFoundError, IOError, EOFError, mido.KeySignatureError - continue - - offsets = miditok.data_augmentation.get_offsets(tokenizer, 2, 2, 2, midi=midi) - midis = miditok.data_augmentation.data_augmentation_midi(midi, tokenizer, *offsets) - for _, aug_mid in midis: - _ = tokenizer(aug_mid) - tt = time() - t0 - print(f'Opening midi -> augment midis -> tokenize midis: took {tt:.2f} sec ' - f'({tt / len(files):.2f} sec/file)') - - # Testing opening midi -> tokenize midi -> augment tokens - t0 = time() - for file_path in files: - # Reads the MIDI - try: - midi = MidiFile(Path(file_path)) - except Exception: # ValueError, OSError, FileNotFoundError, IOError, EOFError, mido.KeySignatureError - continue - - tokens = tokenizer(midi) - for track_tokens in tokens: - offsets = miditok.data_augmentation.get_offsets(tokenizer, 2, 2, 2, tokens=tokens) - _ = miditok.data_augmentation.data_augmentation_tokens(track_tokens, tokenizer, *offsets) - tt = time() - t0 - print(f'Opening midi -> tokenize midi -> augment tokens: took {tt:.2f} sec ' - f'({tt / len(files):.2f} sec/file)')""" - - -def test_data_augmentation(): - data_path = Path("./tests/Multitrack_MIDIs") - original_midi_paths = list(data_path.glob("**/*.mid"))[:7] - ALL_TOKENIZATIONS.remove("MuMIDI") # not compatible - - for tokenization in ALL_TOKENIZATIONS: - print(f"TESTING WITH {tokenization}") - tokenizer = getattr(miditok, tokenization)() - midi_aug_path = Path("tests", "Multitrack_MIDIs_aug", tokenization) - tokens_path = Path("tests", "Multitrack_tokens", tokenization) - tokens_aug_path = Path("tests", "Multitrack_tokens_aug", tokenization) - - # We only perform and test data augmentation on MIDIs once, as tokenizers does not play here - if tokenization == "MIDILike": - print("PERFORMING DATA AUGMENTATION ON MIDIS") - miditok.data_augmentation.data_augmentation_dataset( - data_path, - tokenizer, - 2, - 1, - 1, - out_path=midi_aug_path, - copy_original_in_new_location=False, - ) - aug_midi_paths = list(midi_aug_path.glob("**/*.mid")) - for aug_midi_path in tqdm( - aug_midi_paths, desc="CHECKING DATA AUGMENTATION ON MIDIS" - ): - if "Mr. Blue Sky" in aug_midi_path.stem: - continue # TODO remove when miditoolkit v1.0.1 is released - # Determine offsets of file - parts = aug_midi_path.stem.split("§") - original_stem, offsets_str = parts[0], parts[1].split("_") - offsets = [0, 0, 0] - for offset_str in offsets_str: - for pos, letter in enumerate(["p", "v", "d"]): - if offset_str[0] == letter: - offsets[pos] = int(offset_str[1:]) - - # Loads MIDIs to compare - try: - aug_midi = MidiFile(aug_midi_path) - original_midi = MidiFile(data_path / f"{original_stem}.mid") - except Exception: # ValueError, OSError, FileNotFoundError, IOError, EOFError, mido.KeySignatureError - continue - - # Compare them - for original_track, aug_track in zip( - original_midi.instruments, aug_midi.instruments - ): - if original_track.is_drum: - continue - original_track.notes.sort( - key=lambda x: (x.start, x.pitch, x.end, x.velocity) - ) # sort notes - aug_track.notes.sort( - key=lambda x: (x.start, x.pitch, x.end, x.velocity) - ) # sort notes - for note_o, note_s in zip(original_track.notes, aug_track.notes): - assert note_s.pitch == note_o.pitch + offsets[0] - assert note_s.velocity in [ - tokenizer.velocities[0], - tokenizer.velocities[-1], - note_o.velocity + offsets[1], - ] - - print("PERFORMING DATA AUGMENTATION ON TOKENS") - tokenizer.tokenize_midi_dataset(original_midi_paths, tokens_path) - miditok.data_augmentation.data_augmentation_dataset( - tokens_path, - tokenizer, - 2, - 1, - 1, - out_path=tokens_aug_path, - copy_original_in_new_location=False, - ) - - # Getting tokens idx from tokenizer for assertions - aug_tokens_paths = list(tokens_aug_path.glob("**/*.json")) - pitch_voc_idx, vel_voc_idx, dur_voc_idx = None, None, None - note_off_tokens = [] - if tokenizer.is_multi_voc: - pitch_voc_idx = tokenizer.vocab_types_idx["Pitch"] - vel_voc_idx = tokenizer.vocab_types_idx["Velocity"] - dur_voc_idx = tokenizer.vocab_types_idx["Duration"] - pitch_tokens = np.array(tokenizer.token_ids_of_type("Pitch", pitch_voc_idx)) - vel_tokens = np.array(tokenizer.token_ids_of_type("Velocity", vel_voc_idx)) - dur_tokens = np.array(tokenizer.token_ids_of_type("Duration", dur_voc_idx)) - else: - pitch_tokens = np.array( - tokenizer.token_ids_of_type("Pitch") - + tokenizer.token_ids_of_type("NoteOn") - ) - vel_tokens = np.array(tokenizer.token_ids_of_type("Velocity")) - dur_tokens = np.array(tokenizer.token_ids_of_type("Duration")) - note_off_tokens = np.array( - tokenizer.token_ids_of_type("NoteOff") - ) # for MidiLike - tok_vel_min, tok_vel_max = vel_tokens[0], vel_tokens[-1] - tok_dur_min, tok_dur_max = None, None - if tokenization != "MIDILike": - tok_dur_min, tok_dur_max = dur_tokens[0], dur_tokens[-1] - - for aug_token_path in aug_tokens_paths: - # Determine offsets of file - parts = aug_token_path.stem.split("§") - original_stem, offsets_str = parts[0], parts[1].split("_") - offsets = [0, 0, 0] - for offset_str in offsets_str: - for pos, letter in enumerate(["p", "v", "d"]): - if offset_str[0] == letter: - offsets[pos] = int(offset_str[1:]) - - # Loads tokens to compare - with open(aug_token_path) as json_file: - file = json.load(json_file) - aug_tokens = file["ids"] - - with open(tokens_path / f"{original_stem}.json") as json_file: - file = json.load(json_file) - original_tokens = file["ids"] - original_programs = file["programs"] if "programs" in file else None - - # Compare them - if tokenizer.one_token_stream: - original_tokens, aug_tokens = [original_tokens], [aug_tokens] - for ti, (original_track, aug_track) in enumerate( - zip(original_tokens, aug_tokens) - ): - if original_programs is not None and original_programs[ti][1]: # drums - continue - for idx, (original_token, aug_token) in enumerate( - zip(original_track, aug_track) - ): - if not tokenizer.is_multi_voc: - if original_token in pitch_tokens: - pitch_offset = offsets[0] - # no offset for drum pitches - if ( - tokenizer.one_token_stream - and idx > 0 - and tokenizer[original_track[idx - 1]] == "Program_-1" - ): - pitch_offset = 0 - assert aug_token == original_token + pitch_offset - elif original_token in vel_tokens: - assert aug_token in [ - original_token + offsets[1], - tok_vel_min, - tok_vel_max, - ] - elif ( - original_token in dur_tokens and tokenization != "MIDILike" - ): - assert aug_token in [ - original_token + offsets[2], - tok_dur_min, - tok_dur_max, - ] - elif original_token in note_off_tokens: - assert aug_token == original_token + offsets[0] - else: - if original_token[pitch_voc_idx] in pitch_tokens: - assert ( - aug_token[pitch_voc_idx] - == original_token[pitch_voc_idx] + offsets[0] - ) - elif original_token[vel_voc_idx] in vel_tokens: - assert aug_token[vel_voc_idx] in [ - original_token[vel_voc_idx] + offsets[1], - tok_vel_min, - tok_vel_max, - ] - elif ( - original_token[dur_voc_idx] in dur_tokens - and tokenization != "MIDILike" - ): - assert aug_token[dur_voc_idx] in [ - original_token[dur_voc_idx] + offsets[2], - tok_dur_min, - tok_dur_max, - ] - - -def test_tokenize_datasets(data_path: Union[str, Path] = Path("./tests")): # Check the file tree is copied - random.seed(8) - midi_paths = list((data_path / "One_track_MIDIs").glob("**/*.mid")) + list( - (data_path / "Multitrack_MIDIs").glob("**/*.mid") - ) - midi_paths = random.choices(midi_paths, k=6) - config = miditok.TokenizerConfig() - tokenizer = miditok.TSD(config) - out_path = Path("tests", "test_results", "file_tree") - tokenizer.tokenize_midi_dataset(midi_paths, out_path) - json_paths = list(out_path.glob("**/*.json")) + tokenizer = miditok.TSD(miditok.TokenizerConfig()) + tokenizer.tokenize_midi_dataset(midi_paths, tmp_path, overwrite_mode=True) + json_paths = list(tmp_path.glob("**/*.json")) json_paths.sort(key=lambda x: x.stem) midi_paths.sort(key=lambda x: x.stem) - assert all( - json_path.relative_to(out_path).with_suffix(".test") - == midi_path.relative_to(data_path).with_suffix(".test") - for json_path, midi_path in zip(json_paths, midi_paths) - ) - tokenizer.tokenize_midi_dataset(midi_paths, out_path, overwrite_mode=False) - - -if __name__ == "__main__": - test_tokenize_datasets() - test_convert_tensors() - test_data_augmentation() + for json_path, midi_path in zip(json_paths, midi_paths): + assert ( + json_path.relative_to(tmp_path).with_suffix(".test") + == midi_path.relative_to(HERE).with_suffix(".test") + ), f"The file tree has not been reproduced as it should, for the file {midi_path} tokenized {json_path}" + + # Just make sure the non-overwrite mode doesn't crash + tokenizer.tokenize_midi_dataset(midi_paths, tmp_path, overwrite_mode=False) diff --git a/tests/test_multitrack.py b/tests/test_multitrack.py deleted file mode 100644 index 3549af7d..00000000 --- a/tests/test_multitrack.py +++ /dev/null @@ -1,163 +0,0 @@ -#!/usr/bin/python3 python - -"""Multitrack test file -""" - -from copy import deepcopy -from pathlib import Path -from time import time -from typing import Union - -from miditoolkit import MidiFile, Pedal -from tqdm import tqdm - -import miditok - -from .tests_utils import ( - ALL_TOKENIZATIONS, - adapt_tempo_changes_times, - remove_equal_successive_tempos, - tokenize_check_equals, -) - -BEAT_RES_TEST = {(0, 16): 8} -TOKENIZER_PARAMS = { - "beat_res": BEAT_RES_TEST, - "use_chords": True, - "use_rests": True, # tempo decode fails when False for MIDILike because beat_res range is too short - "use_tempos": True, - "use_time_signatures": True, - "use_sustain_pedals": True, - "use_pitch_bends": True, - "use_programs": True, - "chord_maps": miditok.constants.CHORD_MAPS, - "chord_tokens_with_root_note": True, # Tokens will look as "Chord_C:maj" - "chord_unknown": (3, 6), - "beat_res_rest": {(0, 2): 4, (2, 12): 2}, - "nb_tempos": 32, - "tempo_range": (40, 250), - "log_tempos": False, - "sustain_pedal_duration": False, - "one_token_stream_for_programs": True, - "program_changes": False, -} - -# Define kwargs sets -# The first set is empty, using the default params -params_kwargs_sets = {tok: [{}] for tok in ALL_TOKENIZATIONS} -programs_tokenizations = ["TSD", "REMI", "MIDILike", "Structured", "CPWord", "Octuple"] -for tok in programs_tokenizations: - params_kwargs_sets[tok].append( - {"one_token_stream_for_programs": False}, - ) -for tok in ["TSD", "REMI", "MIDILike"]: - params_kwargs_sets[tok].append( - {"program_changes": True}, - ) -# Disable tempos for Octuple with one_token_stream_for_programs, as tempos are carried by note tokens, and -# time signatures for the same reasons (as time could be shifted by on or several bars) -params_kwargs_sets["Octuple"][1]["use_tempos"] = False -params_kwargs_sets["Octuple"][0]["use_time_signatures"] = False -params_kwargs_sets["Octuple"][1]["use_time_signatures"] = False -# Increase the TimeShift voc for Structured as it doesn't support successive TimeShifts -for kwargs_set in params_kwargs_sets["Structured"]: - kwargs_set["beat_res"] = {(0, 512): 8} - - -def test_multitrack_midi_to_tokens_to_midi( - data_path: Union[str, Path] = "./tests/Multitrack_MIDIs", - saving_erroneous_midis: bool = False, -): - r"""Reads a few MIDI files, convert them into token sequences, convert them back to MIDI files. - The converted back MIDI files should identical to original one, expect with note starting and ending - times quantized, and maybe a some duplicated notes removed - """ - files = list(Path(data_path).glob("**/*.mid")) - at_least_one_error = False - t0 = time() - - for fi, file_path in enumerate(tqdm(files, desc="Testing multitrack")): - # Reads the MIDI - midi = MidiFile(Path(file_path)) - if midi.ticks_per_beat % max(BEAT_RES_TEST.values()) != 0: - continue - # add pedal messages - for ti in range(max(3, len(midi.instruments))): - midi.instruments[ti].pedals = [ - Pedal(start, start + 200) for start in [100, 600, 1800, 2200] - ] - - for tokenization in ALL_TOKENIZATIONS: - for pi, params_kwargs in enumerate(params_kwargs_sets[tokenization]): - idx = f"{fi}_{pi}" - params = deepcopy(TOKENIZER_PARAMS) - params.update(params_kwargs) - tokenizer_config = miditok.TokenizerConfig(**params) - tokenizer: miditok.MIDITokenizer = getattr(miditok, tokenization)( - tokenizer_config=tokenizer_config - ) - - # Process the MIDI - # midi notes / tempos / time signature quantized with the line above - midi_to_compare = deepcopy(midi) - for track in midi_to_compare.instruments: - if track.is_drum: - track.program = ( - 0 # need to be done before sorting tracks per program - ) - - # Sort and merge tracks if needed - # MIDI produced with one_token_stream contains tracks with different orders - # This step is also performed in preprocess_midi, but we need to call it here for the assertions below - tokenizer.preprocess_midi(midi_to_compare) - # For Octuple, as tempo is only carried at notes times, we need to adapt their times for comparison - # Same for CPWord which carries tempo with Position (for notes) - if tokenization in ["Octuple", "CPWord"]: - adapt_tempo_changes_times( - midi_to_compare.instruments, midi_to_compare.tempo_changes - ) - # When the tokenizer only decoded tempo changes different from the last tempo val - if tokenization in ["CPWord"]: - remove_equal_successive_tempos(midi_to_compare.tempo_changes) - - # MIDI -> Tokens -> MIDI - decoded_midi, has_errors = tokenize_check_equals( - midi_to_compare, tokenizer, idx, file_path.stem - ) - - if has_errors: - at_least_one_error = True - if saving_erroneous_midis: - decoded_midi.dump( - Path( - "tests", - "test_results", - f"{file_path.stem}_{tokenization}.mid", - ) - ) - midi_to_compare.dump( - Path( - "tests", - "test_results", - f"{file_path.stem}_{tokenization}_original.mid", - ) - ) - - ttotal = time() - t0 - print(f"Took {ttotal:.2f} seconds") - assert not at_least_one_error - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser(description="MIDI Encoding test") - parser.add_argument( - "--data", - type=str, - default="tests/Multitrack_MIDIs", - help="directory of MIDI files to use for test", - ) - args = parser.parse_args() - - test_multitrack_midi_to_tokens_to_midi(args.data) diff --git a/tests/test_one_track.py b/tests/test_one_track.py deleted file mode 100644 index 195d7278..00000000 --- a/tests/test_one_track.py +++ /dev/null @@ -1,167 +0,0 @@ -#!/usr/bin/python3 python - -"""One track test file -""" - -from copy import deepcopy -from pathlib import Path, PurePath -from time import time -from typing import Union - -from miditoolkit import MidiFile -from tqdm import tqdm - -import miditok -from miditok.constants import CHORD_MAPS - -from .tests_utils import ( - ALL_TOKENIZATIONS, - TIME_SIGNATURE_RANGE_TESTS, - adapt_tempo_changes_times, - adjust_pedal_durations, - remove_equal_successive_tempos, - tokenize_check_equals, -) - -BEAT_RES_TEST = {(0, 16): 8} -TOKENIZER_PARAMS = { - "beat_res": BEAT_RES_TEST, - "use_chords": False, # set false to speed up tests as it takes some time on maestro MIDIs - "use_rests": True, - "use_tempos": True, - "use_time_signatures": True, - "use_sustain_pedals": True, - "use_pitch_bends": True, - "use_programs": False, - "use_pitch_intervals": True, - "beat_res_rest": {(0, 2): 4, (2, 12): 2}, - "nb_tempos": 32, - "tempo_range": (40, 250), - "log_tempos": True, - "time_signature_range": TIME_SIGNATURE_RANGE_TESTS, - "chord_maps": CHORD_MAPS, - "chord_tokens_with_root_note": True, # Tokens will look as "Chord_C:maj" - "chord_unknown": False, - "delete_equal_successive_time_sig_changes": True, - "delete_equal_successive_tempo_changes": True, - "sustain_pedal_duration": True, - "one_token_stream_for_programs": False, - "program_changes": True, -} - - -def test_one_track_midi_to_tokens_to_midi( - data_path: Union[str, Path, PurePath] = "./tests/One_track_MIDIs", - saving_erroneous_midis: bool = True, -): - r"""Reads a few MIDI files, convert them into token sequences, convert them back to MIDI files. - The converted back MIDI files should identical to original one, expect with note starting and ending - times quantized, and maybe a some duplicated notes removed - - :param data_path: root path to the data to test - :param saving_erroneous_midis: will save MIDIs converted back with errors, to be used to debug - """ - files = list(Path(data_path).glob("**/*.mid")) - at_least_one_error = False - t0 = time() - - for i, file_path in enumerate(tqdm(files, desc="Testing One Track")): - # Reads the midi - midi = MidiFile(file_path) - # midi.instruments = [midi.instruments[0]] - # Will store the tracks tokenized / detokenized, to be saved in case of errors - for ti, track in enumerate(midi.instruments): - track.name = f"original {ti} not quantized" - tracks_with_errors = [] - - for tokenization in ALL_TOKENIZATIONS: - params = deepcopy(TOKENIZER_PARAMS) - # Special beat res for test, up to 64 beats so the duration and time-shift values are - # long enough for Structured, and with a single beat resolution - if tokenization == "Structured": - params["beat_res"] = {(0, 64): 8} - elif tokenization == "Octuple": - params["max_bar_embedding"] = 300 - params["use_time_signatures"] = False # because of time shifted - elif tokenization == "CPWord": - # Rests and time sig can mess up with CPWord, when a Rest that is crossing new bar is followed - # by a new TimeSig change, as TimeSig are carried with Bar tokens (and there is None is this case) - if params["use_time_signatures"] and params["use_rests"]: - params["use_rests"] = False - - tokenizer_config = miditok.TokenizerConfig(**params) - tokenizer: miditok.MIDITokenizer = getattr(miditok, tokenization)( - tokenizer_config=tokenizer_config - ) - - # Process the MIDI - # midi notes / tempos / time signature quantized with the line above - midi_to_compare = deepcopy(midi) - for track in midi_to_compare.instruments: - if track.is_drum: - track.program = ( - 0 # need to be done before sorting tracks per program - ) - - # This step is also performed in preprocess_midi, but we need to call it here for the assertions below - tokenizer.preprocess_midi(midi_to_compare) - # For Octuple, as tempo is only carried at notes times, we need to adapt their times for comparison - # Same for CPWord which carries tempo with Position (for notes) - if tokenization in ["Octuple", "CPWord"]: - # We use the first track only, as it is the one for which tempos are decoded - adapt_tempo_changes_times( - [midi_to_compare.instruments[0]], midi_to_compare.tempo_changes - ) - # When the tokenizer only decoded tempo changes different from the last tempo val - if tokenization in ["CPWord"]: - remove_equal_successive_tempos(midi_to_compare.tempo_changes) - # Adjust pedal ends to the maximum possible value - if tokenizer.config.use_sustain_pedals: - for track in midi_to_compare.instruments: - adjust_pedal_durations(track.pedals, tokenizer, midi.ticks_per_beat) - # Store preprocessed track - if len(tracks_with_errors) == 0: - tracks_with_errors += midi_to_compare.instruments - for ti, track in enumerate(midi_to_compare.instruments): - track.name = f"original {ti} quantized" - - # printing the tokenizer shouldn't fail - _ = str(tokenizer) - - # MIDI -> Tokens -> MIDI - decoded_midi, has_errors = tokenize_check_equals( - midi_to_compare, tokenizer, i, file_path.stem - ) - - # Add track to error list - if has_errors: - for ti, track in enumerate(decoded_midi.instruments): - track.name = f"{ti} encoded with {tokenization}" - tracks_with_errors += decoded_midi.instruments - - # > 1 as the first one is the preprocessed - if len(tracks_with_errors) > len(midi.instruments): - at_least_one_error = True - if saving_erroneous_midis: - midi.tempo_changes = midi_to_compare.tempo_changes - midi.time_signature_changes = midi_to_compare.time_signature_changes - midi.instruments += tracks_with_errors - midi.dump(PurePath("tests", "test_results", file_path.name)) - - ttotal = time() - t0 - print(f"Took {ttotal:.2f} seconds") - assert not at_least_one_error - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser(description="MIDI Encoding test") - parser.add_argument( - "--data", - type=str, - default="tests/One_track_MIDIs", - help="directory of MIDI files to use for test", - ) - args = parser.parse_args() - test_one_track_midi_to_tokens_to_midi(args.data) diff --git a/tests/test_pytorch_data_loading.py b/tests/test_pytorch_data_loading.py index 4197b867..ca371559 100644 --- a/tests/test_pytorch_data_loading.py +++ b/tests/test_pytorch_data_loading.py @@ -5,13 +5,15 @@ """ from pathlib import Path -from typing import Sequence +from typing import Sequence, Union from miditoolkit import MidiFile from torch import randint import miditok +from .utils import MIDI_PATHS_MULTITRACK, MIDI_PATHS_ONE_TRACK + def test_split_seq(): min_seq_len = 50 @@ -26,16 +28,20 @@ def test_split_seq(): ], "Sequence split failed" -def test_dataset_ram(): - multitrack_midis_paths = list(Path("tests", "Multitrack_MIDIs").glob("**/*.mid"))[ - :3 - ] - one_track_midis_paths = list(Path("tests", "One_track_MIDIs").glob("**/*.mid"))[:3] - tokens_os_dir = Path("tests", "multitrack_tokens_os") +def test_dataset_ram( + tmp_path: Path, + midi_paths_one_track: Sequence[Union[str, Path]] = None, + midi_paths_multitrack: Sequence[Union[str, Path]] = None, +): + if midi_paths_one_track is None: + midi_paths_one_track = MIDI_PATHS_ONE_TRACK[:3] + if midi_paths_multitrack is None: + midi_paths_multitrack = MIDI_PATHS_MULTITRACK[:3] + tokens_os_dir = tmp_path / "multitrack_tokens_os" dummy_labels = { label: i for i, label in enumerate( - set(path.name.split("_")[0] for path in one_track_midis_paths) + set(path.name.split("_")[0] for path in midi_paths_one_track) ) } @@ -52,7 +58,7 @@ def get_labels_multitrack_one_stream(tokens: Sequence, _: Path) -> int: config = miditok.TokenizerConfig(use_programs=True) tokenizer_os = miditok.TSD(config) dataset_os = miditok.pytorch_data.DatasetTok( - one_track_midis_paths, + midi_paths_one_track, 50, 100, tokenizer_os, @@ -72,7 +78,7 @@ def get_labels_multitrack_one_stream(tokens: Sequence, _: Path) -> int: # MIDI + Multiple token streams + labels tokenizer_ms = miditok.TSD(miditok.TokenizerConfig()) dataset_ms = miditok.pytorch_data.DatasetTok( - multitrack_midis_paths, + midi_paths_multitrack, 50, 100, tokenizer_ms, @@ -85,7 +91,7 @@ def get_labels_multitrack_one_stream(tokens: Sequence, _: Path) -> int: # JSON + one token stream if not tokens_os_dir.is_dir(): tokenizer_os.tokenize_midi_dataset( - multitrack_midis_paths, + midi_paths_multitrack, tokens_os_dir, ) _ = miditok.pytorch_data.DatasetTok( @@ -95,19 +101,16 @@ def get_labels_multitrack_one_stream(tokens: Sequence, _: Path) -> int: func_to_get_labels=get_labels_multitrack_one_stream, ) - assert True - -def test_dataset_io(): - multitrack_midis_paths = list(Path("tests", "Multitrack_MIDIs").glob("**/*.mid"))[ - :3 - ] - tokens_os_dir = Path("tests", "multitrack_tokens_os") +def test_dataset_io(tmp_path: Path, midi_path: Sequence[Union[str, Path]] = None): + if midi_path is None: + midi_path = MIDI_PATHS_MULTITRACK[:3] + tokens_os_dir = tmp_path / "multitrack_tokens_os" if not tokens_os_dir.is_dir(): config = miditok.TokenizerConfig(use_programs=True) tokenizer = miditok.TSD(config) - tokenizer.tokenize_midi_dataset(multitrack_midis_paths, tokens_os_dir) + tokenizer.tokenize_midi_dataset(midi_path, tokens_os_dir) dataset = miditok.pytorch_data.DatasetJsonIO( list(tokens_os_dir.glob("**/*.json")), @@ -120,22 +123,22 @@ def test_dataset_io(): for _ in dataset: pass - assert True - -def test_split_dataset_to_subsequences(): - multitrack_midis_paths = list(Path("tests", "Multitrack_MIDIs").glob("**/*.mid"))[ - :3 - ] - tokens_os_dir = Path("tests", "multitrack_tokens_os") - tokens_split_dir = Path("tests", "multitrack_tokens_os_split") - tokens_split_dir_ms = Path("tests", "multitrack_tokens_ms_split") +def test_split_dataset_to_subsequences( + tmp_path: Path, + midi_path: Sequence[Union[str, Path]] = None, +): + if midi_path is None: + midi_path = MIDI_PATHS_MULTITRACK[:3] + tokens_os_dir = tmp_path / "multitrack_tokens_os" + tokens_split_dir = tmp_path / "multitrack_tokens_os_split" + tokens_split_dir_ms = tmp_path / "multitrack_tokens_ms_split" # One token stream if not tokens_os_dir.is_dir(): config = miditok.TokenizerConfig(use_programs=True) tokenizer = miditok.TSD(config) - tokenizer.tokenize_midi_dataset(multitrack_midis_paths, tokens_os_dir) + tokenizer.tokenize_midi_dataset(midi_path, tokens_os_dir) miditok.pytorch_data.split_dataset_to_subsequences( list(tokens_os_dir.glob("**/*.json")), tokens_split_dir, @@ -148,7 +151,7 @@ def test_split_dataset_to_subsequences(): if not tokens_split_dir_ms.is_dir(): config = miditok.TokenizerConfig(use_programs=False) tokenizer = miditok.TSD(config) - tokenizer.tokenize_midi_dataset(multitrack_midis_paths, tokens_split_dir_ms) + tokenizer.tokenize_midi_dataset(midi_path, tokens_split_dir_ms) miditok.pytorch_data.split_dataset_to_subsequences( list(tokens_split_dir_ms.glob("**/*.json")), tokens_split_dir, @@ -157,8 +160,6 @@ def test_split_dataset_to_subsequences(): False, ) - assert True - def test_collator(): collator = miditok.pytorch_data.DataCollator( @@ -197,13 +198,3 @@ def test_collator(): max(seq_lengths) + 1, 5, ] - - assert True - - -if __name__ == "__main__": - test_split_seq() - test_dataset_ram() - test_dataset_io() - test_split_dataset_to_subsequences() - test_collator() diff --git a/tests/test_results/.gitignore b/tests/test_results/.gitignore deleted file mode 100644 index 5e7d2734..00000000 --- a/tests/test_results/.gitignore +++ /dev/null @@ -1,4 +0,0 @@ -# Ignore everything in this directory -* -# Except this file -!.gitignore diff --git a/tests/test_saving_loading_config.py b/tests/test_saving_loading_config.py index 09d8f88a..cbfd44ef 100644 --- a/tests/test_saving_loading_config.py +++ b/tests/test_saving_loading_config.py @@ -5,9 +5,13 @@ """ +from pathlib import Path + +import pytest + import miditok -from .tests_utils import ALL_TOKENIZATIONS +from .utils import ALL_TOKENIZATIONS ADDITIONAL_TOKENS_TEST = { "use_chords": False, # set False to speed up tests as it takes some time on maestro MIDIs @@ -21,41 +25,35 @@ } -def test_saving_loading_tokenizer_config(): - for tokenization in ALL_TOKENIZATIONS: - config1 = miditok.TokenizerConfig() - config1.save_to_json(f"./tests/configs/tok_conf_{tokenization}.json") +@pytest.mark.parametrize("tokenization", ALL_TOKENIZATIONS) +def test_saving_loading_tokenizer_config(tokenization: str, tmp_path: Path): + config1 = miditok.TokenizerConfig() + config1.save_to_json(tmp_path / f"tok_conf_{tokenization}.json") - config2 = miditok.TokenizerConfig.load_from_json( - f"./tests/configs/tok_conf_{tokenization}.json" - ) + config2 = miditok.TokenizerConfig.load_from_json( + tmp_path / f"tok_conf_{tokenization}.json" + ) - assert config1 == config2 - config1.pitch_range = (0, 777) - assert config1 != config2 + assert config1 == config2 + config1.pitch_range = (0, 777) + assert config1 != config2 -def test_saving_loading_tokenizer(): +@pytest.mark.parametrize("tokenization", ALL_TOKENIZATIONS) +def test_saving_loading_tokenizer(tokenization: str, tmp_path: Path): r"""Tests to create tokenizers, save their config, and load it back. If all went well the tokenizer should be identical. """ - - for tokenization in ALL_TOKENIZATIONS: - tokenizer_config = miditok.TokenizerConfig(**ADDITIONAL_TOKENS_TEST) - tokenizer: miditok.MIDITokenizer = getattr(miditok, tokenization)( - tokenizer_config=tokenizer_config - ) - tokenizer.save_params(f"./tests/configs/{tokenization}.txt") - - tokenizer2: miditok.MIDITokenizer = getattr(miditok, tokenization)( - params=f"./tests/configs/{tokenization}.txt" - ) - assert tokenizer == tokenizer2 - if tokenization == "Octuple": - tokenizer.vocab[0]["PAD_None"] = 8 - assert tokenizer != tokenizer2 - - -if __name__ == "__main__": - test_saving_loading_tokenizer_config() - test_saving_loading_tokenizer() + tokenizer_config = miditok.TokenizerConfig(**ADDITIONAL_TOKENS_TEST) + tokenizer: miditok.MIDITokenizer = getattr(miditok, tokenization)( + tokenizer_config=tokenizer_config + ) + tokenizer.save_params(tmp_path / f"{tokenization}.txt") + + tokenizer2: miditok.MIDITokenizer = getattr(miditok, tokenization)( + params=tmp_path / f"{tokenization}.txt" + ) + assert tokenizer == tokenizer2 + if tokenization == "Octuple": + tokenizer.vocab[0]["PAD_None"] = 8 + assert tokenizer != tokenizer2 diff --git a/tests/test_tokenize_multitrack.py b/tests/test_tokenize_multitrack.py new file mode 100644 index 00000000..7d5dd6b3 --- /dev/null +++ b/tests/test_tokenize_multitrack.py @@ -0,0 +1,117 @@ +#!/usr/bin/python3 python + +"""Multitrack test file +""" + +from copy import deepcopy +from pathlib import Path +from typing import Any, Dict, Sequence, Tuple, Union + +import pytest +from miditoolkit import MidiFile, Pedal + +import miditok + +from .utils import ( + ALL_TOKENIZATIONS, + MIDI_PATHS_MULTITRACK, + TEST_LOG_DIR, + TOKENIZER_CONFIG_KWARGS, + adjust_tok_params_for_tests, + prepare_midi_for_tests, + tokenize_and_check_equals, +) + +default_params = deepcopy(TOKENIZER_CONFIG_KWARGS) +default_params.update( + { + "use_chords": True, + "use_rests": True, # tempo decode fails when False for MIDILike because beat_res range is too short + "use_tempos": True, + "use_time_signatures": True, + "use_sustain_pedals": True, + "use_pitch_bends": True, + "use_programs": True, + "sustain_pedal_duration": False, + "one_token_stream_for_programs": True, + "program_changes": False, + } +) +TOK_PARAMS_MULTITRACK = [] +tokenizations_non_one_stream = [ + "TSD", + "REMI", + "MIDILike", + "Structured", + "CPWord", + "Octuple", +] +tokenizations_program_change = ["TSD", "REMI", "MIDILike"] +for tokenization_ in ALL_TOKENIZATIONS: + params_ = deepcopy(default_params) + adjust_tok_params_for_tests(tokenization_, params_) + TOK_PARAMS_MULTITRACK.append((tokenization_, params_)) + + if tokenization_ in tokenizations_non_one_stream: + params_tmp = deepcopy(params_) + params_tmp["one_token_stream_for_programs"] = False + # Disable tempos for Octuple with one_token_stream_for_programs, as tempos are carried by note tokens + if tokenization_ == "Octuple": + params_tmp["use_tempos"] = False + TOK_PARAMS_MULTITRACK.append((tokenization_, params_tmp)) + if tokenization_ in tokenizations_program_change: + params_tmp = deepcopy(params_) + params_tmp["program_changes"] = True + TOK_PARAMS_MULTITRACK.append((tokenization_, params_tmp)) + + +@pytest.mark.parametrize("midi_path", MIDI_PATHS_MULTITRACK) +def test_multitrack_midi_to_tokens_to_midi( + midi_path: Union[str, Path], + tok_params_sets: Sequence[Tuple[str, Dict[str, Any]]] = None, + saving_erroneous_midis: bool = False, +): + r"""Reads a MIDI file, converts it into tokens, convert it back to a MIDI object. + The decoded MIDI should be identical to the original one after downsampling, and potentially notes deduplication. + We only parametrize for midi files, as it would otherwise require to load them multiple times each. + # TODO test parametrize tokenization / params_set + + :param midi_path: path to the MIDI file to test. + :param tok_params_sets: sequence of tokenizer and its parameters to run. + :param saving_erroneous_midis: will save MIDIs decoded with errors, to be used to debug. + """ + if tok_params_sets is None: + tok_params_sets = TOK_PARAMS_MULTITRACK + at_least_one_error = False + + # Reads the MIDI and add pedal messages + midi = MidiFile(Path(midi_path)) + for ti in range(max(3, len(midi.instruments))): + midi.instruments[ti].pedals = [ + Pedal(start, start + 200) for start in [100, 600, 1800, 2200] + ] + + for tok_i, (tokenization, params) in enumerate(tok_params_sets): + tokenizer: miditok.MIDITokenizer = getattr(miditok, tokenization)( + tokenizer_config=miditok.TokenizerConfig(**params) + ) + + # Process the MIDI + # midi notes / tempos / time signature quantized with the line above + midi_to_compare = prepare_midi_for_tests(midi, tokenizer=tokenizer) + + # MIDI -> Tokens -> MIDI + decoded_midi, has_errors = tokenize_and_check_equals( + midi_to_compare, tokenizer, tok_i, midi_path.stem + ) + + if has_errors: + TEST_LOG_DIR.mkdir(exist_ok=True, parents=True) + at_least_one_error = True + if saving_erroneous_midis: + decoded_midi.dump(TEST_LOG_DIR / f"{midi_path.stem}_{tokenization}.mid") + midi_to_compare.dump( + TEST_LOG_DIR / f"{midi_path.stem}_{tokenization}_original.mid" + ) + + assert not at_least_one_error diff --git a/tests/test_tokenize_one_track.py b/tests/test_tokenize_one_track.py new file mode 100644 index 00000000..193ab7f6 --- /dev/null +++ b/tests/test_tokenize_one_track.py @@ -0,0 +1,113 @@ +#!/usr/bin/python3 python + +"""One track test file +""" + +from copy import deepcopy +from pathlib import Path +from typing import Any, Dict, Sequence, Tuple, Union + +import pytest +from miditoolkit import MidiFile + +import miditok + +from .utils import ( + ALL_TOKENIZATIONS, + MIDI_PATHS_ONE_TRACK, + TEST_LOG_DIR, + TOKENIZER_CONFIG_KWARGS, + adjust_tok_params_for_tests, + prepare_midi_for_tests, + tokenize_and_check_equals, +) + +default_params = deepcopy(TOKENIZER_CONFIG_KWARGS) +default_params.update( + { + "use_chords": False, # set false to speed up tests as it takes some time on maestro MIDIs + "use_rests": True, + "use_tempos": True, + "use_time_signatures": True, + "use_sustain_pedals": True, + "use_pitch_bends": True, + "use_pitch_intervals": True, + "log_tempos": True, + "chord_unknown": False, + "delete_equal_successive_time_sig_changes": True, + "delete_equal_successive_tempo_changes": True, + "sustain_pedal_duration": True, + } +) +TOK_PARAMS_ONE_TRACK = [] +for tokenization_ in ALL_TOKENIZATIONS: + params_ = deepcopy(default_params) + adjust_tok_params_for_tests(tokenization_, params_) + TOK_PARAMS_ONE_TRACK.append((tokenization_, params_)) + + +@pytest.mark.parametrize("midi_path", MIDI_PATHS_ONE_TRACK) +def test_one_track_midi_to_tokens_to_midi( + midi_path: Union[str, Path], + tok_params_sets: Sequence[Tuple[str, Dict[str, Any]]] = None, + saving_erroneous_midis: bool = True, +): + r"""Reads a MIDI file, converts it into tokens, convert it back to a MIDI object. + The decoded MIDI should be identical to the original one after downsampling, and potentially notes deduplication. + We only parametrize for midi files, as it would otherwise require to load them multiple times each. + # TODO test parametrize tokenization / params_set, if faster --> unique method for test tok (one+multi) + + :param midi_path: path to the MIDI file to test. + :param tok_params_sets: sequence of tokenizer and its parameters to run. + :param saving_erroneous_midis: will save MIDIs decoded with errors, to be used to debug. + """ + if tok_params_sets is None: + tok_params_sets = TOK_PARAMS_ONE_TRACK + at_least_one_error = False + + # Reads the midi + midi = MidiFile(midi_path) + # Will store the tracks tokenized / detokenized, to be saved in case of errors + for ti, track in enumerate(midi.instruments): + track.name = f"original {ti} not quantized" + tracks_with_errors = [] + + for tok_i, (tokenization, params) in enumerate(tok_params_sets): + tokenizer: miditok.MIDITokenizer = getattr(miditok, tokenization)( + tokenizer_config=miditok.TokenizerConfig(**params) + ) + + # Process the MIDI + # preprocess_midi is also performed when tokenizing, but we need to call it here for following adaptations + midi_to_compare = prepare_midi_for_tests(midi, tokenizer=tokenizer) + # Store preprocessed track + if len(tracks_with_errors) == 0: + tracks_with_errors += midi_to_compare.instruments + for ti, track in enumerate(midi_to_compare.instruments): + track.name = f"original {ti} quantized" + + # printing the tokenizer shouldn't fail + _ = str(tokenizer) + + # MIDI -> Tokens -> MIDI + decoded_midi, has_errors = tokenize_and_check_equals( + midi_to_compare, tokenizer, tok_i, midi_path.stem + ) + + # Add track to error list + if has_errors: + at_least_one_error = True + for ti, track in enumerate(decoded_midi.instruments): + track.name = f"{tok_i} encoded with {tokenization}" + tracks_with_errors += decoded_midi.instruments + + # > 1 as the first one is the preprocessed + if len(tracks_with_errors) > len(midi.instruments): + if saving_erroneous_midis: + TEST_LOG_DIR.mkdir(exist_ok=True, parents=True) + midi.tempo_changes = midi_to_compare.tempo_changes + midi.time_signature_changes = midi_to_compare.time_signature_changes + midi.instruments += tracks_with_errors + midi.dump(TEST_LOG_DIR / midi_path.name) + + assert not at_least_one_error diff --git a/tests/test_utils.py b/tests/test_utils.py index 9e2f5f0a..5b9c926b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,8 +6,18 @@ from copy import deepcopy from pathlib import Path - -from miditoolkit import MidiFile +from typing import Union + +import pytest +from miditoolkit import ( + ControlChange, + KeySignature, + MidiFile, + Pedal, + PitchBend, + TempoChange, + TimeSignature, +) from miditok import REMI from miditok.constants import CLASS_OF_INST @@ -18,18 +28,113 @@ nb_bar_pos, ) +from .utils import MIDI_PATHS_MULTITRACK, MIDI_PATHS_ONE_TRACK, check_midis_equals + + +def test_containers_assertions(): + tc1 = [TempoChange(120, 2), TempoChange(110, 0)] + tc2 = [TempoChange(120, 3), TempoChange(110, 0)] + tc3 = [TempoChange(120, 3), TempoChange(110, 0)] + assert tc1 != tc2 + assert tc2 == tc3 + + ts1 = [TimeSignature(4, 4, 0), TimeSignature(6, 4, 10)] + ts2 = [TimeSignature(2, 4, 0), TimeSignature(6, 4, 10)] + ts3 = [TimeSignature(2, 4, 0), TimeSignature(6, 4, 10)] + assert ts1 != ts2 + assert ts2 == ts3 + + sp1 = [Pedal(0, 2), TempoChange(10, 20)] + sp2 = [Pedal(0, 2), TempoChange(15, 20)] + sp3 = [Pedal(0, 2), TempoChange(15, 20)] + assert sp1 != sp2 + assert sp2 == sp3 + + pb1 = [PitchBend(120, 2), PitchBend(110, 0)] + pb2 = [PitchBend(120, 3), PitchBend(110, 0)] + pb3 = [PitchBend(120, 3), PitchBend(110, 0)] + assert pb1 != pb2 + assert pb2 == pb3 + + ks1 = [KeySignature("C#", 2), KeySignature("C#", 0)] + ks2 = [KeySignature("C#", 20), KeySignature("C#", 0)] + ks3 = [KeySignature("C#", 20), KeySignature("C#", 0)] + assert ks1 != ks2 + assert ks2 == ks3 + + cc1 = [ControlChange(120, 50, 2), ControlChange(110, 50, 0)] + cc2 = [ControlChange(120, 50, 2), ControlChange(110, 50, 10)] + cc3 = [ControlChange(120, 50, 2), ControlChange(110, 50, 10)] + assert cc1 != cc2 + assert cc2 == cc3 + + +@pytest.mark.parametrize("midi_path", MIDI_PATHS_ONE_TRACK) +def test_check_midi_equals(midi_path: Path): + midi = MidiFile(midi_path) + midi_copy = deepcopy(midi) + + # Check when midi is untouched + assert check_midis_equals(midi, midi_copy)[1] + + # Altering notes + i = 0 + while i < len(midi_copy.instruments): + if len(midi_copy.instruments[i].notes) > 0: + midi_copy.instruments[i].notes[-1].pitch += 5 + assert not check_midis_equals(midi, midi_copy)[1] + break + i += 1 + + # Altering track events + if len(midi_copy.instruments) > 0: + # Altering pedals + midi_copy = deepcopy(midi) + if len(midi_copy.instruments[0].pedals) == 0: + midi_copy.instruments[0].pedals.append(Pedal(0, 10)) + else: + midi_copy.instruments[0].pedals[-1].end += 10 + assert not check_midis_equals(midi, midi_copy)[1] -def test_merge_tracks(): - midi = MidiFile(Path("tests", "One_track_MIDIs", "Maestro_1.mid")) + # Altering pitch bends + midi_copy = deepcopy(midi) + if len(midi_copy.instruments[0].pitch_bends) == 0: + midi_copy.instruments[0].pitch_bends.append(PitchBend(50, 10)) + else: + midi_copy.instruments[0].pitch_bends[-1].end += 10 + assert not check_midis_equals(midi, midi_copy)[1] + + # Altering tempos + midi_copy = deepcopy(midi) + if len(midi_copy.tempo_changes) == 0: + midi_copy.tempo_changes.append(TempoChange(50, 10)) + else: + midi_copy.tempo_changes[-1].time += 10 + assert not check_midis_equals(midi, midi_copy)[1] + + # Altering time signatures + midi_copy = deepcopy(midi) + if len(midi_copy.time_signature_changes) == 0: + midi_copy.time_signature_changes.append(TimeSignature(4, 4, 10)) + else: + midi_copy.time_signature_changes[-1].time += 10 + assert not check_midis_equals(midi, midi_copy)[1] + + +def test_merge_tracks( + midi_path: Union[str, Path] = MIDI_PATHS_ONE_TRACK[0], +): + # Load MIDI and only keep the first track + midi = MidiFile(midi_path) + midi.instruments = [midi.instruments[0]] + + # Duplicate the track and merge it original_track = deepcopy(midi.instruments[0]) midi.instruments.append(deepcopy(midi.instruments[0])) - merge_tracks(midi.instruments) - assert len(midi.instruments[0].notes) == 2 * len(original_track.notes) # Test merge with effects - midi.instruments.append(deepcopy(midi.instruments[0])) merge_tracks(midi, effects=True) - assert len(midi.instruments[0].notes) == 4 * len(original_track.notes) + assert len(midi.instruments[0].notes) == 2 * len(original_track.notes) assert len(midi.instruments[0].pedals) == 2 * len(original_track.pedals) assert len(midi.instruments[0].control_changes) == 2 * len( original_track.control_changes @@ -37,44 +142,37 @@ def test_merge_tracks(): assert len(midi.instruments[0].pitch_bends) == 2 * len(original_track.pitch_bends) -def test_merge_same_program_tracks_and_by_class(): - multitrack_midi_paths = list(Path("tests", "Multitrack_MIDIs").glob("**/*.mid")) - for midi_path in multitrack_midi_paths: - midi = MidiFile(midi_path) - for track in midi.instruments: - if track.is_drum: - track.program = -1 - - # Test merge same program - midi_copy = deepcopy(midi) - programs = [track.program for track in midi_copy.instruments] - unique_programs = list(set(programs)) - merge_same_program_tracks(midi_copy.instruments) - new_programs = [track.program for track in midi_copy.instruments] - unique_programs.sort() - new_programs.sort() - assert new_programs == unique_programs - - # Test merge same class - midi_copy = deepcopy(midi) - merge_tracks_per_class( - midi_copy, - CLASS_OF_INST, - valid_programs=list(range(-1, 128)), - filter_pitches=True, - ) +@pytest.mark.parametrize("midi_path", MIDI_PATHS_MULTITRACK) +def test_merge_same_program_tracks_and_by_class(midi_path: Union[str, Path]): + midi = MidiFile(midi_path) + for track in midi.instruments: + if track.is_drum: + track.program = -1 + + # Test merge same program + midi_copy = deepcopy(midi) + programs = [track.program for track in midi_copy.instruments] + unique_programs = list(set(programs)) + merge_same_program_tracks(midi_copy.instruments) + new_programs = [track.program for track in midi_copy.instruments] + unique_programs.sort() + new_programs.sort() + assert new_programs == unique_programs + + # Test merge same class + midi_copy = deepcopy(midi) + merge_tracks_per_class( + midi_copy, + CLASS_OF_INST, + valid_programs=list(range(-1, 128)), + filter_pitches=True, + ) def test_nb_pos(): tokenizer = REMI() _ = nb_bar_pos( - tokenizer(Path("tests", "One_track_MIDIs", "Maestro_1.mid"))[0].ids, + tokenizer(MIDI_PATHS_ONE_TRACK[0])[0].ids, tokenizer["Bar_None"], tokenizer.token_ids_of_type("Position"), ) - - -if __name__ == "__main__": - test_merge_tracks() - test_merge_same_program_tracks_and_by_class() - test_nb_pos() diff --git a/tests/tests_utils.py b/tests/tests_utils.py deleted file mode 100644 index a7f5ce7c..00000000 --- a/tests/tests_utils.py +++ /dev/null @@ -1,300 +0,0 @@ -""" Test validation methods - -""" - -from typing import List, Tuple, Union - -import numpy as np -from miditoolkit import ( - Instrument, - Marker, - MidiFile, - Note, - Pedal, - PitchBend, - TempoChange, - TimeSignature, -) - -import miditok -from miditok.constants import TIME_SIGNATURE_RANGE - -ALL_TOKENIZATIONS = [ - "MIDILike", - "TSD", - "Structured", - "REMI", - "CPWord", - "Octuple", - "MuMIDI", - "MMM", -] -TIME_SIGNATURE_RANGE_TESTS = TIME_SIGNATURE_RANGE -TIME_SIGNATURE_RANGE_TESTS.update({2: [2, 3, 4]}) -TIME_SIGNATURE_RANGE_TESTS[4].append(8) - - -def midis_equals( - midi1: MidiFile, midi2: MidiFile -) -> List[Tuple[int, str, List[Tuple[str, Union[Note, int], int]]]]: - errors = [] - for track1, track2 in zip(midi1.instruments, midi2.instruments): - track_errors = track_equals(track1, track2) - if len(track_errors) > 0: - errors.append((track1.program, track1.name, track_errors)) - return errors - - -def track_equals( - track1: Instrument, track2: Instrument -) -> List[Tuple[str, Union[Note, int], int]]: - if len(track1.notes) != len(track2.notes): - return [("len", len(track2.notes), len(track1.notes))] - errors = [] - for note1, note2 in zip(track1.notes, track2.notes): - err = notes_equals(note1, note2) - if err != "": - errors.append((err, note2, getattr(note1, err))) - return errors - - -def notes_equals(note1: Note, note2: Note) -> str: - if note1.start != note2.start: - return "start" - elif note1.end != note2.end: - return "end" - elif note1.pitch != note2.pitch: - return "pitch" - elif note1.velocity != note2.velocity: - return "velocity" - return "" - - -def tempo_changes_equals( - tempo_changes1: List[TempoChange], tempo_changes2: List[TempoChange] -) -> List[Tuple[str, Union[TempoChange, int], float]]: - if len(tempo_changes1) != len(tempo_changes2): - return [("len", len(tempo_changes2), len(tempo_changes1))] - errors = [] - for tempo_change1, tempo_change2 in zip(tempo_changes1, tempo_changes2): - if tempo_change1.time != tempo_change2.time: - errors.append(("time", tempo_change1, tempo_change2.time)) - if tempo_change1.tempo != tempo_change2.tempo: - errors.append(("tempo", tempo_change1, tempo_change2.tempo)) - return errors - - -def time_signature_changes_equals( - time_sig_changes1: List[TimeSignature], time_sig_changes2: List[TimeSignature] -) -> List[Tuple[str, Union[TimeSignature, int], float]]: - if len(time_sig_changes1) != len(time_sig_changes2): - return [("len", len(time_sig_changes1), len(time_sig_changes2))] - errors = [] - for time_sig_change1, time_sig_change2 in zip(time_sig_changes1, time_sig_changes2): - if time_sig_change1.time != time_sig_change2.time: - errors.append(("time", time_sig_change1, time_sig_change2.time)) - if time_sig_change1.numerator != time_sig_change2.numerator: - errors.append(("numerator", time_sig_change1, time_sig_change2.numerator)) - if time_sig_change1.denominator != time_sig_change2.denominator: - errors.append( - ("denominator", time_sig_change1, time_sig_change2.denominator) - ) - return errors - - -def pedal_equals( - midi1: MidiFile, midi2: MidiFile -) -> List[List[Tuple[str, Union[Pedal, int], float]]]: - errors = [] - for inst1, inst2 in zip(midi1.instruments, midi2.instruments): - if len(inst1.pedals) != len(inst2.pedals): - errors.append([("len", len(inst1.pedals), len(inst2.pedals))]) - continue - errors.append([]) - for pedal1, pedal2 in zip(inst1.pedals, inst2.pedals): - if pedal1.start != pedal2.start: - errors[-1].append(("start", pedal1, pedal2.start)) - elif pedal1.end != pedal2.end: - errors[-1].append(("end", pedal1, pedal2.end)) - return errors - - -def pitch_bend_equals( - midi1: MidiFile, midi2: MidiFile -) -> List[List[Tuple[str, Union[PitchBend, int], float]]]: - errors = [] - for inst1, inst2 in zip(midi1.instruments, midi2.instruments): - if len(inst1.pitch_bends) != len(inst2.pitch_bends): - errors.append([("len", len(inst1.pitch_bends), len(inst2.pitch_bends))]) - continue - errors.append([]) - for pitch_bend1, pitch_bend2 in zip(inst1.pitch_bends, inst2.pitch_bends): - if pitch_bend1.time != pitch_bend2.time: - errors[-1].append(("time", pitch_bend1, pitch_bend2.time)) - elif pitch_bend1.pitch != pitch_bend2.pitch: - errors[-1].append(("pitch", pitch_bend1, pitch_bend2.pitch)) - return errors - - -def tokenize_check_equals( - midi: MidiFile, - tokenizer: miditok.MIDITokenizer, - file_idx: Union[int, str], - file_name: str, -) -> Tuple[MidiFile, bool]: - has_errors = False - tokenization = type(tokenizer).__name__ - midi.instruments.sort(key=lambda x: (x.program, x.is_drum)) - # merging is performed in preprocess only in one_token_stream mode - # but in multi token stream, decoding will actually keep one track per program - if tokenizer.config.use_programs: - miditok.utils.merge_same_program_tracks(midi.instruments) - - tokens = tokenizer(midi) - midi_decoded = tokenizer( - tokens, - miditok.utils.get_midi_programs(midi), - time_division=midi.ticks_per_beat, - ) - midi_decoded.instruments.sort(key=lambda x: (x.program, x.is_drum)) - if tokenization == "MIDILike": - for track in midi_decoded.instruments: - track.notes.sort(key=lambda x: (x.start, x.pitch, x.end)) - - # Checks types and values conformity following the rules - err_tse = tokenizer.tokens_errors(tokens) - if isinstance(err_tse, list): - err_tse = sum(err_tse) - if err_tse != 0.0: - print( - f"Validation of tokens types / values successions failed with {tokenization}: {err_tse:.2f}" - ) - - # Checks notes - errors = midis_equals(midi, midi_decoded) - if len(errors) > 0: - has_errors = True - for e, track_err in enumerate(errors): - if track_err[-1][0][0] != "len": - for err, note, exp in track_err[-1]: - midi_decoded.markers.append( - Marker( - f"{e}: with note {err} (pitch {note.pitch})", - note.start, - ) - ) - print( - f"MIDI {file_idx} - {file_name} / {tokenization} failed to encode/decode NOTES" - f"({sum(len(t[2]) for t in errors)} errors)" - ) - - # Checks tempos - if ( - tokenizer.config.use_tempos and tokenization != "MuMIDI" - ): # MuMIDI doesn't decode tempos - tempo_errors = tempo_changes_equals( - midi.tempo_changes, midi_decoded.tempo_changes - ) - if len(tempo_errors) > 0: - has_errors = True - print( - f"MIDI {file_idx} - {file_name} / {tokenization} failed to encode/decode TEMPO changes" - f"({len(tempo_errors)} errors)" - ) - - # Checks time signatures - if tokenizer.config.use_time_signatures: - time_sig_errors = time_signature_changes_equals( - midi.time_signature_changes, - midi_decoded.time_signature_changes, - ) - if len(time_sig_errors) > 0: - has_errors = True - print( - f"MIDI {file_idx} - {file_name} / {tokenization} failed to encode/decode TIME SIGNATURE changes" - f"({len(time_sig_errors)} errors)" - ) - - # Checks pedals - if tokenizer.config.use_sustain_pedals: - pedal_errors = pedal_equals(midi, midi_decoded) - if any(len(err) > 0 for err in pedal_errors): - has_errors = True - print( - f"MIDI {file_idx} - {file_name} / {tokenization} failed to encode/decode PEDALS" - f"({sum(len(err) for err in pedal_errors)} errors)" - ) - - # Checks pitch bends - if tokenizer.config.use_pitch_bends: - pitch_bend_errors = pitch_bend_equals(midi, midi_decoded) - if any(len(err) > 0 for err in pitch_bend_errors): - has_errors = True - print( - f"MIDI {file_idx} - {file_name} / {tokenization} failed to encode/decode PITCH BENDS" - f"({sum(len(err) for err in pitch_bend_errors)} errors)" - ) - - # TODO check control changes - - return midi_decoded, has_errors - - -def adapt_tempo_changes_times( - tracks: List[Instrument], tempo_changes: List[TempoChange] -): - r"""Will adapt the times of tempo changes depending on the - onset times of the notes of the MIDI. - This is needed to pass the tempo tests for Octuple as the tempos - will be decoded only from the notes. - - :param tracks: tracks of the MIDI to adapt the tempo changes - :param tempo_changes: tempo changes to adapt - """ - notes = sum((t.notes for t in tracks), []) - notes.sort(key=lambda x: x.start) - max_tick = max(note.start for note in notes) - - current_note_idx = 0 - tempo_idx = 1 - while tempo_idx < len(tempo_changes): - if tempo_changes[tempo_idx].time > max_tick: - del tempo_changes[tempo_idx] - continue - for n, note in enumerate(notes[current_note_idx:]): - if note.start >= tempo_changes[tempo_idx].time: - tempo_changes[tempo_idx].time = note.start - current_note_idx += n - break - if tempo_changes[tempo_idx].time == tempo_changes[tempo_idx - 1].time: - del tempo_changes[tempo_idx - 1] - continue - tempo_idx += 1 - - -def adjust_pedal_durations( - pedals: List[Pedal], tokenizer: miditok.MIDITokenizer, time_division: int -): - durations_in_tick = np.array( - [ - (beat * res + pos) * time_division // res - for beat, pos, res in tokenizer.durations - ] - ) - for pedal in pedals: - dur_index = np.argmin(np.abs(durations_in_tick - pedal.duration)) - beat, pos, res = tokenizer.durations[dur_index] - dur_index_in_tick = (beat * res + pos) * time_division // res - pedal.end = pedal.start + dur_index_in_tick - pedal.duration = pedal.end - pedal.start - - -def remove_equal_successive_tempos(tempo_changes: List[TempoChange]): - current_tempo = -1 - i = 0 - while i < len(tempo_changes): - if tempo_changes[i].tempo == current_tempo: - del tempo_changes[i] - continue - current_tempo = tempo_changes[i].tempo - i += 1 diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 00000000..66d02796 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,360 @@ +""" +Test validation methods. +""" + +from copy import deepcopy +from pathlib import Path +from typing import Any, Dict, List, Tuple, Union + +import numpy as np +from miditoolkit import ( + Instrument, + Marker, + MidiFile, + Note, + Pedal, + TempoChange, + TimeSignature, +) + +import miditok +from miditok.constants import CHORD_MAPS, TIME_SIGNATURE, TIME_SIGNATURE_RANGE + +SEED = 777 + +HERE = Path(__file__).parent +MIDI_PATHS_ONE_TRACK = sorted((HERE / "MIDIs_one_track").rglob("*.mid")) +MIDI_PATHS_MULTITRACK = sorted((HERE / "MIDIs_multitrack").rglob("*.mid")) +MIDI_PATHS_ALL = sorted( + deepcopy(MIDI_PATHS_ONE_TRACK) + deepcopy(MIDI_PATHS_MULTITRACK) +) +TEST_LOG_DIR = HERE / "test_logs" + +# TOKENIZATIONS +ALL_TOKENIZATIONS = miditok.tokenizations.__all__ +TOKENIZATIONS_BPE = ["REMI", "MIDILike", "TSD", "MMM", "Structured"] + +# TOK CONFIG PARAMS +TIME_SIGNATURE_RANGE_TESTS = TIME_SIGNATURE_RANGE +TIME_SIGNATURE_RANGE_TESTS.update({2: [2, 3, 4]}) +TIME_SIGNATURE_RANGE_TESTS[4].append(8) +TOKENIZER_CONFIG_KWARGS = { + "beat_res": {(0, 4): 8, (4, 12): 4, (12, 16): 2}, + "beat_res_rest": {(0, 2): 4, (2, 12): 2}, + "num_tempos": 32, + "tempo_range": (40, 250), + "time_signature_range": TIME_SIGNATURE_RANGE_TESTS, + "chord_maps": CHORD_MAPS, + "chord_tokens_with_root_note": True, # Tokens will look as "Chord_C:maj" + "chord_unknown": (3, 6), + "delete_equal_successive_time_sig_changes": True, + "delete_equal_successive_tempo_changes": True, +} + + +def adjust_tok_params_for_tests(tokenization: str, params: Dict[str, Any]): + """Adjusts parameters (as dictionary for keyword arguments) depending on the tokenization. + + :param tokenization: tokenization. + :param params: parameters as a dictionary of keyword arguments. + """ + # Increase the TimeShift voc for Structured as it doesn't support successive TimeShifts. + if tokenization == "Structured": + params["beat_res"] = {(0, 512): 8} + # We don't test time signatures with Octuple as it can lead to time shifts, as the TS changes are only + # detectable at the onset times of the notes. + elif tokenization == "Octuple": + params["max_bar_embedding"] = 300 + params["use_time_signatures"] = False + # Rests and time sig can mess up with CPWord, when a Rest that is crossing new bar is followed + # by a new TimeSig change, as TimeSig are carried with Bar tokens (and there is None is this case). + elif tokenization == "CPWord": + if params["use_time_signatures"] and params["use_rests"]: + params["use_rests"] = False + + +def prepare_midi_for_tests( + midi: MidiFile, sort_notes: bool = False, tokenizer: miditok.MIDITokenizer = None +) -> MidiFile: + """Prepares a midi for test by returning a copy with tracks sorted, and optionally notes. + It also + + :param midi: midi reference. + :param sort_notes: whether to sort the notes. This is not necessary before tokenizing a MIDI, as the sorting + will be performed by the tokenizer. (default: False) + :param tokenizer: in order to downsample the MIDI before sorting its content. + :return: a new MIDI object with track (and notes) sorted. + """ + tokenization = type(tokenizer).__name__ if tokenizer is not None else None + new_midi = deepcopy(midi) + + # Downsamples the MIDI if a tokenizer is given + if tokenizer is not None: + tokenizer.preprocess_midi(new_midi) + + # For Octuple/CPWord, as tempo is only carried at notes times, we need to adapt their times for comparison + # Set tempo changes at onset times of notes + # We use the first track only, as it is the one for which tempos are decoded + if tokenizer.config.use_tempos and tokenization in ["Octuple", "CPWord"]: + if len(new_midi.instruments) > 0: + adapt_tempo_changes_times( + [new_midi.instruments[0]], new_midi.tempo_changes + ) + else: + new_midi.tempo_changes = [TempoChange(tokenizer._DEFAULT_TEMPO, 0)] + if ( + tokenizer.config.use_time_signatures + and tokenization in ["Octuple", "CPWord", "MMM"] + and len(new_midi.instruments) == 0 + ): + new_midi.time_signature_changes = [TimeSignature(*TIME_SIGNATURE, 0)] + + for track in new_midi.instruments: + # Adjust notes and pedal ends to the maximum possible value + if tokenizer is not None: + adjust_notes_durations(track.notes, tokenizer, midi.ticks_per_beat) + if tokenizer.config.use_sustain_pedals: + adjust_pedal_durations(track.pedals, tokenizer, midi.ticks_per_beat) + if track.is_drum: + track.program = 0 # need to be done before sorting tracks per program + if sort_notes: + track.notes.sort(key=lambda x: (x.start, x.pitch, x.end, x.velocity)) + + # Sorts tracks + # MIDI detokenized with one_token_stream contains tracks sorted by note occurrence + new_midi.instruments.sort(key=lambda x: (x.program, x.is_drum)) + + return new_midi + + +def midis_notes_equals( + midi1: MidiFile, midi2: MidiFile +) -> List[Tuple[int, str, List[Tuple[str, Union[Note, int], int]]]]: + """Checks if the notes from two MIDIs are all equal, and if not returns the list of errors. + + :param midi1: first MIDI. + :param midi2: second MIDI. + :return: list of errors. + """ + errors = [] + for track1, track2 in zip(midi1.instruments, midi2.instruments): + track_errors = tracks_notes_equals(track1, track2) + if len(track_errors) > 0: + errors.append((track1.program, track1.name, track_errors)) + return errors + + +def tracks_notes_equals( + track1: Instrument, track2: Instrument +) -> List[Tuple[str, Union[Note, int], int]]: + if len(track1.notes) != len(track2.notes): + return [("len", len(track2.notes), len(track1.notes))] + errors = [] + for note1, note2 in zip(track1.notes, track2.notes): + err = notes_equals(note1, note2) + if err != "": + errors.append((err, note2, getattr(note1, err))) + return errors + + +def notes_equals(note1: Note, note2: Note) -> str: + if note1.start != note2.start: + return "start" + elif note1.end != note2.end: + return "end" + elif note1.pitch != note2.pitch: + return "pitch" + elif note1.velocity != note2.velocity: + return "velocity" + return "" + + +def check_midis_equals( + midi1: MidiFile, + midi2: MidiFile, + check_tempos: bool = True, + check_time_signatures: bool = True, + check_pedals: bool = True, + check_pitch_bends: bool = True, + log_prefix: str = "", +) -> Tuple[MidiFile, bool]: + has_errors = False + types_of_errors = [] + + # Checks notes and add markers if errors + errors = midis_notes_equals(midi1, midi2) + if len(errors) > 0: + has_errors = True + for e, track_err in enumerate(errors): + if track_err[-1][0][0] != "len": + for err, note, exp in track_err[-1]: + midi2.markers.append( + Marker( + f"{e}: with note {err} (pitch {note.pitch})", + note.start, + ) + ) + print( + f"{log_prefix} failed to encode/decode NOTES ({sum(len(t[2]) for t in errors)} errors)" + ) + + # Check pedals + if check_pedals: + for inst1, inst2 in zip(midi1.instruments, midi2.instruments): + if inst1.pedals != inst2.pedals: + types_of_errors.append("PEDALS") + break + + # Check pitch bends + if check_pitch_bends: + for inst1, inst2 in zip(midi1.instruments, midi2.instruments): + if inst1.pitch_bends != inst2.pitch_bends: + types_of_errors.append("PITCH BENDS") + break + + """# Check control changes + if check_control_changes: + for inst1, inst2 in zip(midi1.instruments, midi2.instruments): + if inst1.control_changes != inst2.control_changes: + types_of_errors.append("CONTROL CHANGES") + break""" + + # Checks tempos + if check_tempos: + if midi1.tempo_changes != midi2.tempo_changes: + types_of_errors.append("TEMPOS") + + # Checks time signatures + if check_time_signatures: + if midi1.time_signature_changes != midi2.time_signature_changes: + types_of_errors.append("TIME SIGNATURES") + + # Prints types of errors + has_errors = has_errors or len(types_of_errors) > 0 + for err_type in types_of_errors: + print(f"{log_prefix} failed to encode/decode {err_type}") + + return midi2, not has_errors + + +def tokenize_and_check_equals( + midi: MidiFile, + tokenizer: miditok.MIDITokenizer, + file_idx: Union[int, str], + file_name: str, +) -> Tuple[MidiFile, bool]: + tokenization = type(tokenizer).__name__ + log_prefix = f"MIDI {file_idx} - {file_name} / {tokenization}" + midi.instruments.sort(key=lambda x: (x.program, x.is_drum)) + # merging is performed in preprocess only in one_token_stream mode + # but in multi token stream, decoding will actually keep one track per program + if tokenizer.config.use_programs: + miditok.utils.merge_same_program_tracks(midi.instruments) + + # Tokenize and detokenize + tokens = tokenizer(midi) + midi_decoded = tokenizer( + tokens, + miditok.utils.get_midi_programs(midi) if len(midi.instruments) > 0 else None, + time_division=midi.ticks_per_beat, + ) + midi_decoded = prepare_midi_for_tests( + midi_decoded, sort_notes=tokenization == "MIDILike" + ) + + # Check decoded MIDI is identical + midi_decoded, no_error = check_midis_equals( + midi, + midi_decoded, + check_tempos=tokenizer.config.use_tempos and not tokenization == "MuMIDI", + check_time_signatures=tokenizer.config.use_time_signatures, + check_pedals=tokenizer.config.use_sustain_pedals, + check_pitch_bends=tokenizer.config.use_pitch_bends, + log_prefix=log_prefix, + ) + + # Checks types and values conformity following the rules + err_tse = tokenizer.tokens_errors(tokens) + if isinstance(err_tse, list): + err_tse = sum(err_tse) + if err_tse != 0.0: + no_error = False + print(f"{log_prefix} Validation of tokens types / values successions failed") + + return midi_decoded, not no_error + + +def adapt_tempo_changes_times( + tracks: List[Instrument], tempo_changes: List[TempoChange] +): + r"""Will adapt the times of tempo changes depending on the + onset times of the notes of the MIDI. + This is needed to pass the tempo tests for Octuple as the tempos + will be decoded only from the notes. + + :param tracks: tracks of the MIDI to adapt the tempo changes + :param tempo_changes: tempo changes to adapt + """ + notes = sum((t.notes for t in tracks), []) + notes.sort(key=lambda x: x.start) + max_tick = max(note.start for note in notes) + + current_note_idx = 0 + tempo_idx = 1 + while tempo_idx < len(tempo_changes): + if tempo_changes[tempo_idx].time > max_tick: + del tempo_changes[tempo_idx] + continue + for n, note in enumerate(notes[current_note_idx:]): + if note.start >= tempo_changes[tempo_idx].time: + tempo_changes[tempo_idx].time = note.start + current_note_idx += n + break + if tempo_changes[tempo_idx].time == tempo_changes[tempo_idx - 1].time: + del tempo_changes[tempo_idx - 1] + continue + tempo_idx += 1 + + +def adjust_notes_durations( + notes: List[Note], tokenizer: miditok.MIDITokenizer, time_division: int +): + """Adapt notes offset times so that they match the possible durations covered by a tokenizer. + + :param notes: list of Note objects to adapt. + :param tokenizer: tokenizer (needed for durations). + :param time_division: time division of the MIDI of origin. + """ + durations_in_tick = np.array( + [ + (beat * res + pos) * time_division // res + for beat, pos, res in tokenizer.durations + ] + ) + for note in notes: + dur_index = np.argmin(np.abs(durations_in_tick - note.duration)) + beat, pos, res = tokenizer.durations[dur_index] + dur_index_in_tick = (beat * res + pos) * time_division // res + note.end = note.start + dur_index_in_tick + + +def adjust_pedal_durations( + pedals: List[Pedal], tokenizer: miditok.MIDITokenizer, time_division: int +): + """Adapt pedal offset times so that they match the possible durations covered by a tokenizer. + + :param pedals: list of Pedal objects to adapt. + :param tokenizer: tokenizer (needed for durations). + :param time_division: time division of the MIDI of origin. + """ + durations_in_tick = np.array( + [ + (beat * res + pos) * time_division // res + for beat, pos, res in tokenizer.durations + ] + ) + for pedal in pedals: + dur_index = np.argmin(np.abs(durations_in_tick - pedal.duration)) + beat, pos, res = tokenizer.durations[dur_index] + dur_index_in_tick = (beat * res + pos) * time_division // res + pedal.end = pedal.start + dur_index_in_tick