-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_loader.py
137 lines (108 loc) · 4.26 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import random
import torch
import torchvision.transforms as transforms
import torch.utils.data as data
from PIL import Image
from tqdm import tqdm
def print_error(e):
import traceback
traceback.print_exc()
print(e)
class DTDDataLoader(data.Dataset):
def __init__(self, data_root, img_list_path="train.txt", transform=None):
"""
args:
data_root: str
root directory of the dataset.
if we are using DTD, it should be the "path_to_DTDdataset_file/dtd/images"
img_list_path: str
list for which texture name to use, it should be per line like
dotted
stripe
blotchy
.
.
.
transform: torchvision.transforms
image transfoms.
if it is None, it will perform, ToTensor->Normalize(std, mean at .5)
"""
self.transform = transform
self.images = []
# read texture names
with open(os.path.join(img_list_path), "r") as file:
self.texture_names = file.readlines()
self.texture_names = [name.rstrip("\n") for name in self.texture_names]
tqdm.write("loading images...")
# read the
for tex_name in tqdm(self.texture_names, desc="textures", ncols=80):
texture_file_path = os.path.join(data_root, tex_name)
try:
images = os.listdir(texture_file_path)
except Exception as e:
#print_error(e)
tqdm.write("pass {}".format(tex_name))
for img_name in images:
try:
self.images.append(Image.open(os.path.join(texture_file_path, img_name)).convert('RGB'))
except Exception as e:
#print_error(e)
tqdm.write("pass {}".format(img_name))
#break
self.data_num = len(self.images)
def __getitem__(self, index):
"""
the single image is pick from the index.
"""
if self.transform is not None:
img = self.transform(self.images[index])
else:
img = transforms.ToTensor()(self.images[index])
img = transforms.Normalize(mean=(0.5, 0.5, 0.5),std=(0.5, 0.5, 0.5))(img)
return img
def __len__(self):
return self.data_num
# for your own data, which read all file in data_dir
class ImageDataLoader(data.Dataset):
def __init__(self, data_dir, transform=None):
"""
args:
data_root: str
root directory of the dataset.
transform: torchvision.transforms
image transfoms.
if it is None, it will perform, ToTensor->Normalize(std, mean at .5)
"""
self.transform = transform
self.images = []
tqdm.write("loading images...")
images = os.listdir(data_dir)
# read the
for img_name in tqdm(images, desc="image", ncols=80):
img_file_path = os.path.join(data_dir, img_name)
try:
self.images.append(Image.open(img_file_path).convert('RGB'))
except Exception as e:
#print_error(e)
tqdm.write("pass {}".format(img_name))
self.data_num = len(self.images)
def __getitem__(self, index):
"""
the single image is pick from the index.
"""
if self.transform is not None:
img = self.transform(self.images[index])
else:
img = transforms.ToTensor()(self.images[index])
img = transforms.Normalize(mean=(0.5, 0.5, 0.5),std=(0.5, 0.5, 0.5))(img)
return img
def __len__(self):
return self.data_num
# data loader for dataset
def get_loader(data_set, batch_size, shuffle, num_workers):
data_loader = torch.utils.data.DataLoader(dataset=data_set,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers)
return data_loader