-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdraw_classifier_train_job.py
142 lines (118 loc) · 4.6 KB
/
draw_classifier_train_job.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import numpy as np
import os
import argparse
import wget
import json
import torch
from sklearn.model_selection import train_test_split
from utils.image_utils import join_transformed_images
from utils.train_utils import build_model, fit_model, evaluate_model
def load_label_dict():
with open('service/label_dict.json') as f:
return json.load(f)
# URL to dataset in GCP Storage
dataset_url = 'https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/'
labels = load_label_dict()
data_filepath = 'datasets'
num_categories = len(labels)
# Hyperparameters for our network
model_path = 'service/models/model.nnet'
input_size = 784
hidden_sizes = [128, 100, 64]
output_size = 10
dropout = 0.0
# Fit parameters
n_chunks = 1000
learning_rate = 0.003
weight_decay = 0
def download_datasets(labels):
"""
Download data for each label
:param labels: list of labels
"""
for category in labels:
if not os.path.exists(data_filepath + '/' + str(category) + '.npy'):
print("Start downloading data process for [{}].".format(category))
url = dataset_url + str(category) + '.npy'
wget.download(
url=url,
out=data_filepath
)
print("Dataset for {} was successfully downloaded.".format(category))
else:
print("Dataset for {} is already downloaded.".format(category))
def prepare_datasets(labels, num_examples):
"""
Take some number of data from examples and split to train.
:param labels: list of labels
:param num_examples: number of examples
:return: X_train, X_test, y_train, y_test
"""
classes_dict = {}
for category in labels:
classes_dict[category] = np.load(data_filepath + '/' + str(category) + '.npy')
# Generate labels and add labels to loaded data
for i, (key, value) in enumerate(classes_dict.items()):
value = value.astype('float32') / 255.
if i == 0:
classes_dict[key] = np.c_[value, np.zeros(len(value))]
else:
classes_dict[key] = np.c_[value, i * np.ones(len(value))]
lst = []
for key, value in classes_dict.items():
lst.append(value[:num_examples])
tmp = np.concatenate(lst)
# Split the data into features and class labels (X & y respectively)
y = tmp[:, -1].astype('float32')
X = tmp[:, :784]
# Split each dataset into train/test splits
return train_test_split(X, y, test_size=0.3, random_state=1)
def save_np_data(data, file_name):
with open('{}/{}'.format(data_filepath, file_name), 'wb') as f:
np.save(f, data)
def load_np_data(file_name):
with open('{}/{}'.format(data_filepath, file_name), 'rb') as f:
return np.load(f)
def main(num_examples, epochs, from_cache):
print('Start train process with below properties:')
print('Number of examples: {}'.format(num_examples))
print('Train epochs: {}'.format(epochs))
if from_cache:
print('Load data from stage')
X_train = load_np_data('X_train.npy')
y_train = load_np_data('y_train.npy')
X_test = load_np_data('X_test.npy')
y_test = load_np_data('y_test.npy')
else:
download_datasets(labels)
X_train, X_test, y_train, y_test = prepare_datasets(labels, num_examples)
print('Generate new data and join to train dataset')
X_train, y_train = join_transformed_images(X_train, y_train)
train = torch.from_numpy(X_train).float()
train_labels = torch.from_numpy(y_train).long()
test = torch.from_numpy(X_test).float()
test_labels = torch.from_numpy(y_test).long()
print('Build model')
model = build_model(input_size, output_size, hidden_sizes, dropout)
print('Start fitting')
fit_model(model, train, train_labels,
epochs=epochs, n_chunks=n_chunks, learning_rate=learning_rate, weight_decay=weight_decay)
evaluate_model(model, train, train_labels, test, test_labels)
metainfo = {'input_size': input_size,
'output_size': output_size,
'hidden_layers': hidden_sizes,
'dropout': dropout,
'state_dict': model.state_dict()}
print('End fit')
torch.save(metainfo, model_path)
print("Model saved to {}\n".format(model_path))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--num_examples', default=3000)
parser.add_argument('--epochs', default=25)
parser.add_argument('--fromCache', default=False)
args = parser.parse_args()
num_examples = args.num_examples
epochs = args.epochs
from_cache = args.fromCache
main(num_examples, epochs, from_cache)