Skip to content

Commit

Permalink
show_examples_from_index now also works on index_ser.txt files
Browse files Browse the repository at this point in the history
  • Loading branch information
liebharc committed Dec 23, 2024
1 parent 3664b68 commit eee9bd9
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 11 deletions.
58 changes: 50 additions & 8 deletions training/show_examples_from_index.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,54 @@
# ruff: noqa: T201

import argparse
import os
import sys
from typing import Any

import cv2
import numpy as np
from termcolor import colored

index_file_name = sys.argv[1]
number_of_samples_per_iteration = int(sys.argv[2])
parser = argparse.ArgumentParser(description="Show examples from a dataset index")
parser.add_argument("index", type=str, help="Index file name")
parser.add_argument("number_of_images", type=int, help="Number of images to show at once")
parser.add_argument(
"--sorted",
action="store_true",
help="Show the images in the order they are in the index file",
)
parser.add_argument(
"--min-ser",
type=int,
help="Minimum SER to show",
)
parser.add_argument(
"--max-ser",
type=int,
help="Maximum SER to show",
)
args = parser.parse_args()

index_file_name = args.index
number_of_samples_per_iteration = int(args.number_of_images)

index_file = open(index_file_name)
index_lines = index_file.readlines()
index_file.close()

np.random.shuffle(index_lines)
ser_position = 2

if not args.sorted:
np.random.shuffle(index_lines)

if args.min_ser is not None:
index_lines = [
line for line in index_lines if int(line.split(",")[ser_position]) >= args.min_ser
]

if args.max_ser is not None:
index_lines = [
line for line in index_lines if int(line.split(",")[ser_position]) <= args.max_ser
]


def print_color(text: str, highlights: list[str], color: Any) -> None:
Expand All @@ -29,10 +62,11 @@ def print_color(text: str, highlights: list[str], color: Any) -> None:


while True:
batch = []
for _ in range(number_of_samples_per_iteration):
batch: list[str] = []
while len(batch) < number_of_samples_per_iteration:
if len(index_lines) == 0:
break

batch.append(index_lines.pop())

if len(batch) == 0:
Expand All @@ -45,7 +79,12 @@ def print_color(text: str, highlights: list[str], color: Any) -> None:
print("==========================================")
print()
for line in batch:
image_path, semantic_path = line.strip().split(",")
cells = line.strip().split(",")
image_path = cells[0]
semantic_path = cells[1]
ser: None | int = None
if len(cells) > ser_position:
ser = int(cells[ser_position])
agnostic_path = semantic_path.replace(".semantic", ".agnostic")
image = cv2.imread(image_path)
with open(semantic_path) as file:
Expand All @@ -60,7 +99,10 @@ def print_color(text: str, highlights: list[str], color: Any) -> None:
else:
images = np.concatenate((images, image), axis=0)
print()
print(">>> " + image_path)
if ser is not None:
print(">>> " + image_path + f" SER: {ser}%")
else:
print(">>> " + image_path)
print_color(semantic, ["barline", "#", "N", "b"], "green")
cv2.imshow("Images", images) # type: ignore
escKey = 27
Expand Down
3 changes: 0 additions & 3 deletions validation/ser_for_data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@ def calc_symbol_error_rate_for_list(dataset: list[str], result_file: str, config
total = len(dataset)

with open(result_file, "w") as result:
result.write(
"img_path,ser,len_expected,len_actual,added_symbols,missing_symbols,expected,actual\n"
)
for sample in dataset:
img_path, semantic_path = sample.strip().split(",")
expected_str = _load_semantic_file(semantic_path)[0].strip()
Expand Down

0 comments on commit eee9bd9

Please sign in to comment.