Skip to content

Commit

Permalink
refactor: clean up dataset generation script
Browse files Browse the repository at this point in the history
  • Loading branch information
beuss-git committed May 18, 2023
1 parent c0fd0a6 commit 1f63d2a
Showing 1 changed file with 47 additions and 139 deletions.
186 changes: 47 additions & 139 deletions tools/generate_dataset/generate_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,15 @@
import yaml
from tqdm import tqdm

CVAT_EXPORTS_FOLDER = Path(r"D:\dataset_temp\cvat_exports")
DATASET_FOLDER = Path(r"D:\dataset_temp\generated_test")
CVAT_EXPORTS_FOLDER = Path(r"path\to\cvat\exports")
DATASET_FOLDER = Path(r"path\to\output\dataset")
SOURCE_VIDEO_FOLDERS = [
Path(r"X:\Myggbukta 2022"),
Path(r"X:\DISK1 - Høyregga 17+18 og myggbukta 2020 mai NTNU"),
]
BACKGROUND_IMAGE_PERCENTAGE = 0.1
TRAIN_SPLIT = 0.8
PNG_QUALITY = 3 # 0-9 where 0 is the best quality
# VAL_SPLIT = 0.3 # Not used, just the rest of train split
MAX_WORKERS = 8

CLASSES = [
Expand Down Expand Up @@ -57,7 +56,6 @@ def get_video_filename(annotation_xml: Path) -> str:
def extract_annotations(
cvat_exports_folder: Path, output_folder: Path, processed_counter: mp.Value
) -> None:

# NOTE: We do this because background images are randomly selected and we
# don't want to add more and more background images to the dataset if run multiple times
# also other parameters and the dataset itself might have changed.
Expand All @@ -72,7 +70,6 @@ def extract_annotations(
def process_zip_file(
zip_filepath: Path, id: int, total: int, processed_counter: mp.Value
) -> None:

try:
with zipfile.ZipFile(zip_filepath, "r") as zip_file:
zip_file.extract(
Expand All @@ -90,7 +87,6 @@ def process_zip_file(
generate_yolo_dataset(
output_folder / video_filename,
annotation_path,
BACKGROUND_IMAGE_PERCENTAGE,
)
except FileNotFoundError as err:
print(err)
Expand Down Expand Up @@ -197,8 +193,6 @@ def parse_xml_annotation(
def generate_yolo_dataset(
yolo_dataset_folder: Path,
annotation_path: Path,
background_image_percentage: float,
background_image_no_annotation_buffer: int = 50,
) -> None:
os.makedirs(yolo_dataset_folder)

Expand All @@ -210,10 +204,6 @@ def generate_yolo_dataset(
) = parse_xml_annotation(annotation_path)

def get_video_path(video_filename: str) -> Path:
"""
This is kinda hacky, but it works.
Should have globbed it directly, but escaping the brackets was a pain.
"""
for source_video_folder in SOURCE_VIDEO_FOLDERS:
search_name = os.path.splitext(video_filename)[0]
# NOTE: we replace underscores with spaces due to handbrake naming
Expand All @@ -229,14 +219,10 @@ def get_video_path(video_filename: str) -> Path:

video_path = get_video_path(video_filename)

def extract_frames(video_path: Path, background_frames: List[int]) -> None:
def extract_frames(video_path: Path) -> None:
video = cv2.VideoCapture(str(video_path))

frames_to_extract = list(
filter(
lambda n: n in annotations or n in background_frames, range(frame_count)
)
)
frames_to_extract = list(filter(lambda n: n in annotations, range(frame_count)))

for frame_number in frames_to_extract:
video.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
Expand All @@ -249,47 +235,6 @@ def extract_frames(video_path: Path, background_frames: List[int]) -> None:
)
video.release()

def determine_background_frames(max_frame_number: int) -> List[int]:
"""
Determines background frames based on the background_image_percentage.
It needs to have at least background_image_no_annotation_buffer amount of frames before
and after it to be considered a background frame.
This is to prevent the background frames from being too close to the
annotated frames and containing half a fish.
Needs to be verified by a human.
"""
# Store annotated frame numbers in a set
annotated_frames = set(annotations.keys())

# Create a set of frame numbers that are too close to annotated frames
exclude_frames = set(
frame_number
for annotated_frame_number in annotated_frames
for frame_number in range(
max(0, annotated_frame_number - background_image_no_annotation_buffer),
min(
max_frame_number,
annotated_frame_number + background_image_no_annotation_buffer + 1,
),
)
)

# Create a set of background frames that are not too close to annotated frames
background_frames = (
set(range(max_frame_number + 1)) - annotated_frames - exclude_frames
)

# Convert the background frames set to a list and shuffle it
background_frames = list(background_frames)
random.shuffle(background_frames)

# Determine the amount of background frames to use
n_background_frames = int(background_image_percentage * len(annotations))

return background_frames[:n_background_frames]

# Iterate all the annotations and generate the yolo dataset
# Each entry of annotations contain the frame number as key and a dictionary of labels as value, with the bounding boxes as value
video_width, video_height = video_resolution
Expand Down Expand Up @@ -322,16 +267,48 @@ def determine_background_frames(max_frame_number: int) -> List[int]:
f"{CLASSES.index(label)} {x_center} {y_center} {width} {height}\n"
)

background_frames = determine_background_frames(frame_count)
# Extract annotation frames and background frames
extract_frames(video_path)

# Write empty annotation files for the background frames
for frame_number in background_frames:
annotation_file = yolo_dataset_folder / f"{frame_number:06}.txt"
with open(annotation_file, "w", encoding="utf-8") as file:
pass

# Extract annotation frames and background frames
extract_frames(video_path, background_frames)
def get_background_images(images: List[Path]) -> List[Path]:
# Count the number of background frames by counting the number of empty annotation files and images without annotations
background_images = []
for image in images:
annotation_file = image.with_suffix(".txt")
if not annotation_file.exists() or annotation_file.stat().st_size == 0:
background_images.append(image)
return background_images


def adjust_background_images(images: List[Path]) -> List[Path]:
background_images: List[Path] = get_background_images(images)

# Remove all background images from the list of images so we don't use them twice
images = set(images) - set(background_images)

# Shuffle the manual background images
random.shuffle(background_images)

# Adjust background images to match BACKGROUND_IMAGE_PERCENTAGE
background_image_count = int(len(images) * BACKGROUND_IMAGE_PERCENTAGE)
if len(background_images) < background_image_count:
print(
f"WARNING: not enough background images, using {len(background_images)} instead of {background_image_count}"
)
background_image_count = len(background_images)

background_image = background_images[:background_image_count]

# NOTE: this doesn't actually make it so that the percentage of background images is exactly BACKGROUND_IMAGE_PERCENTAGE of the total
# it just adds the BACKGROUND_IMAGE_PERCENTAGE of the amount of annotated images to the total amount of images

# Add the manual background images to the set of images
images = list(images) + background_images

print(f"Percentage of background images: {len(background_image) / len(images):.2f}")

return list(images)


def split_train_val(dataset_path: Path, train_split: float) -> None:
Expand All @@ -348,8 +325,11 @@ def split_train_val(dataset_path: Path, train_split: float) -> None:
# images = [img.relative_to(dataset_path) for img in images]
all_images.extend(images)

all_images = adjust_background_images(all_images)

# Shuffle the images and split them into train and val sets
random.shuffle(all_images)

train_size = int(len(all_images) * train_split)

print(f"Total images: {len(all_images)}")
Expand All @@ -367,82 +347,10 @@ def split_train_val(dataset_path: Path, train_split: float) -> None:
val_file.write(f"{img}\n")


def merge_datasets(dataset_a: Path, dataset_b: Path, output_folder: Path):
"""
This will merge the train.txt and val.txt files from two datasets into one dataset.
It will also copy the obj.names and obj.data files from the first dataset.
And it will output the new train.txt and val.txt files in the output folder.
"""
train_a = dataset_a / "train.txt"
val_a = dataset_a / "val.txt"
train_b = dataset_b / "train.txt"
val_b = dataset_b / "val.txt"

with open(train_a, "r", encoding="utf-8") as file:
train_a_lines = file.readlines()
with open(val_a, "r", encoding="utf-8") as file:
val_a_lines = file.readlines()
with open(train_b, "r", encoding="utf-8") as file:
train_b_lines = file.readlines()
with open(val_b, "r", encoding="utf-8") as file:
val_b_lines = file.readlines()

# FIXME: This is a hack to fix the paths in the train.txt and val.txt files
# I only need it for the yolo dataset (b) atm because prepends the data/ folder to the paths
"""
def remove_data_path(lines: List[str]) -> List[str]:
# Remove data/ from each line by removing first 5 characters
lines = list(map(lambda line: line[5:], lines))
return [f"{ line}" for line in lines]
train_b_lines = remove_data_path(train_b_lines)
val_b_lines = remove_data_path(val_b_lines)
"""

# Merge the train and val files
train_lines = train_a_lines + train_b_lines
val_lines = val_a_lines + val_b_lines

# Shuffle the lines
random.shuffle(train_lines)
random.shuffle(val_lines)

# Write the new train and val files
with open(output_folder / "train.txt", "w", encoding="utf-8") as file:
file.writelines(train_lines)

with open(output_folder / "val.txt", "w", encoding="utf-8") as file:
file.writelines(val_lines)

# shutil.copy(dataset_a / "obj.names", output_folder / "obj.names")
# shutil.copy(dataset_a / "obj.data", output_folder / "obj.data")

# Create the dataset.yaml config file
dataset_yaml = {
"path": str(output_folder),
"train": "train.txt",
"val": "val.txt",
"names": CLASSES,
}
with open(output_folder / "dataset.yaml", "w", encoding="utf-8") as yaml_file:
yaml.dump(dataset_yaml, yaml_file, default_flow_style=False)


if __name__ == "__main__":
"""
processed_counter = mp.Value("i", 0)
try:
extract_annotations(CVAT_EXPORTS_FOLDER, DATASET_FOLDER, processed_counter)
except FileNotFoundError as e:
print(e)
"""
split_train_val(DATASET_FOLDER, TRAIN_SPLIT)
# generate_obj_files(DATASET_FOLDER)

# Split the downloaded yolo dataset into train and val
split_train_val(Path(r"D:\dataset_temp\yolo_updated_with_images"), TRAIN_SPLIT)
merge_datasets(
DATASET_FOLDER, # Path(r"C:\Users\benja\Documents\datasets\nina_yolo_new"),
Path(r"D:\dataset_temp\yolo_updated_with_images"),
output_folder=Path(r"D:\dataset_temp\spliced"),
)

0 comments on commit 1f63d2a

Please sign in to comment.