forked from AliaksandrSiarohin/wc-gan
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tiny_imagenet.py
67 lines (61 loc) · 2.28 KB
/
tiny_imagenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from tensorflow.python.keras.utils.data_utils import get_file
import os
import numpy as np
from skimage.transform import resize
from skimage import img_as_ubyte
from tqdm import tqdm
from skimage.io import imread
from skimage.color import gray2rgb
import pickle
def load_data():
"""Loads tiny-imagenet dataset.
# Returns
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
# Raises
ValueError: in case of invalid `label_mode`.
"""
# dirname = 'tiny-imagenet-200'
origin = 'http://cs231n.stanford.edu/tiny-imagenet-200.zip'
path = get_file('tiny-imagenet-200.zip', origin=origin, extract=True, cache_dir='.', archive_format='zip')
path = path.replace('.zip', '')
def load_train_images():
subdir = 'train'
X = np.empty((500 * 200, 64, 64, 3), dtype='uint8')
Y = np.empty((500 * 200, ), dtype='int')
classes = []
for cls in os.listdir(os.path.join(path, subdir)):
classes.append(cls)
# f = open('ti_classses.pkl', 'w')
# pickle.dump(classes, f)
# f.close()
classes = {name: i for i, name in enumerate(classes)}
i = 0
for cls in tqdm(os.listdir(os.path.join(path, subdir)),ascii=True):
for img in os.listdir(os.path.join(path, subdir, cls, 'images')):
name = os.path.join(path, subdir, cls, 'images', img)
image = imread(name)
if len(image.shape) == 2:
image = gray2rgb(image)
X[i] = image
Y[i] = classes[cls]
i += 1
print(i)
return X, Y
def load_test_images():
X = np.empty((100 * (50 + 50), 64, 64, 3), dtype='uint8')
Y = None
i = 0
for subdir in ('test', ):
for img in tqdm(os.listdir(os.path.join(path, subdir, 'images')), ascii=True):
name = os.path.join(path, subdir, 'images', img)
image = imread(name)
if len(image.shape) == 2:
image = gray2rgb(image)
X[i] = image
i += 1
print(i)
return X, Y
print ("Loading images...")
X_train, Y_train = load_train_images()
X_test, Y_test = load_test_images()
return (X_train, Y_train), (X_test, Y_test)