From bb01a9f796ce8c13e4ef63648f0ee5c9afa32a5b Mon Sep 17 00:00:00 2001 From: Packophys Date: Fri, 25 Aug 2023 12:59:50 +0200 Subject: [PATCH] add data augmentation as option --- .../training/data/make_dataset2.py | 25 +++++++++++++------ .../training/test_make_dataset.py | 3 +++ 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/src/plasticorigins/training/data/make_dataset2.py b/src/plasticorigins/training/data/make_dataset2.py index ab483ec..3900764 100644 --- a/src/plasticorigins/training/data/make_dataset2.py +++ b/src/plasticorigins/training/data/make_dataset2.py @@ -65,21 +65,31 @@ def main(args: Namespace) -> None: print( f"found {cpos} valid annotations with images and {cneg} unmatched annotations" ) + if args.split: + train_files, val_files = get_train_valid(yolo_filelist, args.split) - train_files, val_files = get_train_valid(yolo_filelist, args.split) - train_files = data_augmentation_for_yolo_data(data_dir, train_files) - - if args.artificial_data: + # use data augmentation for artificial data only if original data have been processed + if args.artificial_data and args.data_augmentation: artificial_data_dir = Path(args.artificial_data) artificial_data_list = [Path(path).as_posix() for path in os.listdir(artificial_data_dir / "images")] artificial_train_files, artificial_val_files = get_train_valid(artificial_data_list, args.split) artificial_train_files = data_augmentation_for_yolo_data(artificial_data_dir, artificial_train_files) # concatenate original images and artificial data - train_files = train_files + artificial_train_files - val_files = val_files + artificial_val_files + with open(data_dir / "train.txt", "w") as f: + for path in artificial_train_files: + f.write(path + "\n") + with open(data_dir / "val.txt", "w") as f: + for path in artificial_val_files: + f.write(path + "\n") + + else: + train_files, val_files = get_train_valid(yolo_filelist) + + if args.data_augmentation: + train_files = data_augmentation_for_yolo_data(data_dir, train_files) - generate_yolo_files(data_dir, train_files, val_files, args.nb_classes) + generate_yolo_files(data_dir, train_files, val_files, args.nb_classes) if __name__ == "__main__": @@ -107,6 +117,7 @@ def main(args: Namespace) -> None: ) parser.add_argument("--nb-classes", type=int, default=10) parser.add_argument("--split", type=float, default=0.85) + parser.add_argument("--data-augmentation", type=int, default=0) parser.add_argument("--limit-data", type=int, default=0) parser.add_argument( "--exclude-img-folder", diff --git a/tests/test_plasticorigins/training/test_make_dataset.py b/tests/test_plasticorigins/training/test_make_dataset.py index 2b9c3b3..15f7d6e 100644 --- a/tests/test_plasticorigins/training/test_make_dataset.py +++ b/tests/test_plasticorigins/training/test_make_dataset.py @@ -17,6 +17,7 @@ quality_filters="[good,medium]", nb_classes=12, limit_data=0, + data_augmentation=0, exclude_img_folder=path_data + "exclude_ids/", split=0.85, ) @@ -32,6 +33,7 @@ quality_filters=None, nb_classes=12, limit_data=0, + data_augmentation=0, exclude_img_folder=None, split=0.85, ) @@ -46,6 +48,7 @@ quality_filters="[good,medium]", nb_classes=12, limit_data=0, + data_augmentation=0, exclude_img_folder=None, split=0.85, )