-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathcub_loader.py
68 lines (52 loc) · 2.12 KB
/
cub_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""
The Caltech-UCSD birds dataset.
"""
import os
import math
import numpy as np
import torch
import torch.utils.data as data
from PIL import Image
import utils
class CUBImages(data.Dataset):
def __init__(self, root, classes=range(200), transform=None, im_size=128):
self.loader = utils.DefaultImageLoader
self.transform = transform
self.im_size = im_size
# paths
self.root = root
self.im_base_path = os.path.join(root, 'images')
# load metadata
images = [line.split()[1] for line in
open(os.path.join(root, 'images.txt'), 'r')]
labels = [int(line.split()[1]) - 1 for line in
open(os.path.join(root, 'image_class_labels.txt'), 'r')]
birdnames = [line.split()[1] for line in
open(os.path.join(root, 'classes.txt'), 'r')]
boxes = [[int(round(float(c))) for c in line.split()[1:]] for line in
open(os.path.join(root, 'bounding_boxes.txt'),'r')]
name_to_id = dict(zip(birdnames, range(len(birdnames))))
self.birdnames = birdnames
# which classes to include
self.classes = classes
self.num_classes = len(classes)
split = [l in classes for l in labels]
# load list and metadata for train/test set
# paths
self.images = [image for image, val in zip(images, split) if val]
# labels
self.labels = np.array([label for label, val in zip(labels, split) if val])
# boxes
self.boxes = np.array([box for box, val in zip(boxes, split) if val])
# number of images
self.num_images = len(self.images)
print("CUB loader initialized for %d classes, %d images" % (self.num_classes, self.num_images))
def __len__(self):
return self.num_images
def __getitem__(self, index):
img = self.loader(os.path.join(self.im_base_path, self.images[index]))
img.crop(self.boxes[index])
img = utils.Resize(img, self.im_size)
if self.transform is not None:
img = self.transform(img)
return img, self.labels[index], index