Skip to content

Commit

Permalink
Added a script to run a checkpoint on an data set index and to get th…
Browse files Browse the repository at this point in the history
…e symbol error rate for each entry of the index
  • Loading branch information
liebharc committed Dec 22, 2024
1 parent 92b8f4c commit 006ceb3
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 6 deletions.
4 changes: 4 additions & 0 deletions training/transformer/data_set_filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def contains_supported_clef(semantic: str) -> bool:
if semantic.count("clef-") != 1:
return False
return "clef-G2" in semantic or "clef-F4" in semantic
7 changes: 1 addition & 6 deletions training/transformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from training.run_id import get_run_id
from training.transformer.data_loader import load_dataset
from training.transformer.data_set_filters import contains_supported_clef
from training.transformer.mix_datasets import mix_training_sets

torch._dynamo.config.suppress_errors = True
Expand All @@ -29,12 +30,6 @@ def load_training_index(file_path: str) -> list[str]:
return f.readlines()


def contains_supported_clef(semantic: str) -> bool:
if semantic.count("clef-") != 1:
return False
return "clef-G2" in semantic or "clef-F4" in semantic


def filter_for_clefs(file_paths: list[str]) -> list[str]:
result = []
for entry in file_paths:
Expand Down
82 changes: 82 additions & 0 deletions validation/ser_for_data_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import argparse
import os
import re

import cv2
import editdistance # type: ignore

from homr.simple_logging import eprint
from homr.transformer.configs import Config
from homr.transformer.staff2score import Staff2Score
from training.transformer.data_set_filters import contains_supported_clef


def calc_symbol_error_rate_for_list(dataset: list[str], result_file: str, config: Config) -> None:
model = Staff2Score(config, keep_all_symbols_in_chord=True)
i = 0
total = len(dataset)

with open(result_file, "w") as result:
result.write(
"img_path,ser,len_expected,len_actual,added_symbols,missing_symbols,expected,actual\n"
)
for sample in dataset:
img_path, semantic_path = sample.strip().split(",")
expected_str = _load_semantic_file(semantic_path)[0].strip()
if not contains_supported_clef(expected_str):
continue

image = cv2.imread(img_path)
actual = model.predict(image)[0].split("+")
expected = re.split(r"\+|\s+", expected_str)
if "timeSignature" not in expected:
# reference data has no time signature
actual = [symbol for symbol in actual if not symbol.startswith("timeSignature")]
actual = sort_chords(actual)
expected = sort_chords(expected)
distance = editdistance.eval(expected, actual)
ser = distance / len(expected)
ser = round(100 * ser)
i += 1
percentage = round(i / total * 100)
len_expected = len(expected)
len_actual = len(actual)
added_symbols = "+".join(set(actual) - set(expected))
missing_symbols = "+".join(set(expected) - set(actual))
result.write(
f"{img_path},{ser},{len_expected},{len_actual},{added_symbols},{missing_symbols},{"+".join(expected)},{"+".join(actual)}\n"
)
eprint(f"Progress: {percentage}%, SER: {ser}%")


def _load_semantic_file(semantic_path: str) -> list[str]:
with open(semantic_path) as f:
return f.readlines()


def sort_chords(symbols: list[str]) -> list[str]:
result = []
for symbol in symbols:
result.append(str.join("|", sorted(symbol.split("|"))))
return result


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Calculate symbol error rate.")
parser.add_argument("checkpoint_file", type=str, help="Path to the checkpoint file.")
parser.add_argument("index_file", type=str, help="Path to the index file.")
args = parser.parse_args()

index_file = args.index_file
if not os.path.exists(index_file):
raise FileNotFoundError(f"Index file {index_file} does not exist.")

with open(index_file) as f:
index = f.readlines()

config = Config()
checkpoint_file = str(args.checkpoint_file)
config.filepaths.checkpoint = str(checkpoint_file)
result_file = index_file.split(".")[0] + "_ser.txt"
calc_symbol_error_rate_for_list(index, result_file, config)
eprint(result_file)

0 comments on commit 006ceb3

Please sign in to comment.