Skip to content

Commit

Permalink
add data augmentation as option
Browse files Browse the repository at this point in the history
  • Loading branch information
Packophys authored and Packophys committed Aug 25, 2023
1 parent d6458ad commit bb01a9f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
25 changes: 18 additions & 7 deletions src/plasticorigins/training/data/make_dataset2.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,21 +65,31 @@ def main(args: Namespace) -> None:
print(
f"found {cpos} valid annotations with images and {cneg} unmatched annotations"
)
if args.split:
train_files, val_files = get_train_valid(yolo_filelist, args.split)

train_files, val_files = get_train_valid(yolo_filelist, args.split)
train_files = data_augmentation_for_yolo_data(data_dir, train_files)

if args.artificial_data:
# use data augmentation for artificial data only if original data have been processed
if args.artificial_data and args.data_augmentation:

artificial_data_dir = Path(args.artificial_data)
artificial_data_list = [Path(path).as_posix() for path in os.listdir(artificial_data_dir / "images")]
artificial_train_files, artificial_val_files = get_train_valid(artificial_data_list, args.split)
artificial_train_files = data_augmentation_for_yolo_data(artificial_data_dir, artificial_train_files)
# concatenate original images and artificial data
train_files = train_files + artificial_train_files
val_files = val_files + artificial_val_files
with open(data_dir / "train.txt", "w") as f:
for path in artificial_train_files:
f.write(path + "\n")
with open(data_dir / "val.txt", "w") as f:
for path in artificial_val_files:
f.write(path + "\n")

else:
train_files, val_files = get_train_valid(yolo_filelist)

if args.data_augmentation:
train_files = data_augmentation_for_yolo_data(data_dir, train_files)

generate_yolo_files(data_dir, train_files, val_files, args.nb_classes)
generate_yolo_files(data_dir, train_files, val_files, args.nb_classes)


if __name__ == "__main__":
Expand Down Expand Up @@ -107,6 +117,7 @@ def main(args: Namespace) -> None:
)
parser.add_argument("--nb-classes", type=int, default=10)
parser.add_argument("--split", type=float, default=0.85)
parser.add_argument("--data-augmentation", type=int, default=0)
parser.add_argument("--limit-data", type=int, default=0)
parser.add_argument(
"--exclude-img-folder",
Expand Down
3 changes: 3 additions & 0 deletions tests/test_plasticorigins/training/test_make_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
quality_filters="[good,medium]",
nb_classes=12,
limit_data=0,
data_augmentation=0,
exclude_img_folder=path_data + "exclude_ids/",
split=0.85,
)
Expand All @@ -32,6 +33,7 @@
quality_filters=None,
nb_classes=12,
limit_data=0,
data_augmentation=0,
exclude_img_folder=None,
split=0.85,
)
Expand All @@ -46,6 +48,7 @@
quality_filters="[good,medium]",
nb_classes=12,
limit_data=0,
data_augmentation=0,
exclude_img_folder=None,
split=0.85,
)
Expand Down

0 comments on commit bb01a9f

Please sign in to comment.