forked from WZMIAOMIAO/deep-learning-for-image-processing
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmy_dataset_coco.py
108 lines (91 loc) · 4.01 KB
/
my_dataset_coco.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import copy
import torch
import numpy as np
import cv2
import torch.utils.data as data
from pycocotools.coco import COCO
class CocoKeypoint(data.Dataset):
def __init__(self,
root,
dataset="train",
years="2017",
transforms=None,
det_json_path=None,
fixed_size=(256, 192)):
super().__init__()
assert dataset in ["train", "val"], 'dataset must be in ["train", "val"]'
anno_file = f"person_keypoints_{dataset}{years}.json"
assert os.path.exists(root), "file '{}' does not exist.".format(root)
self.img_root = os.path.join(root, f"{dataset}{years}")
assert os.path.exists(self.img_root), "path '{}' does not exist.".format(self.img_root)
self.anno_path = os.path.join(root, "annotations", anno_file)
assert os.path.exists(self.anno_path), "file '{}' does not exist.".format(self.anno_path)
self.fixed_size = fixed_size
self.mode = dataset
self.transforms = transforms
self.coco = COCO(self.anno_path)
img_ids = list(sorted(self.coco.imgs.keys()))
if det_json_path is not None:
det = self.coco.loadRes(det_json_path)
else:
det = self.coco
self.valid_person_list = []
obj_idx = 0
for img_id in img_ids:
img_info = self.coco.loadImgs(img_id)[0]
ann_ids = det.getAnnIds(imgIds=img_id)
anns = det.loadAnns(ann_ids)
for ann in anns:
# only save person class
if ann["category_id"] != 1:
print(f'warning: find not support id: {ann["category_id"]}, only support id: 1 (person)')
continue
# COCO_val2017_detections_AP_H_56_person.json文件中只有det信息,没有keypoint信息,跳过检查
if det_json_path is None:
# skip objs without keypoints annotation
if "keypoints" not in ann:
continue
if max(ann["keypoints"]) == 0:
continue
xmin, ymin, w, h = ann['bbox']
# Use only valid bounding boxes
if w > 0 and h > 0:
info = {
"box": [xmin, ymin, w, h],
"image_path": os.path.join(self.img_root, img_info["file_name"]),
"image_id": img_id,
"image_width": img_info['width'],
"image_height": img_info['height'],
"obj_origin_hw": [h, w],
"obj_index": obj_idx,
"score": ann["score"] if "score" in ann else 1.
}
# COCO_val2017_detections_AP_H_56_person.json文件中只有det信息,没有keypoint信息,跳过
if det_json_path is None:
keypoints = np.array(ann["keypoints"]).reshape([-1, 3])
visible = keypoints[:, 2]
keypoints = keypoints[:, :2]
info["keypoints"] = keypoints
info["visible"] = visible
self.valid_person_list.append(info)
obj_idx += 1
def __getitem__(self, idx):
target = copy.deepcopy(self.valid_person_list[idx])
image = cv2.imread(target["image_path"])
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self.transforms is not None:
image, person_info = self.transforms(image, target)
return image, target
def __len__(self):
return len(self.valid_person_list)
@staticmethod
def collate_fn(batch):
imgs_tuple, targets_tuple = tuple(zip(*batch))
imgs_tensor = torch.stack(imgs_tuple)
return imgs_tensor, targets_tuple
if __name__ == '__main__':
train = CocoKeypoint("/data/coco2017/", dataset="val")
print(len(train))
t = train[0]
print(t)