-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
111 lines (91 loc) · 2.72 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from os import error
import tensorflow.keras as keras
import tensorflow as tf
import sys
import getopt
import importlib
import numpy as np
import dlci
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
# Grab passed arguments
opts, args = getopt.getopt(sys.argv[2:], "i:e:p:n:r:")
script = sys.argv[1]
# Defaults
args = {
"epochs": 100000,
"patience": 100,
}
username = dlci.get_user_name()
index = None
for opt, arg in opts:
if opt == "-i":
index = arg
elif opt == "-e":
args["epochs"] = int(arg)
elif opt == "-p":
args["patience"] = int(arg)
elif opt == "-r":
args["repet"] = int(arg)
elif opt == "-n":
username = arg
if index is None:
error("option -i not set")
sys.exit(0)
indices = dlci.parse_index(index)
user_models = importlib.import_module(script)
for index in indices:
try:
m_header_dict, model = user_models.get_model(index)
d_header_dict, generator = user_models.get_generator(index)
except KeyError as e:
print(e)
print(
"The obove error likely occured because the index range was"
" larger than the number of defined models. If so, you can"
" safely disregard the error."
)
break
headers = {**m_header_dict, **d_header_dict, **args}
model.compile(loss="mae", metrics=dlci.metrics())
print("")
print("=" * 50, "START", "=" * 50)
print(
"Running trial on model {0}, with patience of {1}".format(
index, args["patience"]
)
)
# Define callbacks
early_stopping = keras.callbacks.EarlyStopping(
monitor="val_loss",
patience=args["patience"],
restore_best_weights=True,
)
# Define data generators
print("Grabbing validation data...")
validation_data = dlci.get_validation_set(2)
for i in range(args["repet"]):
with generator:
h = model.fit(
generator,
epochs=args["epochs"],
callbacks=[early_stopping],
validation_data=(
np.expand_dims(validation_data[0][0], axis=0),
np.expand_dims(validation_data[1][0], axis=0),
),
validation_batch_size=1,
)
predictions = [
model.predict(np.expand_dims(file, axis=0), batch_size=1)
for file in validation_data[0][:2]
]
dlci.save_training_results(
index=index,
name=username,
history=h.history,
model=model.generator,
headers=headers,
inputs=validation_data[0],
predictions=predictions,
targets=validation_data[1],
)