Skip to content

Commit

Permalink
Added paramters which allow to customize the generated music xml
Browse files Browse the repository at this point in the history
  • Loading branch information
liebharc committed Jul 15, 2024
1 parent 784d431 commit b05a8be
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 13 deletions.
28 changes: 23 additions & 5 deletions homr/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from homr.title_detection import detect_title
from homr.transformer.configs import default_config
from homr.type_definitions import NDArray
from homr.xml_generator import generate_xml
from homr.xml_generator import XmlGeneratorArguments, generate_xml

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

Expand Down Expand Up @@ -138,7 +138,10 @@ def predict_symbols(debug: Debug, predictions: InputPredictions) -> PredictedSym


def process_image( # noqa: PLR0915
image_path: str, enable_debug: bool, enable_cache: bool
image_path: str,
enable_debug: bool,
enable_cache: bool,
xml_generator_args: XmlGeneratorArguments,
) -> tuple[str, str, str]:
eprint("Processing " + image_path)
predictions, debug = load_and_preprocess_predictions(image_path, enable_debug, enable_cache)
Expand Down Expand Up @@ -233,7 +236,7 @@ def process_image( # noqa: PLR0915
result_staffs = maintain_accidentals(result_staffs)

eprint("Writing XML")
xml = generate_xml(result_staffs, title)
xml = generate_xml(xml_generator_args, result_staffs, title)
xml.write(xml_file)

eprint(
Expand Down Expand Up @@ -314,27 +317,42 @@ def main() -> None:
parser.add_argument(
"--cache", action="store_true", help="Read an existing cache file or create a new one"
)
parser.add_argument(
"--output-large-page",
action="store_true",
help="Adds instructions to the musicxml so that it gets rendered on larger pages",
)
parser.add_argument(
"--output-metronome", type=int, help="Adds a metronome to the musicxml with the given bpm"
)
parser.add_argument(
"--output-tempo", type=int, help="Adds a tempo to the musicxml with the given bpm"
)
args = parser.parse_args()

download_weights()
if args.init:
eprint("Init finished")
return

xml_generator_args = XmlGeneratorArguments(
args.output_large_page, args.output_metronome, args.output_tempo
)

if not args.image:
eprint("No image provided")
parser.print_help()
sys.exit(1)
elif os.path.isfile(args.image):
process_image(args.image, args.debug, args.cache)
process_image(args.image, args.debug, args.cache, xml_generator_args)
elif os.path.isdir(args.image):
image_files = get_all_image_files_in_folder(args.image)
eprint("Processing", len(image_files), "files:", image_files)
error_files = []
for image_file in image_files:
eprint("=========================================")
try:
process_image(image_file, args.debug, args.cache)
process_image(image_file, args.debug, args.cache, xml_generator_args)
eprint("Finished", image_file)
except Exception as e:
eprint(f"An error occurred while processing {image_file}: {e}")
Expand Down
54 changes: 46 additions & 8 deletions homr/xml_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@
)


class XmlGeneratorArguments:
def __init__(self, large_page: bool | None, metronome: int | None, tempo: int | None):
self.large_page = large_page
self.metronome = metronome
self.tempo = tempo


def build_work(f_name: str) -> mxl.XMLWork: # type: ignore
work = mxl.XMLWork()
title = mxl.XMLWorkTitle()
Expand All @@ -20,7 +27,9 @@ def build_work(f_name: str) -> mxl.XMLWork: # type: ignore
return work


def build_defaults() -> mxl.XMLDefaults: # type: ignore
def build_defaults(args: XmlGeneratorArguments) -> mxl.XMLDefaults: # type: ignore
if not args.large_page:
return mxl.XMLDefaults()
# These values are larger than a letter or A4 format so that
# we only have to break staffs with every new detected staff
# This works well for electronic formats, if the results are supposed
Expand Down Expand Up @@ -146,9 +155,33 @@ def build_chord(chord: ResultChord) -> list[mxl.XMLNote]: # type: ignore
return build_note_group(chord)


def build_measure(measure: ResultMeasure, measure_number: int) -> mxl.XMLMeasure: # type: ignore
def build_add_time_direction(args: XmlGeneratorArguments) -> mxl.XMLDirection: # type: ignore
if not args.metronome:
return mxl.XMLDirection()
direction = mxl.XMLDirection()
direction_type = mxl.XMLDirectionType()
direction.add_child(direction_type)
metronome = mxl.XMLMetronome()
direction_type.add_child(metronome)
beat_unit = mxl.XMLBeatUnit(value_="quarter")
metronome.add_child(beat_unit)
per_minute = mxl.XMLPerMinute(value_=str(args.metronome))
metronome.add_child(per_minute)
if args.tempo:
direction.add_child(mxl.XMLSound(tempo=args.tempo))
else:
direction.add_child(mxl.XMLSound(tempo=args.metronome))
return direction


def build_measure( # type: ignore
args: XmlGeneratorArguments, measure: ResultMeasure, is_first_part: bool, measure_number: int
) -> mxl.XMLMeasure:
result = mxl.XMLMeasure(number=str(measure_number))
if measure.is_new_line:
is_first_measure = measure_number == 1
if is_first_measure and is_first_part:
result.add_child(build_add_time_direction(args))
if measure.is_new_line and not is_first_measure:
result.add_child(mxl.XMLPrint(new_system="yes"))
for symbol in measure.symbols:
if isinstance(symbol, ResultClef):
Expand All @@ -163,20 +196,25 @@ def build_measure(measure: ResultMeasure, measure_number: int) -> mxl.XMLMeasure
return result


def build_part(staff: ResultStaff, index: int) -> mxl.XMLPart: # type: ignore
def build_part( # type: ignore
args: XmlGeneratorArguments, staff: ResultStaff, index: int
) -> mxl.XMLPart:
part = mxl.XMLPart(id=get_part_id(index))
measure_number = 1
is_first_part = index == 0
for measure in staff.measures:
part.add_child(build_measure(measure, measure_number))
part.add_child(build_measure(args, measure, is_first_part, measure_number))
measure_number += 1
return part


def generate_xml(staffs: list[ResultStaff], title: str) -> mxl.XMLElement: # type: ignore
def generate_xml( # type: ignore
args: XmlGeneratorArguments, staffs: list[ResultStaff], title: str
) -> mxl.XMLElement:
root = mxl.XMLScorePartwise()
root.add_child(build_work(title))
root.add_child(build_defaults())
root.add_child(build_defaults(args))
root.add_child(build_part_list(len(staffs)))
for index, staff in enumerate(staffs):
root.add_child(build_part(staff, index))
root.add_child(build_part(args, staff, index))
return root

0 comments on commit b05a8be

Please sign in to comment.