-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsplit_data.py
117 lines (103 loc) · 5.48 KB
/
split_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import argparse
import itertools
import os
import pickle
import functools
import numpy as np
from datasets import get_data, validate_datasets
from net import alexnet, preprocess_images_alexnet
from predictions import get_predictions_filepath
from weights import validate_weights
def __adapt_to_softmax(x, length):
try:
len(x)
except TypeError:
x = [x]
softmaxs = np.zeros([len(x), length])
for row, col in enumerate(x):
softmaxs[row, col] = 1
return softmaxs
def __image_truth_generator(truths_mapping, datasets_directory, load_image, output_length, path_output=None):
assert len(truths_mapping) > 0
for image_path, truth_label in itertools.cycle(truths_mapping.items()):
image = load_image(os.path.join(datasets_directory, image_path))
truth = __adapt_to_softmax(truth_label, output_length)
yield image, truth
if path_output is not None:
path_output.append(image_path)
def __image_generator(image_truth_generator):
for image, _ in image_truth_generator:
yield image
def retrain(model, datasets_directory, datasets, image_truth_generator, num_epochs=10, image_batch_size=1000):
for retrain_dataset in datasets:
print("Collect %s" % retrain_dataset)
truths_mapping = get_data(retrain_dataset, datasets_directory)
generator = image_truth_generator(truths_mapping, datasets_directory=datasets_directory)
print("Retrain on %s" % retrain_dataset)
# TODO: use best weights based on validation error
model.fit_generator(generator, samples_per_epoch=len(truths_mapping), nb_epoch=num_epochs,
max_q_size=image_batch_size)
weights_file = "weights/retrain/%s/%s/%depochs.h5" % (args.model, retrain_dataset, args.num_epochs)
print("Save weights to %s" % weights_file)
os.makedirs(os.path.dirname(weights_file), exist_ok=True)
model.save_weights(weights_file)
def predict(model, weights_names, datasets_directory, datasets, image_truth_generator, image_batch_size=1000):
for dataset_name in datasets:
print("Collect %s" % dataset_name)
truths_mapping = get_data(dataset_name, datasets_directory)
for weights_name in weights_names:
print("Predicting with %s" % weights_name)
model.load_weights("weights/%s.h5" % weights_name)
image_paths = []
generator = image_truth_generator(truths_mapping, datasets_directory=datasets_directory,
path_output=image_paths)
generator = __image_generator(generator)
predictions = model.predict_generator(generator, val_samples=len(truths_mapping),
max_q_size=image_batch_size)
mapped_predictions = dict((image, prediction) for image, prediction
in zip(image_paths[:len(predictions)], predictions))
results_filepath = get_predictions_filepath(dataset_name, weights_name)
print("Writing predictions to %s" % results_filepath)
with open(results_filepath, 'wb') as results_file:
pickle.dump({'predictions': mapped_predictions, 'dataset': dataset_name, 'weights': weights_name},
results_file)
if __name__ == '__main__':
# options
models = {'alexnet': alexnet}
image_preprocessors = {'alexnet': preprocess_images_alexnet}
# params - command line
parser = argparse.ArgumentParser(description='Neural Net Robustness')
parser.add_argument('--model', type=str, default=next(models.__iter__()),
help='The model to run', choices=models.keys())
parser.add_argument('--datasets_directory', type=str, default='datasets',
help='The directory all datasets are stored in')
parser.add_argument('--datasets', type=str, nargs='+', default=['ILSVRC2012/val'],
help='The datasets to either re-train the model with or to predict')
parser.add_argument('--weights', type=str, nargs='+', default=None,
help='The set of weights to use for prediction - re-trains the model if this is not set')
parser.add_argument('--num_epochs', type=int, default=50,
help='how many epochs to search for optimal weights during training')
parser.add_argument('--image_batch_size', type=int, default=1000,
help='how many images to load into memory at once')
args = parser.parse_args()
print('Running with args', args)
weights = args.weights
if weights is not None:
validate_weights(weights)
datasets = args.datasets
validate_datasets(datasets, args.datasets_directory)
# model
model = models[args.model]()
output_shape = model.get_output_shape_at(-1)
generator = functools.partial(__image_truth_generator,
load_image=lambda path: image_preprocessors[args.model]([path]),
output_length=output_shape[1])
if weights is None:
print("Retraining")
model.load_weights("weights/%s.h5" % args.model)
retrain(model, args.datasets_directory, datasets, num_epochs=args.num_epochs,
image_truth_generator=generator, image_batch_size=args.image_batch_size)
else:
print("Predicting")
predict(model, weights, args.datasets_directory, datasets,
image_truth_generator=generator, image_batch_size=args.image_batch_size)