Skip to content

Commit

Permalink
fix data augment with unique_track, Program tokens created from confi…
Browse files Browse the repository at this point in the history
…g.programs
  • Loading branch information
Natooz committed Jul 6, 2023
1 parent 220f384 commit 30d5546
Show file tree
Hide file tree
Showing 10 changed files with 34 additions and 18 deletions.
2 changes: 1 addition & 1 deletion miditok/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""

CURRENT_VERSION_PACKAGE = "2.1.0" # used when saving the config of a tokenizer
CURRENT_VERSION_PACKAGE = "2.1.1" # used when saving the config of a tokenizer

MIDI_FILES_EXTENSIONS = [".mid", ".midi", ".MID", ".MIDI"]

Expand Down
31 changes: 24 additions & 7 deletions miditok/data_augmentation/data_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,10 @@ def data_augmentation_dataset(
Tuple[int, int, int], List[Union[int, List[int]]]
] = {}
for track, (_, is_drum) in zip(ids, programs):
if is_drum: # we dont augment drums
# we dont augment drums
if not tokenizer.unique_track and is_drum:
continue
elif tokenizer.unique_track and all(p[1] for p in programs):
continue
corrected_offsets = deepcopy(offsets)
vel_dim = int(128 / len(tokenizer.velocities))
Expand All @@ -113,12 +116,13 @@ def data_augmentation_dataset(
augmented_tokens[aug_offsets].append(seq)
except KeyError:
augmented_tokens[aug_offsets] = [seq]
for i, (track, (_, is_drum)) in enumerate(
zip(ids, programs)
): # adding drums to all already augmented
if is_drum:
for aug_offsets in augmented_tokens:
augmented_tokens[aug_offsets].insert(i, track)
if not tokenizer.unique_track:
for i, (track, (_, is_drum)) in enumerate(
zip(ids, programs)
): # adding drums to all already augmented
if is_drum:
for aug_offsets in augmented_tokens:
augmented_tokens[aug_offsets].insert(i, track)

# Save augmented tracks as json
for aug_offsets, tracks_seq in augmented_tokens.items():
Expand All @@ -138,6 +142,8 @@ def data_augmentation_dataset(
nb_augmentations += 1
nb_tracks_augmented += len(tracks_seq)
if copy_original_in_new_location and out_path is not None:
if tokenizer.unique_track:
ids = ids[0]
tokenizer.save_tokens(
ids, out_path / f"{file_path.stem}.json", programs
)
Expand Down Expand Up @@ -448,6 +454,17 @@ def data_augmentation_tokens(
)
note_off_tokens = np.array(tokenizer.token_ids_of_type("NoteOff"))
mask_pitch = np.isin(tokens, pitch_tokens)
# If applicable, removes drum notes from the mask
if tokenizer.unique_track:
for idx, is_note in enumerate(mask_pitch):
if (
is_note
and idx > 0
and tokenizer[tokens[idx - 1]] == "Program_-1"
):
mask_pitch[idx] = False
if len(note_off_tokens) > 0:
note_off_tokens[idx] = False
else:
pitch_tokens = np.array(tokenizer.token_ids_of_type("Pitch", pitch_voc_idx))
mask_pitch = np.full_like(tokens, 0, dtype=np.bool_)
Expand Down
4 changes: 1 addition & 3 deletions miditok/midi_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,9 +1025,7 @@ def learn_bpe(
sample["ids"], as_one_str=True
) # list of str (bytes)
iterator += (
[[byte_] for byte_ in bytes_]
if not self.unique_track
else [bytes_]
[[byte_] for byte_ in bytes_] if not self.unique_track else [bytes_]
)

# This doesn't seem to work, the trainer pre-processes the sequences, but then no word remains
Expand Down
3 changes: 2 additions & 1 deletion miditok/tokenizations/cp_word.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,8 @@ def _create_base_vocabulary(self) -> List[List[str]]:
# PROGRAM
if self.config.use_programs:
vocab += [
["Ignore_None"] + [f"Program_{program}" for program in range(-1, 128)]
["Ignore_None"]
+ [f"Program_{program}" for program in self.config.programs]
]

# CHORD
Expand Down
2 changes: 1 addition & 1 deletion miditok/tokenizations/midi_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def _create_base_vocabulary(self) -> List[str]:

# PROGRAM
if self.config.use_programs:
vocab += [f"Program_{program}" for program in range(-1, 128)]
vocab += [f"Program_{program}" for program in self.config.programs]

return vocab

Expand Down
2 changes: 1 addition & 1 deletion miditok/tokenizations/octuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def tokens_to_midi(
current_time_sig_tick = 0
current_time_sig_bar = 0

tracks = dict([(n, []) for n in range(-1, 128)])
tracks = dict([(n, []) for n in self.config.programs])
for time_step in tokens:
if any(tok.split("_")[1] == "None" for tok in time_step[:6]):
continue # Either padding, mask: error of prediction or end of sequence anyway
Expand Down
2 changes: 1 addition & 1 deletion miditok/tokenizations/remi.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def _create_base_vocabulary(self, sos_eos_tokens: bool = None) -> List[str]:

# PROGRAM
if self.config.use_programs:
vocab += [f"Program_{program}" for program in range(-1, 128)]
vocab += [f"Program_{program}" for program in self.config.programs]

return vocab

Expand Down
2 changes: 1 addition & 1 deletion miditok/tokenizations/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def _create_base_vocabulary(self, sos_eos_tokens: bool = None) -> List[str]:

# PROGRAM
if self.config.use_programs:
vocab += [f"Program_{program}" for program in range(-1, 128)]
vocab += [f"Program_{program}" for program in self.config.programs]

return vocab

Expand Down
2 changes: 1 addition & 1 deletion miditok/tokenizations/tsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def _create_base_vocabulary(self, sos_eos_tokens: bool = False) -> List[str]:

# PROGRAM
if self.config.use_programs:
vocab += [f"Program_{program}" for program in range(-1, 128)]
vocab += [f"Program_{program}" for program in self.config.programs]

# TIME SIGNATURE
if self.config.use_time_signatures:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
author="Nathan Fradet",
url="https://github.com/Natooz/MidiTok",
packages=find_packages(exclude=("tests",)),
version="2.1.0",
version="2.1.1",
license="MIT",
description="A convenient MIDI tokenizer for Deep Learning networks, with multiple encoding strategies",
long_description=long_description,
Expand Down

0 comments on commit 30d5546

Please sign in to comment.