diff --git a/homr/accidental_detection.py b/homr/accidental_detection.py index 3b44180..3c38710 100644 --- a/homr/accidental_detection.py +++ b/homr/accidental_detection.py @@ -1,6 +1,6 @@ from homr import constants from homr.bounding_boxes import RotatedBoundingBox -from homr.model import Accidental, Prediction, Staff +from homr.model import Accidental, Staff def add_accidentals_to_staffs( @@ -32,8 +32,7 @@ def add_accidentals_to_staffs( position = point.find_position_in_unit_sizes(accidental) accidental_bbox = accidental.to_bounding_box() - prediction = Prediction({}, 0) - clef_symbol = Accidental(accidental_bbox, prediction, position) + clef_symbol = Accidental(accidental_bbox, position) staff.add_symbol(clef_symbol) result.append(clef_symbol) return result diff --git a/homr/model.py b/homr/model.py index cc30c17..48228fc 100644 --- a/homr/model.py +++ b/homr/model.py @@ -1,7 +1,6 @@ from abc import abstractmethod from collections.abc import Callable from enum import Enum -from typing import Any import cv2 import numpy as np @@ -16,28 +15,10 @@ RotatedBoundingBox, ) from homr.circle_of_fifths import get_circle_of_fifth_notes -from homr.results import ( - ClefType, - ResultClef, - ResultDuration, - ResultNote, - ResultNoteGroup, - ResultPitch, - ResultRest, - ResultSymbol, -) +from homr.results import ClefType, ResultPitch from homr.type_definitions import NDArray -class Prediction: - def __init__(self, result: dict[Any, Any], best: Any) -> None: - self.result = result - self.best = best - - def __str__(self) -> str: - return str(self.result) - - class InputPredictions: def __init__( self, @@ -62,10 +43,6 @@ class SymbolOnStaff(DebugDrawable): def __init__(self, center: tuple[float, float]) -> None: self.center = center - @abstractmethod - def to_result(self) -> ResultSymbol | None: - pass - @abstractmethod def copy(self) -> Self: pass @@ -84,54 +61,17 @@ def calc_x_distance_to(self, point: tuple[float, float]) -> float: return abs(self.center[0] - point[0]) -class AccidentalType(Enum): - SHARP = 1 - FLAT = 2 - NATURAL = 3 - - def __str__(self) -> str: - if self == AccidentalType.SHARP: - return "#" - elif self == AccidentalType.FLAT: - return "b" - elif self == AccidentalType.NATURAL: - return "n" - else: - raise Exception("Unknown AccidentalType") - - def to_alter(self) -> int: - if self == AccidentalType.SHARP: - return 1 - elif self == AccidentalType.FLAT: - return -1 - elif self == AccidentalType.NATURAL: - return 0 - else: - raise Exception("Unknown AccidentalType") - - class Accidental(SymbolOnStaff): - def __init__(self, box: BoundingBox, prediction: Prediction, position: int) -> None: + def __init__(self, box: BoundingBox, position: int) -> None: super().__init__(box.center) self.box = box - self.prediction = prediction self.position = position - def get_accidental_type(self) -> AccidentalType: - if self.prediction.best == "sharp": - return AccidentalType.SHARP - elif self.prediction.best == "flat": - return AccidentalType.FLAT - elif self.prediction.best == "natural": - return AccidentalType.NATURAL - else: - raise Exception("Unknown accidental type " + self.prediction.best) - def draw_onto_image(self, img: NDArray, color: tuple[int, int, int] = (255, 0, 0)) -> None: self.box.draw_onto_image(img, color) cv2.putText( img, - str(self.prediction.best) + "-" + str(self.position), + "accidental-" + str(self.position), (int(self.box.box[0]), int(self.box.box[1])), cv2.FONT_HERSHEY_SIMPLEX, 1, @@ -140,9 +80,6 @@ def draw_onto_image(self, img: NDArray, color: tuple[int, int, int] = (255, 0, 0 cv2.LINE_AA, ) - def to_result(self) -> None: - return None - def __str__(self) -> str: return "Accidental(" + str(self.center) + ")" @@ -150,21 +87,20 @@ def __repr__(self) -> str: return str(self) def copy(self) -> "Accidental": - return Accidental(self.box, self.prediction, self.position) + return Accidental(self.box, self.position) class Rest(SymbolOnStaff): - def __init__(self, box: BoundingBox, prediction: Prediction) -> None: + def __init__(self, box: BoundingBox) -> None: super().__init__(box.center) self.box = box - self.prediction = prediction self.has_dot = False def draw_onto_image(self, img: NDArray, color: tuple[int, int, int] = (255, 0, 0)) -> None: self.box.draw_onto_image(img, color) cv2.putText( img, - str(self.prediction.best), + "rest", (int(self.box.box[0]), int(self.box.box[1])), cv2.FONT_HERSHEY_SIMPLEX, 1, @@ -173,22 +109,6 @@ def draw_onto_image(self, img: NDArray, color: tuple[int, int, int] = (255, 0, 0 cv2.LINE_AA, ) - def get_duration(self) -> int: - text = self.prediction.best - duration_dict = { - "rest_whole": 4 * constants.duration_of_quarter, - "rest_half": 2 * constants.duration_of_quarter, - "rest_quarter": constants.duration_of_quarter, - "rest_8th": constants.duration_of_quarter // 2, - "rest_16th": constants.duration_of_quarter // 4, - "rest_32nd": constants.duration_of_quarter // 8, - "rest_64th": constants.duration_of_quarter // 16, - } - return duration_dict.get(text, 0) # returns 0 if text is not found in the dictionary - - def to_result(self) -> ResultRest: - return ResultRest(ResultDuration(self.get_duration(), self.has_dot)) - def __str__(self) -> str: return "Rest(" + str(self.center) + ")" @@ -196,7 +116,7 @@ def __repr__(self) -> str: return str(self) def copy(self) -> "Rest": - return Rest(self.box, self.prediction) + return Rest(self.box) class StemDirection(Enum): @@ -230,9 +150,7 @@ def __init__(self, step: str, alter: int | None, octave: int): self.alter = None self.octave = int(octave) - def move_by_position( - self, position: int, accidental: Accidental | None, circle_of_fifth: int - ) -> "Pitch": + def move_by_position(self, position: int, circle_of_fifth: int) -> "Pitch": # Find the current position of the note in the scale current_position = note_names.index(self.step) @@ -252,15 +170,6 @@ def move_by_position( else: alter = 1 - if accidental is not None: - accidental_type = accidental.get_accidental_type() - if accidental_type == AccidentalType.SHARP: - alter = 1 - elif accidental_type == AccidentalType.FLAT: - alter = -1 - elif accidental_type == AccidentalType.NATURAL: - alter = 0 - return Pitch(new_step, alter, new_octave) def to_result(self) -> ResultPitch: @@ -270,16 +179,11 @@ def copy(self) -> "Pitch": return Pitch(self.step, self.alter, self.octave) -reference_pitch_f_clef = Pitch("F", 0, 2) -reference_pitch_g_clef = Pitch("D", 0, 4) - - class Note(SymbolOnStaff): def __init__( self, box: BoundingEllipse, position: int, - notehead_type: Prediction, stem: RotatedBoundingBox | None, stem_direction: StemDirection | None, ): @@ -288,9 +192,8 @@ def __init__( self.position = position self.has_dot = False self.beam_count = 0 - self.notehead_type = notehead_type self.stem = stem - self.clef_type = ClefType.TREBLE + self.clef_type = ClefType.treble() self.circle_of_fifth = 0 self.accidental: Accidental | None = None self.stem_direction = stem_direction @@ -302,7 +205,7 @@ def draw_onto_image(self, img: NDArray, color: tuple[int, int, int] = (255, 0, 0 dot_string = "." if self.has_dot else "" cv2.putText( img, - str(self.notehead_type.best) + dot_string + str(self.position), + "note" + dot_string + str(self.position), (int(self.box.center[0]), int(self.box.center[1])), cv2.FONT_HERSHEY_SIMPLEX, 1, @@ -322,54 +225,16 @@ def get_pitch( ) -> Pitch: clef_type = self.clef_type if clef_type is None else clef_type circle_of_fifth = self.circle_of_fifth if circle_of_fifth is None else circle_of_fifth - reference = ( - reference_pitch_g_clef if clef_type == ClefType.TREBLE else reference_pitch_f_clef - ) - return reference.move_by_position(self.position, self.accidental, circle_of_fifth) - - def _adjust_duration_with_dot(self, duration: int) -> int: - if self.has_dot: - return duration * 3 // 2 - else: - return duration - - def _get_base_duration(self) -> int: - if self.notehead_type.best == NoteHeadType.HOLLOW: - if self.stem is None: - return 4 * constants.duration_of_quarter - else: - return 2 * constants.duration_of_quarter - if self.notehead_type.best == NoteHeadType.SOLID: - if self.stem is None: - # TODO that would be an odd result, return a quarter note - return constants.duration_of_quarter - elif self.beam_count == 0: - return constants.duration_of_quarter - elif self.beam_count > 0: - # TODO take the note count into account - return constants.duration_of_quarter // 2 - raise Exception( - "Unknown notehead type " - + str(self.notehead_type.best) - + " " - + str(self.stem) - + " " - + str(self.beam_count) - ) - - def get_duration(self) -> int: - return self._adjust_duration_with_dot(self._get_base_duration()) - - def to_result(self) -> ResultNote: - return ResultNote( - self.get_pitch().to_result(), ResultDuration(self.get_duration(), self.has_dot) - ) + reference = clef_type.get_reference_pitch() + reference_pitch = Pitch(reference.step, reference.alter, reference.octave) + # Position + 1 as the model uses a higher reference point on the staff + return reference_pitch.move_by_position(self.position + 1, circle_of_fifth) def to_tr_omr_note(self, clef_type: ClefType) -> str: pitch = self.get_pitch(clef_type=clef_type).to_result() - duration = ResultDuration(self.get_duration(), self.has_dot) - return "note-" + str(pitch) + "_" + str(duration) + # We have no information about the duration here and default to quarter + return "note-" + str(pitch) + "_quarter" def __str__(self) -> str: return "Note(" + str(self.center) + ", " + str(self.position) + ")" @@ -378,7 +243,7 @@ def __repr__(self) -> str: return str(self) def copy(self) -> "Note": - return Note(self.box, self.position, self.notehead_type, self.stem, self.stem_direction) + return Note(self.box, self.position, self.stem, self.stem_direction) class NoteGroup(SymbolOnStaff): @@ -388,9 +253,6 @@ def __init__(self, notes: list[Note]) -> None: # sort notes by pitch, highest position first self.notes = sorted(notes, key=lambda note: note.position, reverse=True) - def to_result(self) -> ResultNoteGroup: - return ResultNoteGroup([note.to_result() for note in self.notes]) - def draw_onto_image(self, img: NDArray, color: tuple[int, int, int] = (255, 0, 0)) -> None: for note in self.notes: note.draw_onto_image(img, color) @@ -421,9 +283,6 @@ def __init__(self, box: RotatedBoundingBox): def draw_onto_image(self, img: NDArray, color: tuple[int, int, int] = (255, 0, 0)) -> None: self.box.draw_onto_image(img, color) - def to_result(self) -> None: - return None - def __str__(self) -> str: return "BarLine(" + str(self.center) + ")" @@ -435,17 +294,16 @@ def copy(self) -> "BarLine": class Clef(SymbolOnStaff): - def __init__(self, box: BoundingBox, prediction: Prediction): + def __init__(self, box: BoundingBox): super().__init__(box.center) self.box = box - self.prediction = prediction self.accidentals: list[Accidental] = [] def draw_onto_image(self, img: NDArray, color: tuple[int, int, int] = (255, 0, 0)) -> None: self.box.draw_onto_image(img, color) cv2.putText( img, - str(self.prediction.best), + "clef", (self.box.box[0], self.box.box[1]), cv2.FONT_HERSHEY_SIMPLEX, 1, @@ -454,31 +312,6 @@ def draw_onto_image(self, img: NDArray, color: tuple[int, int, int] = (255, 0, 0 cv2.LINE_AA, ) - def get_clef_type(self) -> ClefType: - if self.prediction.best == "gclef": - return ClefType.TREBLE - elif self.prediction.best == "fclef": - return ClefType.BASS - else: - raise Exception("Unknown clef type " + self.prediction.best) - - def get_circle_of_fifth(self) -> int: - if len(self.accidentals) == 0: - return 0 - accidental_types = [accidental.get_accidental_type() for accidental in self.accidentals] - max_accidental = max(accidental_types, key=accidental_types.count) - if max_accidental == AccidentalType.SHARP: - return len(self.accidentals) - elif max_accidental == AccidentalType.FLAT: - return -len(self.accidentals) - else: - # TODO we could likely work with this as it implies that there - # are wrong detections, might be that naturals more look like sharps - return 0 - - def to_result(self) -> ResultClef: - return ResultClef(self.get_clef_type(), self.get_circle_of_fifth()) - def __str__(self) -> str: return "Clef(" + str(self.center) + ")" @@ -486,7 +319,7 @@ def __repr__(self) -> str: return str(self) def copy(self) -> "Clef": - return Clef(self.box, self.prediction) + return Clef(self.box) class StaffPoint: diff --git a/homr/note_detection.py b/homr/note_detection.py index 2addfd1..4976760 100644 --- a/homr/note_detection.py +++ b/homr/note_detection.py @@ -3,15 +3,7 @@ from homr import constants from homr.bounding_boxes import BoundingEllipse, DebugDrawable, RotatedBoundingBox -from homr.model import ( - Note, - NoteGroup, - NoteHeadType, - Prediction, - Staff, - StemDirection, - SymbolOnStaff, -) +from homr.model import Note, NoteGroup, Staff, StemDirection, SymbolOnStaff from homr.simple_logging import eprint from homr.type_definitions import NDArray @@ -221,7 +213,7 @@ def add_notes_to_staffs( ): continue position = point.find_position_in_unit_sizes(notehead.notehead) - note = create_detailed_note(notehead, symbols, position) + note = Note(notehead.notehead, position, notehead.stem, notehead.stem_direction) result.append(note) staff.add_symbol(note) number_of_notes = 0 @@ -238,18 +230,3 @@ def add_notes_to_staffs( "note groups", ) return result - - -def create_detailed_note(note: NoteheadWithStem, symbols: NDArray, position: int) -> Note: - ratio = note.notehead.get_color_ratio(symbols) - notehead_type = ( - NoteHeadType.HOLLOW if ratio < constants.notehead_type_threshold else NoteHeadType.SOLID - ) - ratio_as_prediction = Prediction( - { - NoteHeadType.HOLLOW: ratio, - NoteHeadType.SOLID: 1 - ratio, - }, - notehead_type, - ) - return Note(note.notehead, position, ratio_as_prediction, note.stem, note.stem_direction) diff --git a/homr/rest_detection.py b/homr/rest_detection.py index cd16626..d1e47d6 100644 --- a/homr/rest_detection.py +++ b/homr/rest_detection.py @@ -2,7 +2,7 @@ from homr import constants from homr.bounding_boxes import RotatedBoundingBox -from homr.model import Prediction, Rest, Staff +from homr.model import Rest, Staff def add_rests_to_staffs(staffs: list[Staff], rests: list[RotatedBoundingBox]) -> list[Rest]: @@ -36,8 +36,7 @@ def add_rests_to_staffs(staffs: list[Staff], rests: list[RotatedBoundingBox]) -> continue bbox = rest.to_bounding_box() - prediction = Prediction({}, 0) - rest_symbol = Rest(bbox, prediction) + rest_symbol = Rest(bbox) staff.add_symbol(rest_symbol) result.append(rest_symbol) return result diff --git a/homr/results.py b/homr/results.py index 6a92f7a..2342fc1 100644 --- a/homr/results.py +++ b/homr/results.py @@ -1,5 +1,3 @@ -from enum import Enum - import numpy as np from homr import constants @@ -11,21 +9,79 @@ def __init__(self) -> None: pass -class ClefType(Enum): - TREBLE = 1 - BASS = 2 +class ClefType: + @staticmethod + def treble() -> "ClefType": + return ClefType(sign="G", line=2) + + @staticmethod + def bass() -> "ClefType": + return ClefType(sign="F", line=4) + + def __init__(self, sign: str, line: int) -> None: + """ + Why we don't support more clef types, e.g. the other examples given in + https://www.w3.org/2021/06/musicxml40/musicxml-reference/elements/clef/: + + Since the other clef types share the same symbol with one of the ones we support, + we have to expect that the there is a lot of misdetections and this would degrade the + performance. + + E.g. if the treble and french violin (https://en.wikipedia.org/wiki/Clef) + are easily confused. If support french violin then we will have cases where # + the treble clef is detected as french violin and then the pitch will be wrong. + + If we get more training data and a reliable detecton of the rarer clef types, + we can add them here. + """ + self.sign = sign.upper() + + if self.sign not in ["G", "F", "C"]: + raise Exception("Unknown clef sign " + sign) + + # Extend get_reference_pitch if you add more clef types + treble_clef_line = 2 + bass_clef_line = 4 + alto_clef_line = 3 + if sign == "G" and line != treble_clef_line: + eprint("Unsupported treble clef line", line) + self.line = treble_clef_line + elif sign == "F" and line != bass_clef_line: + eprint("Unsupported bass clef line", line) + self.line = bass_clef_line + elif sign == "C" and line != alto_clef_line: + eprint("Unsupported alto clef line", line) + self.line = alto_clef_line + else: + self.line = line - def __str__(self) -> str: - if self == ClefType.TREBLE: - return "G" - elif self == ClefType.BASS: - return "F" + def __eq__(self, __value: object) -> bool: + if isinstance(__value, ClefType): + return self.sign == __value.sign and self.line == __value.line else: - raise Exception("Unknown ClefType") + return False + + def __hash__(self) -> int: + return hash((self.sign, self.line)) + + def __str__(self) -> str: + return f"{self.sign}{self.line}" def __repr__(self) -> str: return str(self) + def get_reference_pitch(self) -> "ResultPitch": + if self.sign == "G": + g2 = ResultPitch("C", 4, None) + return g2.move_by(2 * (self.line - 2), None) + elif self.sign == "F": + e2 = ResultPitch("E", 2, None) + return e2.move_by(2 * (self.line - 4), None) + elif self.sign == "C": + c3 = ResultPitch("C", 3, None) + return c3.move_by(2 * (self.line - 3), None) + raise ValueError("Unknown clef sign " + str(self)) + class ResultTimeSignature(ResultSymbol): def __init__(self, time_signature: str) -> None: @@ -92,6 +148,12 @@ def get_relative_position(self, other: "ResultPitch") -> int: - note_names.index(other.step) ) + def move_by(self, steps: int, alter: int | None) -> "ResultPitch": + step_index = (note_names.index(self.step) + steps) % 7 + step = note_names[step_index] + octave = self.octave + abs(steps - step_index) // 6 * np.sign(steps) + return ResultPitch(step, octave, alter) + def get_pitch_from_relative_position( reference_pitch: ResultPitch, relative_position: int, alter: int | None @@ -129,12 +191,7 @@ def __repr__(self) -> str: return str(self) def get_reference_pitch(self) -> ResultPitch: - if self.clef_type == ClefType.TREBLE: - return ResultPitch("D", 4, None) - elif self.clef_type == ClefType.BASS: - return ResultPitch("F", 2, None) - else: - raise Exception("Unknown ClefType " + str(self.clef_type)) + return self.clef_type.get_reference_pitch() def move_pitch_to_clef( diff --git a/homr/staff_parsing.py b/homr/staff_parsing.py index 9c9a641..80f9e0c 100644 --- a/homr/staff_parsing.py +++ b/homr/staff_parsing.py @@ -4,7 +4,7 @@ from homr import constants from homr.debug import Debug from homr.image_utils import crop_image_and_return_new_top -from homr.model import Clef, InputPredictions, MultiStaff, NoteGroup, Staff +from homr.model import InputPredictions, MultiStaff, NoteGroup, Staff from homr.results import ( ResultClef, ResultMeasure, @@ -249,13 +249,6 @@ def transform_coordinates(point: tuple[float, float]) -> tuple[float, float]: return staff.transform_coordinates(transform_coordinates) -def move_key_information(staff: Staff, destination: ResultStaff) -> None: - source = [clef for clef in staff.symbols if isinstance(clef, Clef)] - dest = [clef for clef in destination.get_symbols() if isinstance(clef, ResultClef)] - for i in range(min(len(source), len(dest))): - dest[i].circle_of_fifth = source[i].get_circle_of_fifth() - - def parse_staff_image( debug: Debug, ranges: list[float], index: int, staff: Staff, predictions: InputPredictions ) -> ResultStaff | None: diff --git a/homr/staff_parsing_tromr.py b/homr/staff_parsing_tromr.py index 47018c8..18aecea 100644 --- a/homr/staff_parsing_tromr.py +++ b/homr/staff_parsing_tromr.py @@ -1,3 +1,4 @@ +import re from collections import Counter import cv2 @@ -111,15 +112,10 @@ def _number_of_accidentals_in_model(staff: Staff) -> int: def _get_clef_type(result: str) -> ClefType | None: - g2_index = result.find("clef-G2") - f4_index = result.find("clef-F4") - - if g2_index == -1 and f4_index == -1: + match = re.search(r"clef-([A-G])([0-9])", result) + if match is None: return None - elif g2_index != -1 and (f4_index == -1 or g2_index < f4_index): - return ClefType.TREBLE - else: - return ClefType.BASS + return ClefType(match.group(1), int(match.group(2))) def _flatten_result(result: list[str]) -> list[str]: diff --git a/homr/tr_omr_parser.py b/homr/tr_omr_parser.py index d1181be..4d4b4d1 100644 --- a/homr/tr_omr_parser.py +++ b/homr/tr_omr_parser.py @@ -41,7 +41,7 @@ def number_of_accidentals(self) -> int: def parse_clef(self, clef: str) -> ResultClef: parts = clef.split("-") clef_type_str = parts[1] - clef_type = ClefType.TREBLE if clef_type_str.startswith("G") else ClefType.BASS + clef_type = ClefType(clef_type_str[0], int(clef_type_str[1])) self._clefs.append(clef_type) return ResultClef(clef_type, 0) diff --git a/homr/xml_generator.py b/homr/xml_generator.py index 5b80bf4..76a5f2d 100644 --- a/homr/xml_generator.py +++ b/homr/xml_generator.py @@ -2,7 +2,6 @@ from homr import constants from homr.results import ( - ClefType, ResultClef, ResultMeasure, ResultNote, @@ -73,12 +72,8 @@ def build_clef(model_clef: ResultClef) -> mxl.XMLAttributes: # type: ignore key.add_child(fifth) clef = mxl.XMLClef() attributes.add_child(clef) - if model_clef.clef_type == ClefType.TREBLE: - clef.add_child(mxl.XMLSign(value_="G")) - clef.add_child(mxl.XMLLine(value_=2)) - else: - clef.add_child(mxl.XMLSign(value_="F")) - clef.add_child(mxl.XMLLine(value_=4)) + clef.add_child(mxl.XMLSign(value_=model_clef.clef_type.sign)) + clef.add_child(mxl.XMLLine(value_=model_clef.clef_type.line)) return attributes diff --git a/tests/test_result_model.py b/tests/test_result_model.py index 04f1a4c..827e737 100644 --- a/tests/test_result_model.py +++ b/tests/test_result_model.py @@ -6,8 +6,8 @@ class TestResultModel(unittest.TestCase): def test_change_staff(self) -> None: - treble = ResultClef(ClefType.TREBLE, 1) - bass = ResultClef(ClefType.BASS, 1) + treble = ResultClef(ClefType.treble(), 1) + bass = ResultClef(ClefType.bass(), 1) self.assertEqual( str(move_pitch_to_clef(treble.get_reference_pitch(), treble, bass)), str(bass.get_reference_pitch()), @@ -33,3 +33,24 @@ def test_change_staff(self) -> None: self.assertEqual(str(move_pitch_to_clef(ResultPitch("B", 3, None), bass, treble)), "G5") self.assertEqual(str(move_pitch_to_clef(ResultPitch("C", 4, 1), bass, treble)), "A5#") + + def test_move_pitch(self) -> None: + note_c4 = ResultPitch("C", 4, None) + self.assertEqual(note_c4.move_by(1, None), ResultPitch("D", 4, None)) + self.assertEqual(note_c4.move_by(2, None), ResultPitch("E", 4, None)) + self.assertEqual(note_c4.move_by(3, None), ResultPitch("F", 4, None)) + self.assertEqual(note_c4.move_by(4, None), ResultPitch("G", 4, None)) + self.assertEqual(note_c4.move_by(5, None), ResultPitch("A", 4, None)) + self.assertEqual(note_c4.move_by(6, None), ResultPitch("B", 4, None)) + self.assertEqual(note_c4.move_by(7, None), ResultPitch("C", 5, None)) + self.assertEqual(note_c4.move_by(8, None), ResultPitch("D", 5, None)) + + note_d4 = ResultPitch("D", 4, None) + self.assertEqual(note_d4.move_by(0, None), ResultPitch("D", 4, None)) + self.assertEqual(note_d4.move_by(1, None), ResultPitch("E", 4, None)) + self.assertEqual(note_d4.move_by(2, None), ResultPitch("F", 4, None)) + self.assertEqual(note_d4.move_by(3, None), ResultPitch("G", 4, None)) + self.assertEqual(note_d4.move_by(4, None), ResultPitch("A", 4, None)) + self.assertEqual(note_d4.move_by(5, None), ResultPitch("B", 4, None)) + self.assertEqual(note_d4.move_by(6, None), ResultPitch("C", 5, None)) + self.assertEqual(note_d4.move_by(7, None), ResultPitch("D", 5, None)) diff --git a/tests/test_tr_omr_parser.py b/tests/test_tr_omr_parser.py index f7852f1..473360c 100644 --- a/tests/test_tr_omr_parser.py +++ b/tests/test_tr_omr_parser.py @@ -26,7 +26,7 @@ def test_parsing(self) -> None: [ ResultMeasure( [ - ResultClef(ClefType.TREBLE, -1), + ResultClef(ClefType.treble(), -1), ResultTimeSignature("4/4"), ResultNote( ResultPitch("A", 4, None), @@ -71,7 +71,7 @@ def test_parsing_no_final_bar_line(self) -> None: [ ResultMeasure( [ - ResultClef(ClefType.TREBLE, -1), + ResultClef(ClefType.treble(), -1), ResultTimeSignature("4/4"), ResultNote( ResultPitch("A", 4, None), @@ -148,7 +148,7 @@ def test_note_group_parsing(self) -> None: [ ResultMeasure( [ - ResultClef(ClefType.TREBLE, 0), + ResultClef(ClefType.treble(), 0), ResultNote( ResultPitch("D", 4, None), ResultDuration(constants.duration_of_quarter, False), @@ -280,7 +280,7 @@ def test_accidental_parsing(self) -> None: [ ResultMeasure( [ - ResultClef(ClefType.TREBLE, 2), + ResultClef(ClefType.treble(), 2), ResultNote( ResultPitch("D", 4, None), ResultDuration(constants.duration_of_quarter, False), diff --git a/training/music_xml.py b/training/music_xml.py index f8f4a9f..54b80f4 100644 --- a/training/music_xml.py +++ b/training/music_xml.py @@ -179,6 +179,27 @@ def _count_dots(note: mxl.XMLNote) -> str: # type: ignore return "." * len(dots) +def _get_triplet_mark(note: mxl.XMLNote) -> str: # type: ignore + time_modification = note.get_children_of_type(mxl.XSDComplexTypeTimeModification) + if len(time_modification) == 0: + return "" + actual_notes = time_modification[0].get_children_of_type(mxl.XMLActualNotes) + if len(actual_notes) == 0: + return "" + normal_notes = time_modification[0].get_children_of_type(mxl.XMLNormalNotes) + if len(normal_notes) == 0: + return "" + is_triplet = ( + int(actual_notes[0].value_) == 3 and int(normal_notes[0].value_) == 2 # noqa: PLR2004 + ) + is_sixtuplet = ( + int(actual_notes[0].value_) == 6 and int(normal_notes[0].value_) == 4 # noqa: PLR2004 + ) + if is_triplet or is_sixtuplet: + return "3" + return "" + + def _process_attributes( # type: ignore semantic: SemanticPart, attribute: mxl.XMLAttributes, key: KeyTransformation ) -> KeyTransformation: @@ -253,6 +274,7 @@ def _process_note( # type: ignore + "_" + _translate_duration(duration_type) + _count_dots(note), + # + _get_triplet_mark(note), ) return key diff --git a/training/transformer/train.py b/training/transformer/train.py index 28ac0ae..e0ea9cb 100644 --- a/training/transformer/train.py +++ b/training/transformer/train.py @@ -31,7 +31,7 @@ def load_training_index(file_path: str) -> list[str]: def contains_supported_clef(semantic: str) -> bool: if semantic.count("clef-") != 1: return False - return "clef-G2" in semantic or "clef-F4" in semantic + return True def filter_for_clefs(file_paths: list[str]) -> list[str]: