-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
104 lines (87 loc) · 3.79 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# in case of use, please quote https://github.com/tikitong/minicoco repo and https://stackoverflow.com/a/73249837/14864907 solution.
import os
import json
import argparse
import numpy as np
from pathlib import Path
from random import sample
from pycocotools.coco import COCO
from alive_progress import alive_bar
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
from concurrent.futures import ThreadPoolExecutor
parser = argparse.ArgumentParser()
parser.add_argument("annotation_file", type=str, help="annotations/instances_train2017.json path file.")
parser.add_argument("-t", "--training", type=int, help="number of images in the training set.")
parser.add_argument("-v", "--validation", type=int, help="number of images in the validation set.")
parser.add_argument("-cat", "--nargs", nargs='+', help="category names.")
args = parser.parse_args()
# Directory Creation
Path("data/images").mkdir(parents=True, exist_ok=True)
Path("data/labels").mkdir(parents=True, exist_ok=True)
# Load COCO Dataset
coco = COCO(args.annotation_file)
catNms = args.nargs
catIds = coco.getCatIds(catNms)
imgIds = coco.getImgIds(catIds=catIds)
imgOriginals = coco.loadImgs(imgIds)
imgShuffled = sample(imgOriginals, len(imgOriginals))
annotations = {
"info": {
"description": "my-project-name"
}
}
# Function Definitions
def myImages(images: list, train: int, val: int) -> tuple:
myImagesTrain = images[:train]
myImagesVal = images[train:train + val]
return myImagesTrain, myImagesVal
def cocoJson(images: list) -> dict:
dictCOCO = {k: coco.getCatIds(k)[0] for k in catNms}
dictCOCOSorted = dict(sorted(dictCOCO.items(), key=lambda x: x[1]))
IdCategories = list(range(1, len(catNms) + 1))
categories = dict(zip(list(dictCOCOSorted), IdCategories))
arrayIds = np.array([k["id"] for k in images])
annIds = coco.getAnnIds(imgIds=arrayIds, catIds=catIds, iscrowd=None)
anns = coco.loadAnns(annIds)
for k in anns:
k["category_id"] = catIds.index(k["category_id"]) + 1
cats = [{'id': int(value), 'name': key} for key, value in categories.items()]
annotations["images"] = images
annotations["annotations"] = anns
annotations["categories"] = cats
return annotations
def createJson(jsonfile: dict, train: bool) -> None:
name = "train2017" if train else "val2017"
with open(f"data/labels/{name}.json", "w") as outfile:
json.dump(jsonfile, outfile)
def download_image(im, session):
if not os.path.isfile(f"data/images/{im['file_name']}"):
for _ in range(3): # retry 3 times
try:
img_data = session.get(im['coco_url']).content
with open(f'data/images/{im["file_name"]}', 'wb') as handler:
handler.write(img_data)
break # if download is successful, break the loop
except Exception as e:
print(f"Error downloading {im['file_name']}: {e}")
def downloadImages(img: list, title: str) -> None:
session = requests.Session()
retry = Retry(connect=3, backoff_factor=0.5)
adapter = HTTPAdapter(max_retries=retry)
session.mount('http://', adapter)
session.mount('https://', adapter)
with alive_bar(len(img), title=title) as bar:
with ThreadPoolExecutor(max_workers=10) as executor:
futures = [executor.submit(download_image, im, session) for im in img]
for _ in futures:
bar()
# Split images and download
imagetrain, imageval = myImages(imgShuffled, args.training, args.validation)
trainset = cocoJson(imagetrain)
createJson(trainset, train=True)
downloadImages(imagetrain, title='Downloading images of the training set:')
valset = cocoJson(imageval)
createJson(valset, train=False)
downloadImages(imageval, title='Downloading images of the validation set:')