-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcoco_api_use_case.py
105 lines (93 loc) · 3.55 KB
/
coco_api_use_case.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""
coco API use case.
Created On 9th Mar, 2020
Author: bohang.li
"""
import os
import torch
from torch.utils.data import Dataset, DataLoader
from collections import Counter
from torchvision import datasets
from pycocotools.coco import COCO
import skimage.io as sio
import pandas as pd
image_folder = "/ldap_home/bohang.li/image_search_download/download"
jsonfile = "/ldap_home/bohang.li/image_search_download/image_search_data_annotation.json"
# jsonfile = "/ldap_home/bohang.li/image_search_download/image_search_val.json"
category_list = [str(x) for x in category_list]
inv_cat_idx = {k: v for v, k in enumerate(category_list)}
coco_dict = COCO(jsonfile)
img2anns = coco_dict.imgToAnns
samples = []
count = 0
for img, labels in img2anns.items():
cat_ids = set()
for lbl in labels:
cat_ids.add(lbl["category_id"])
# remove muti-label images(different labels), only reduce 2168466-2168449=17 images.
if len(cat_ids) > 1 or len(cat_ids) <= 0:
continue
category = coco_dict.cats[list(cat_ids)[0]]["name"]
# remove class not in our list.
if category not in category_list:
continue
img_filename = coco_dict.imgs[img]["file_name"]
samples.append({"id": str(count).zfill(8), "image": img_filename.split(".")[0], "label": category})
count += 1
res = pd.DataFrame(samples)
res.to_csv("../embedding.csv", index=False)
class ImageSearchDataset(Dataset):
"""
Dataset for internal image search dataset.
"""
def __init__(self, jsonfile, category_list, image_folder, transform=None, target_transform=None):
"""
:param jsonfile: coco format json file.
:param category_list: list, categories needed.
:param image_folder: str, the image folder path.
"""
category_list = [str(x) for x in category_list]
self.image_folder = image_folder
self.category_list = category_list
self.transform = transform
self.target_transform = target_transform
coco_dict = COCO(jsonfile)
img2anns = coco_dict.imgToAnns
samples = []
for img, labels in img2anns.items():
cat_ids = set()
for lbl in labels:
cat_ids.add(lbl["category_id"])
# remove muti-label images(different labels), only reduce 2168466-2168449=17 images.
if len(cat_ids) > 1 or len(cat_ids) <= 0:
continue
category = coco_dict.cats[list(cat_ids)[0]]["name"]
# remove class not in our list.
if category not in category_list:
continue
img_filename = coco_dict.imgs[img]["file_name"]
samples.append({"image": img_filename, "label": self.inv_cat_idx[category]})
self.samples = samples
def __len__(self):
return len(self.samples)
def __getitem__(self, item):
img_file, label = self.samples[item]["image"], self.samples[item]["label"]
image = sio.imread(os.path.join(self.image_folder, img_file))
if self.transform is not None:
image = self.transform(image)
if self.target_transform is not None:
label = self.target_transform(label)
return image, label
@property
def cat_idx(self):
idx_dict = {k: v for k, v in enumerate(self.category_list)}
return idx_dict
@property
def inv_cat_idx(self):
inv_cat_idx = {k: v for v, k in enumerate(self.category_list)}
return inv_cat_idx
def stats(self):
c = Counter()
for item in self.samples:
c[self.cat_idx[item["label"]]] += 1
return c