From 006ceb38d3f022c1e4ad0ef7979483cdf4a073fb Mon Sep 17 00:00:00 2001
From: Christian Liebhardt <christian.liebhardt@arcor.de>
Date: Sun, 22 Dec 2024 23:13:34 +0100
Subject: [PATCH] Added a script to run a checkpoint on an data set index and
 to get the symbol error rate for each entry of the index

---
 training/transformer/data_set_filters.py |  4 ++
 training/transformer/train.py            |  7 +-
 validation/ser_for_data_set.py           | 82 ++++++++++++++++++++++++
 3 files changed, 87 insertions(+), 6 deletions(-)
 create mode 100644 training/transformer/data_set_filters.py
 create mode 100644 validation/ser_for_data_set.py

diff --git a/training/transformer/data_set_filters.py b/training/transformer/data_set_filters.py
new file mode 100644
index 0000000..303265f
--- /dev/null
+++ b/training/transformer/data_set_filters.py
@@ -0,0 +1,4 @@
+def contains_supported_clef(semantic: str) -> bool:
+    if semantic.count("clef-") != 1:
+        return False
+    return "clef-G2" in semantic or "clef-F4" in semantic
diff --git a/training/transformer/train.py b/training/transformer/train.py
index e712a55..c8d6844 100644
--- a/training/transformer/train.py
+++ b/training/transformer/train.py
@@ -19,6 +19,7 @@
 )
 from training.run_id import get_run_id
 from training.transformer.data_loader import load_dataset
+from training.transformer.data_set_filters import contains_supported_clef
 from training.transformer.mix_datasets import mix_training_sets
 
 torch._dynamo.config.suppress_errors = True
@@ -29,12 +30,6 @@ def load_training_index(file_path: str) -> list[str]:
         return f.readlines()
 
 
-def contains_supported_clef(semantic: str) -> bool:
-    if semantic.count("clef-") != 1:
-        return False
-    return "clef-G2" in semantic or "clef-F4" in semantic
-
-
 def filter_for_clefs(file_paths: list[str]) -> list[str]:
     result = []
     for entry in file_paths:
diff --git a/validation/ser_for_data_set.py b/validation/ser_for_data_set.py
new file mode 100644
index 0000000..5b98794
--- /dev/null
+++ b/validation/ser_for_data_set.py
@@ -0,0 +1,82 @@
+import argparse
+import os
+import re
+
+import cv2
+import editdistance  # type: ignore
+
+from homr.simple_logging import eprint
+from homr.transformer.configs import Config
+from homr.transformer.staff2score import Staff2Score
+from training.transformer.data_set_filters import contains_supported_clef
+
+
+def calc_symbol_error_rate_for_list(dataset: list[str], result_file: str, config: Config) -> None:
+    model = Staff2Score(config, keep_all_symbols_in_chord=True)
+    i = 0
+    total = len(dataset)
+
+    with open(result_file, "w") as result:
+        result.write(
+            "img_path,ser,len_expected,len_actual,added_symbols,missing_symbols,expected,actual\n"
+        )
+        for sample in dataset:
+            img_path, semantic_path = sample.strip().split(",")
+            expected_str = _load_semantic_file(semantic_path)[0].strip()
+            if not contains_supported_clef(expected_str):
+                continue
+
+            image = cv2.imread(img_path)
+            actual = model.predict(image)[0].split("+")
+            expected = re.split(r"\+|\s+", expected_str)
+            if "timeSignature" not in expected:
+                # reference data has no time signature
+                actual = [symbol for symbol in actual if not symbol.startswith("timeSignature")]
+            actual = sort_chords(actual)
+            expected = sort_chords(expected)
+            distance = editdistance.eval(expected, actual)
+            ser = distance / len(expected)
+            ser = round(100 * ser)
+            i += 1
+            percentage = round(i / total * 100)
+            len_expected = len(expected)
+            len_actual = len(actual)
+            added_symbols = "+".join(set(actual) - set(expected))
+            missing_symbols = "+".join(set(expected) - set(actual))
+            result.write(
+                f"{img_path},{ser},{len_expected},{len_actual},{added_symbols},{missing_symbols},{"+".join(expected)},{"+".join(actual)}\n"
+            )
+            eprint(f"Progress: {percentage}%, SER: {ser}%")
+
+
+def _load_semantic_file(semantic_path: str) -> list[str]:
+    with open(semantic_path) as f:
+        return f.readlines()
+
+
+def sort_chords(symbols: list[str]) -> list[str]:
+    result = []
+    for symbol in symbols:
+        result.append(str.join("|", sorted(symbol.split("|"))))
+    return result
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description="Calculate symbol error rate.")
+    parser.add_argument("checkpoint_file", type=str, help="Path to the checkpoint file.")
+    parser.add_argument("index_file", type=str, help="Path to the index file.")
+    args = parser.parse_args()
+
+    index_file = args.index_file
+    if not os.path.exists(index_file):
+        raise FileNotFoundError(f"Index file {index_file} does not exist.")
+
+    with open(index_file) as f:
+        index = f.readlines()
+
+    config = Config()
+    checkpoint_file = str(args.checkpoint_file)
+    config.filepaths.checkpoint = str(checkpoint_file)
+    result_file = index_file.split(".")[0] + "_ser.txt"
+    calc_symbol_error_rate_for_list(index, result_file, config)
+    eprint(result_file)