Skip to content

Commit

Permalink
Prepared for triplets
Browse files Browse the repository at this point in the history
This change requires us to retrain the transformer in order to take effect
  • Loading branch information
liebharc committed Jul 3, 2024
1 parent 1a2b944 commit 8a4d309
Show file tree
Hide file tree
Showing 7 changed files with 217 additions and 203 deletions.
16 changes: 3 additions & 13 deletions homr/accidental_rules.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
from homr.circle_of_fifths import get_circle_of_fifth_notes
from homr.results import (
ResultClef,
ResultNote,
ResultNoteGroup,
ResultPitch,
ResultStaff,
)
from homr.results import ResultChord, ResultClef, ResultNote, ResultPitch, ResultStaff


def _keep_accidentals_until_cancelled(staff: ResultStaff) -> None:
Expand All @@ -18,9 +12,7 @@ def _keep_accidentals_until_cancelled(staff: ResultStaff) -> None:
for symbol in measure.symbols:
if isinstance(symbol, ResultClef):
accidentals = {}
elif isinstance(symbol, ResultNote):
_process_note(symbol, accidentals)
elif isinstance(symbol, ResultNoteGroup):
elif isinstance(symbol, ResultChord):
for note in symbol.notes:
_process_note(note, accidentals)

Expand Down Expand Up @@ -49,9 +41,7 @@ def _apply_key_signature(staff: ResultStaff) -> None:
if isinstance(symbol, ResultClef):
circle_of_fifth = symbol.circle_of_fifth
circle_of_fifth_notes = get_circle_of_fifth_notes(circle_of_fifth)
elif isinstance(symbol, ResultNote):
_apply_key_to_pitch(symbol.pitch, circle_of_fifth, circle_of_fifth_notes)
elif isinstance(symbol, ResultNoteGroup):
elif isinstance(symbol, ResultChord):
for note in symbol.notes:
_apply_key_to_pitch(note.pitch, circle_of_fifth, circle_of_fifth_notes)

Expand Down
91 changes: 51 additions & 40 deletions homr/results.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from enum import Enum

import numpy as np

from homr import constants
Expand Down Expand Up @@ -228,52 +230,57 @@ def get_duration_name(duration: int) -> str:
return result


class ResultDuration:
def __init__(self, duration: int, has_dot: bool):
self.duration = duration
self.has_dot = has_dot
self.duration_name = get_duration_name(self.base_duration())
class DurationModifier(Enum):
NONE = 0
DOT = 1
TRIPLET = 2

def base_duration(self) -> int:
return self.duration * 2 // 3 if self.has_dot else self.duration
def __init__(self, duration: int) -> None:
self.duration = duration

def __eq__(self, __value: object) -> bool:
if isinstance(__value, ResultDuration):
return self.duration == __value.duration and self.has_dot == __value.has_dot
def __str__(self) -> str:
if self.duration == self.NONE:
return ""
elif self.duration == self.DOT:
return "."
elif self.duration == self.TRIPLET:
return "³"
else:
return False

def __hash__(self) -> int:
return hash((self.duration, self.has_dot))
return "Invalid duration"

def __str__(self) -> str:
return f"{self.duration_name}{'.' if self.has_dot else ''}"

def __repr__(self) -> str:
return str(self)
def _adjust_duration(duration: int, modifier: DurationModifier) -> int:
if modifier == DurationModifier.DOT:
return duration * 3 // 2
elif modifier == DurationModifier.TRIPLET:
return duration * 2 // 3
else:
return duration


class ResultRest(ResultSymbol):
def __init__(self, duration: ResultDuration):
self.duration = duration
class ResultDuration:
def __init__(self, base_duration: int, modifier: DurationModifier = DurationModifier.NONE):
self.duration = _adjust_duration(base_duration, modifier)
self.modifier = modifier
self.duration_name = get_duration_name(base_duration)

def __eq__(self, __value: object) -> bool:
if isinstance(__value, ResultRest):
return self.duration == __value.duration
if isinstance(__value, ResultDuration):
return self.duration == __value.duration and self.modifier == __value.modifier
else:
return False

def __hash__(self) -> int:
return hash(self.duration)
return hash((self.duration, self.modifier))

def __str__(self) -> str:
return f"R_{self.duration}"
return f"{self.duration_name}{str(self.modifier)}"

def __repr__(self) -> str:
return str(self)


class ResultNote(ResultSymbol):
class ResultNote:
def __init__(self, pitch: ResultPitch, duration: ResultDuration):
self.pitch = pitch
self.duration = duration
Expand All @@ -294,24 +301,30 @@ def __repr__(self) -> str:
return str(self)


class ResultNoteGroup(ResultSymbol):
def __init__(self, notes: list[ResultNote]):
class ResultChord(ResultSymbol):
"""
A chord which contains 0 to many pitches. 0 pitches indicates that this is a rest.
The duration of the chord is the distance to the next chord. The individual pitches
my have a different duration.
"""

def __init__(self, duration: ResultDuration, notes: list[ResultNote]):
self.notes = notes
self.duration = notes[0].duration if len(notes) > 0 else ResultDuration(0, False)
self.duration = duration

@property
def is_rest(self) -> bool:
return len(self.notes) == 0

def __eq__(self, __value: object) -> bool:
if isinstance(__value, ResultNoteGroup):
if len(self.notes) != len(__value.notes):
return False
for i in range(len(self.notes)):
if self.notes[i] != __value.notes[i]:
return False
return True
if isinstance(__value, ResultChord):
return self.duration == __value.duration and self.notes == __value.notes
else:
return False

def __hash__(self) -> int:
return hash(tuple(self.notes))
return hash((self.notes, self.duration))

def __str__(self) -> str:
return f"{'&'.join(map(str, self.notes))}"
Expand Down Expand Up @@ -356,9 +369,7 @@ def __repr__(self) -> str:

def length_in_quarters(self) -> float:
return sum(
symbol.duration.base_duration()
for symbol in self.symbols
if isinstance(symbol, ResultNote | ResultNoteGroup | ResultRest)
symbol.duration.duration for symbol in self.symbols if isinstance(symbol, ResultChord)
)


Expand Down
15 changes: 3 additions & 12 deletions homr/staff_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
from homr.image_utils import crop_image_and_return_new_top
from homr.model import InputPredictions, MultiStaff, NoteGroup, Staff
from homr.results import (
ResultChord,
ResultClef,
ResultMeasure,
ResultNote,
ResultNoteGroup,
ResultStaff,
move_pitch_to_clef,
)
Expand Down Expand Up @@ -284,11 +283,7 @@ def _pick_most_dominant_clef(staff: ResultStaff) -> ResultStaff: # noqa: C901,
if isinstance(symbol, ResultClef):
last_clef_was_originally = ResultClef(symbol.clef_type, 0)
symbol.clef_type = most_frequent_clef_type
elif isinstance(symbol, ResultNote):
symbol.pitch = move_pitch_to_clef(
symbol.pitch, last_clef_was_originally, most_frequent_clef
)
elif isinstance(symbol, ResultNoteGroup):
elif isinstance(symbol, ResultChord):
for note in symbol.notes:
note.pitch = move_pitch_to_clef(
note.pitch, last_clef_was_originally, most_frequent_clef
Expand All @@ -298,11 +293,7 @@ def _pick_most_dominant_clef(staff: ResultStaff) -> ResultStaff: # noqa: C901,
if isinstance(symbol, ResultClef):
last_clef_was_originally = ResultClef(symbol.clef_type, 0)
symbol.clef_type = most_frequent_clef_type
elif isinstance(measure_symbol, ResultNote):
measure_symbol.pitch = move_pitch_to_clef(
measure_symbol.pitch, last_clef_was_originally, most_frequent_clef
)
elif isinstance(measure_symbol, ResultNoteGroup):
elif isinstance(measure_symbol, ResultChord):
for note in measure_symbol.notes:
note.pitch = move_pitch_to_clef(
note.pitch, last_clef_was_originally, most_frequent_clef
Expand Down
60 changes: 27 additions & 33 deletions homr/tr_omr_parser.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from homr import constants
from homr.results import (
ClefType,
DurationModifier,
ResultChord,
ResultClef,
ResultDuration,
ResultMeasure,
ResultNote,
ResultNoteGroup,
ResultPitch,
ResultRest,
ResultStaff,
ResultTimeSignature,
)
Expand Down Expand Up @@ -87,20 +87,30 @@ def parse_duration_name(self, duration_name: str) -> int:
}
return duration_mapping.get(duration_name, constants.duration_of_quarter // 16)

def _adjust_duration_with_dot(self, duration: int, has_dot: bool) -> int:
def parse_duration(self, duration: str) -> ResultDuration:
has_dot = duration.endswith(".")
# We use ³ as triplet indicator as it's not a valid duration name
# or note name and thus we have no risk of confusion
is_triplet = duration.endswith("³")

modifier = DurationModifier.NONE
if has_dot:
return duration * 3 // 2
else:
return duration
duration = duration[:-1]
modifier = DurationModifier.DOT
elif is_triplet:
duration = duration[:-1]
modifier = DurationModifier.TRIPLET
return ResultDuration(
self.parse_duration_name(duration),
modifier,
)

def parse_note(self, note: str) -> ResultNote:
try:
note_details = note.split("-")[1]
pitch_and_duration = note_details.split("_")
pitch = pitch_and_duration[0]
duration = pitch_and_duration[1]
has_dot = duration.endswith(".")
duration_name = duration[:-1] if has_dot else duration
note_name = pitch[0]
octave = int(pitch[1])
alter = None
Expand All @@ -115,22 +125,12 @@ def parse_note(self, note: str) -> ResultNote:
else:
alter = 0

return ResultNote(
ResultPitch(note_name, octave, alter),
ResultDuration(
self._adjust_duration_with_dot(
self.parse_duration_name(duration_name), has_dot
),
has_dot,
),
)
return ResultNote(ResultPitch(note_name, octave, alter), self.parse_duration(duration))
except Exception:
eprint("Failed to parse note: " + note)
return ResultNote(
ResultPitch("C", 4, 0), ResultDuration(constants.duration_of_quarter, False)
)
return ResultNote(ResultPitch("C", 4, 0), ResultDuration(constants.duration_of_quarter))

def parse_notes(self, notes: str) -> ResultNote | ResultNoteGroup | ResultRest | None:
def parse_notes(self, notes: str) -> ResultChord | None:
note_parts = notes.split("|")
note_parts = [note_part for note_part in note_parts if note_part.startswith("note")]
rest_parts = [rest_part for rest_part in note_parts if rest_part.startswith("rest")]
Expand All @@ -139,21 +139,15 @@ def parse_notes(self, notes: str) -> ResultNote | ResultNoteGroup | ResultRest |
return None
else:
return self.parse_rest(rest_parts[0])
if len(note_parts) == 1:
return self.parse_note(note_parts[0])
else:
return ResultNoteGroup([self.parse_note(note_part) for note_part in note_parts])
result_notes = [self.parse_note(note_part) for note_part in note_parts]
return ResultChord(result_notes[0].duration, result_notes)

def parse_rest(self, rest: str) -> ResultRest:
def parse_rest(self, rest: str) -> ResultChord:
rest = rest.split("|")[0]
duration = rest.split("-")[1]
has_dot = duration.endswith(".")
duration_name = duration[:-1] if has_dot else duration
return ResultRest(
ResultDuration(
self._adjust_duration_with_dot(self.parse_duration_name(duration_name), has_dot),
has_dot,
)
return ResultChord(
self.parse_duration(duration),
[],
)

def parse_tr_omr_output(self, output: str) -> ResultStaff: # noqa: C901
Expand Down
29 changes: 18 additions & 11 deletions homr/xml_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from homr import constants
from homr.results import (
DurationModifier,
ResultChord,
ResultClef,
ResultMeasure,
ResultNote,
ResultNoteGroup,
ResultRest,
ResultStaff,
)

Expand Down Expand Up @@ -77,7 +77,7 @@ def build_clef(model_clef: ResultClef) -> mxl.XMLAttributes: # type: ignore
return attributes


def build_rest(model_rest: ResultRest) -> mxl.XMLNote: # type: ignore
def build_rest(model_rest: ResultChord) -> mxl.XMLNote: # type: ignore
note = mxl.XMLNote()
note.add_child(mxl.XMLRest(measure="yes"))
note.add_child(mxl.XMLDuration(value_=model_rest.duration.duration))
Expand All @@ -104,12 +104,17 @@ def build_note(model_note: ResultNote, is_chord=False) -> mxl.XMLNote: # type:
note.add_child(mxl.XMLDuration(value_=model_duration.duration))
note.add_child(mxl.XMLStaff(value_=1))
note.add_child(mxl.XMLVoice(value_="1"))
if model_duration.has_dot:
if model_duration.modifier == DurationModifier.DOT:
note.add_child(mxl.XMLDot())
elif model_duration.modifier == DurationModifier.TRIPLET:
time_modification = mxl.XMLTimeModification()
time_modification.add_child(mxl.XMLActualNotes(value_=3))
time_modification.add_child(mxl.XMLNormalNotes(value_=2))
note.add_child(time_modification)
return note


def build_note_group(note_group: ResultNoteGroup) -> mxl.XMLNote: # type: ignore
def build_note_group(note_group: ResultChord) -> list[mxl.XMLNote]: # type: ignore
result = []
is_first = True
for note in note_group.notes:
Expand All @@ -118,19 +123,21 @@ def build_note_group(note_group: ResultNoteGroup) -> mxl.XMLNote: # type: ignor
return result


def build_chord(chord: ResultChord) -> list[mxl.XMLNote]: # type: ignore
if chord.is_rest:
return [build_rest(chord)]
return build_note_group(chord)


def build_measure(measure: ResultMeasure, measure_number: int) -> mxl.XMLMeasure: # type: ignore
result = mxl.XMLMeasure(number=str(measure_number))
if measure.is_new_line:
result.add_child(mxl.XMLPrint(new_system="yes"))
for symbol in measure.symbols:
if isinstance(symbol, ResultClef):
result.add_child(build_clef(symbol))
elif isinstance(symbol, ResultRest):
result.add_child(build_rest(symbol))
elif isinstance(symbol, ResultNote):
result.add_child(build_note(symbol))
elif isinstance(symbol, ResultNoteGroup):
for element in build_note_group(symbol):
elif isinstance(symbol, ResultChord):
for element in build_chord(symbol):
result.add_child(element)
return result

Expand Down
Loading

0 comments on commit 8a4d309

Please sign in to comment.