diff --git a/training/show_examples_from_index.py b/training/show_examples_from_index.py index 0882950..3194141 100644 --- a/training/show_examples_from_index.py +++ b/training/show_examples_from_index.py @@ -1,21 +1,54 @@ # ruff: noqa: T201 +import argparse import os -import sys from typing import Any import cv2 import numpy as np from termcolor import colored -index_file_name = sys.argv[1] -number_of_samples_per_iteration = int(sys.argv[2]) +parser = argparse.ArgumentParser(description="Show examples from a dataset index") +parser.add_argument("index", type=str, help="Index file name") +parser.add_argument("number_of_images", type=int, help="Number of images to show at once") +parser.add_argument( + "--sorted", + action="store_true", + help="Show the images in the order they are in the index file", +) +parser.add_argument( + "--min-ser", + type=int, + help="Minimum SER to show", +) +parser.add_argument( + "--max-ser", + type=int, + help="Maximum SER to show", +) +args = parser.parse_args() + +index_file_name = args.index +number_of_samples_per_iteration = int(args.number_of_images) index_file = open(index_file_name) index_lines = index_file.readlines() index_file.close() -np.random.shuffle(index_lines) +ser_position = 2 + +if not args.sorted: + np.random.shuffle(index_lines) + +if args.min_ser is not None: + index_lines = [ + line for line in index_lines if int(line.split(",")[ser_position]) >= args.min_ser + ] + +if args.max_ser is not None: + index_lines = [ + line for line in index_lines if int(line.split(",")[ser_position]) <= args.max_ser + ] def print_color(text: str, highlights: list[str], color: Any) -> None: @@ -29,10 +62,11 @@ def print_color(text: str, highlights: list[str], color: Any) -> None: while True: - batch = [] - for _ in range(number_of_samples_per_iteration): + batch: list[str] = [] + while len(batch) < number_of_samples_per_iteration: if len(index_lines) == 0: break + batch.append(index_lines.pop()) if len(batch) == 0: @@ -45,7 +79,12 @@ def print_color(text: str, highlights: list[str], color: Any) -> None: print("==========================================") print() for line in batch: - image_path, semantic_path = line.strip().split(",") + cells = line.strip().split(",") + image_path = cells[0] + semantic_path = cells[1] + ser: None | int = None + if len(cells) > ser_position: + ser = int(cells[ser_position]) agnostic_path = semantic_path.replace(".semantic", ".agnostic") image = cv2.imread(image_path) with open(semantic_path) as file: @@ -60,7 +99,10 @@ def print_color(text: str, highlights: list[str], color: Any) -> None: else: images = np.concatenate((images, image), axis=0) print() - print(">>> " + image_path) + if ser is not None: + print(">>> " + image_path + f" SER: {ser}%") + else: + print(">>> " + image_path) print_color(semantic, ["barline", "#", "N", "b"], "green") cv2.imshow("Images", images) # type: ignore escKey = 27 diff --git a/validation/ser_for_data_set.py b/validation/ser_for_data_set.py index 19ebbb1..1f59f55 100644 --- a/validation/ser_for_data_set.py +++ b/validation/ser_for_data_set.py @@ -17,9 +17,6 @@ def calc_symbol_error_rate_for_list(dataset: list[str], result_file: str, config total = len(dataset) with open(result_file, "w") as result: - result.write( - "img_path,ser,len_expected,len_actual,added_symbols,missing_symbols,expected,actual\n" - ) for sample in dataset: img_path, semantic_path = sample.strip().split(",") expected_str = _load_semantic_file(semantic_path)[0].strip()