diff --git a/homr/download_utils.py b/homr/download_utils.py index 8a12910..a40fea3 100644 --- a/homr/download_utils.py +++ b/homr/download_utils.py @@ -38,21 +38,34 @@ def download_file(url: str, filename: str) -> None: eprint() # Add newline after download progress -def unzip_file(filename: str, output_folder: str) -> None: +def unzip_file(filename: str, output_folder: str, flatten_root_entry: bool = False) -> None: with zipfile.ZipFile(filename, "r") as zip_ref: - for member in zip_ref.namelist(): + zip_contents = zip_ref.namelist() + + if flatten_root_entry: + common_prefix = os.path.commonprefix(zip_contents) + if common_prefix and common_prefix.endswith("/"): + zip_contents_dict = { + file: os.path.relpath(file, common_prefix) for file in zip_contents + } + else: + zip_contents_dict = {file: file for file in zip_contents} + else: + zip_contents_dict = {file: file for file in zip_contents} + + for original, member in zip_contents_dict.items(): # Ensure file path is safe if os.path.isabs(member) or ".." in member: eprint(f"Skipping potentially unsafe file {member}") continue # Handle directories - if member.endswith("/"): + if original.endswith("/"): os.makedirs(os.path.join(output_folder, member), exist_ok=True) continue # Extract file - source = zip_ref.open(member) + source = zip_ref.open(original) target = open(os.path.join(output_folder, member), "wb") with source, target: diff --git a/training/convert_grandstaff.py b/training/convert_grandstaff.py index db19374..d35f68d 100644 --- a/training/convert_grandstaff.py +++ b/training/convert_grandstaff.py @@ -1,7 +1,5 @@ import multiprocessing import os -import platform -import stat import sys from pathlib import Path @@ -13,7 +11,7 @@ from torchvision import transforms as tr # type: ignore from torchvision.transforms import Compose # type: ignore -from homr.download_utils import download_file, untar_file +from homr.download_utils import download_file, untar_file, unzip_file from homr.simple_logging import eprint from homr.staff_dewarping import warp_image_randomly from homr.staff_parsing import add_image_into_tr_omr_canvas @@ -27,24 +25,19 @@ grandstaff_root = os.path.join(dataset_root, "grandstaff") grandstaff_train_index = os.path.join(grandstaff_root, "index.txt") -hum2xml = os.path.join(dataset_root, "hum2xml") -if platform.system() == "Windows": - eprint("Transformer training is only implemented for Linux") - eprint("Feel free to submit a PR to support Windows") - eprint("The main work should be to download hum2xml.exe and change the calls") - eprint("to use the exe-file instead of the linux binary.") - sys.exit(1) -if not os.path.exists(hum2xml): - eprint("Downloading hum2xml from https://extras.humdrum.org/man/hum2xml/") - download_file("http://extras.humdrum.org/bin/linux/hum2xml", hum2xml) - os.chmod(hum2xml, stat.S_IXUSR) - if not os.path.exists(grandstaff_root): eprint("Downloading grandstaff from https://sites.google.com/view/multiscore-project/datasets") grandstaff_archive = os.path.join(dataset_root, "grandstaff.tgz") download_file("https://grfia.dlsi.ua.es/musicdocs/grandstaff.tgz", grandstaff_archive) untar_file(grandstaff_archive, grandstaff_root) + eprint("Adding musicxml files to grandstaff dataset") + music_xml_download = os.path.join(dataset_root, "grandstaff_musicxml.zip") + download_file( + "https://github.com/liebharc/grandstaff_musicxml/archive/refs/heads/main.zip", + music_xml_download, + ) + unzip_file(music_xml_download, grandstaff_root, flatten_root_entry=True) def _get_dark_pixels_per_row(image: NDArray) -> NDArray: @@ -213,10 +206,6 @@ def _convert_file( # noqa: PLR0911 basename = str(path).replace(".krn", "") image_file = str(path).replace(".krn", ".jpg") musicxml = str(path).replace(".krn", ".musicxml") - result = os.system(f"{hum2xml} {path} > {musicxml}") # noqa: S605 - if result != 0: - eprint(f"Failed to convert {path}") - return [] upper_semantic, lower_semantic = _music_xml_to_semantic(musicxml, basename) if upper_semantic is None or lower_semantic is None: eprint(f"Failed to convert {musicxml}")