-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy path__init__.py
executable file
·76 lines (59 loc) · 2.35 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import importlib
import torch.utils.data
from data.base_data_loader import BaseDataLoader
from data.base_dataset import BaseDataset
def find_dataset_using_name(dataset_mode):
# Given the option --dataset_mode [datasetmode],
# the file "data/datasetmode_dataset.py"
# will be imported.
dataset_filename = "data." + dataset_mode + "_dataset"
datasetlib = importlib.import_module(dataset_filename)
# In the file, the class called DatasetModeDataset() will
# be instantiated. It has to be a subclass of BaseDataset,
# and it is case-insensitive.
dataset = None
target_dataset_name = dataset_mode.replace('_', '') + 'dataset'
for name, cls in datasetlib.__dict__.items():
if name.lower() == target_dataset_name.lower() \
and issubclass(cls, BaseDataset):
dataset = cls
if dataset is None:
print("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase."
% (dataset_filename, target_dataset_name))
exit(0)
return dataset
def get_option_setter(dataset_name):
dataset_class = find_dataset_using_name(dataset_name)
return dataset_class.modify_commandline_options
def create_dataset(opt):
dataset = find_dataset_using_name(opt.dataset_mode)
instance = dataset()
instance.initialize(opt)
print("dataset [%s] was created" % (instance.name()))
return instance
def CreateDataLoader(opt):
data_loader = CustomDatasetDataLoader()
data_loader.initialize(opt)
return data_loader
# Wrapper class of Dataset class that performs
# multi-threaded data loading
class CustomDatasetDataLoader(BaseDataLoader):
def name(self):
return 'CustomDatasetDataLoader'
def initialize(self, opt):
BaseDataLoader.initialize(self, opt)
self.dataset = create_dataset(opt)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batch_size,
shuffle=not opt.serial_batches,
num_workers=int(opt.num_threads))
def load_data(self):
return self
def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size)
def __iter__(self):
for i, data in enumerate(self.dataloader):
if i * self.opt.batch_size >= self.opt.max_dataset_size:
break
yield data