-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcoco.py
197 lines (164 loc) · 5.99 KB
/
coco.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
"""
@File : coco.py
@Version : 1.0
@Author : laugh12321
@Contact : [email protected]
@Date : 2022/09/01 18:08:09
@Desc : 将TT-100K数据集转换为coco格式
"""
import argparse
import json
import os
from multiprocessing import Process
from typing import Tuple
from tqdm import tqdm
class Loader:
def __init__(self, data_dir: str) -> None:
"""初始化
Args:
data_dir (str): TT-100K 数据集根目录
"""
self.__data_dir = data_dir
self.__annos16_dir = os.path.join(data_dir,
"annotations.json") # 2016 版标注信息
self.__annos21_dir = os.path.join(data_dir,
"annotations_all.json") # 2021 版标注信息
self.__categories, self.__annos = self.__get_annotations()
self.__train_ids, self.__val_ids, self.__test_ids = self.__get_ids()
@property
def categories(self) -> dict:
"""类别与id对应关系"""
return self.__categories
@property
def annotations(self) -> dict:
"""合并后的标注信息 (2021 + 2016)"""
return self.__annos
@property
def train_ids(self) -> list:
"""训练集图片id"""
return self.__train_ids
@property
def val_ids(self) -> list:
"""验证集图片id"""
return self.__val_ids
@property
def test_ids(self) -> list:
"""获取测试集图片id"""
return self.__test_ids
def __get_annotations(self) -> Tuple[dict, dict]:
"""获取合并后的类别信息与标注信息"""
__annos16 = json.loads(open(self.__annos16_dir).read())
__annos21 = json.loads(open(self.__annos21_dir).read())
__categories = sorted(
list(set(__annos16["types"] + __annos21["types"]))) # 类别信息合并并排序
return {
category: category_id
for category_id, category in enumerate(__categories)
}, __annos16["imgs"] | __annos21["imgs"]
def __get_ids(self) -> Tuple[list, list, list]:
"""获取图片id"""
__train_path = os.path.join(self.__data_dir, "train/ids.txt")
__val_path = os.path.join(self.__data_dir, "test/ids.txt")
__test_path = os.path.join(self.__data_dir, "other/ids.txt")
return (
open(__train_path).read().splitlines(),
open(__val_path).read().splitlines(),
open(__test_path).read().splitlines(),
)
class TT100k2COCO(Loader):
def __init__(self, data_dir: str) -> None:
super(TT100k2COCO, self).__init__(data_dir)
self.save_dir = os.path.join(data_dir, "annotations")
@staticmethod
def __bbox2xywh(bbox: dict) -> list:
"""将TT-100K中的bbox[xmin, ymin, xmax, ymax], 转为coco的bbox[xmin, ymin, width, height]
Args:
bbox (dict): TT-100K的bbox
Returns:
list: coco的bbox
"""
return [
bbox["xmin"],
bbox["ymin"],
bbox["xmax"] - bbox["xmin"],
bbox["ymax"] - bbox["ymin"],
]
def format2coco(self, ids: list, json_path: str) -> None:
"""转为COCO格式
Args:
ids (list): 图片ids
json_path (str): annotations json 保存路径
"""
coco_json = {"images": [], "annotations": [], "categories": []}
for item_id, image_id in enumerate(tqdm(ids)):
anno = self.annotations[image_id]
image_dict = {
"file_name": anno["path"],
"height": 2048,
"width": 2048,
"id": anno["id"],
}
coco_json["images"].append(image_dict)
for item in anno["objects"]:
xywh = self.__bbox2xywh(item["bbox"])
category = item["category"]
category_id = self.categories[category]
annotation_dict = {
"area": xywh[-2] * xywh[-1],
"iscrowd": 0,
"image_id": anno["id"],
"bbox": xywh,
"category_id": category_id,
"id": item_id,
}
coco_json["annotations"].append(annotation_dict)
if category not in coco_json["categories"]:
coco_json["categories"].append(category)
categories_list = [{
"id": self.categories[category],
"name": category
} for category in coco_json["categories"]]
coco_json["categories"] = categories_list
with open(json_path, "w+", encoding="utf-8") as file:
json.dump(coco_json,
file,
indent=4,
sort_keys=False,
ensure_ascii=False)
def processing(self) -> None:
"""处理进程"""
os.makedirs(self.save_dir, exist_ok=True)
# 创建进程
train_process = Process(
target=self.format2coco,
kwargs={
"ids": self.train_ids,
"json_path": os.path.join(self.save_dir, "train2017.json"),
},
)
val_process = Process(
target=self.format2coco,
kwargs={
"ids": self.val_ids,
"json_path": os.path.join(self.save_dir, "val2017.json"),
},
)
test_process = Process(
target=self.format2coco,
kwargs={
"ids": self.test_ids,
"json_path": os.path.join(self.save_dir, "test2017.json"),
},
)
# 启动进程
train_process.start()
val_process.start()
test_process.start()
def parse_args():
parser = argparse.ArgumentParser(
description="TT-100K dataset to COCO format.")
parser.add_argument("--data_dir", type=str, help="数据位置")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
TT100k2COCO(args.data_dir).processing()