-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtraining.py
88 lines (73 loc) · 2.64 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from utils.data_utils.DataHandler import DataHandler
from models.models import dot_win37_dep9, dot_win19_dep9
from keras.callbacks import (
EarlyStopping, LearningRateScheduler, TensorBoard, ModelCheckpoint,
ReduceLROnPlateau)
from keras.optimizers import Adagrad
import keras.backend as K
import numpy as np
import os
import tensorflow as tf
import datetime
def create_experiment(path):
now = datetime.datetime.now()
experiment_name = now.strftime("%d%m%y_%H%M%S")
path = os.path.join(path, experiment_name)
if not os.path.isdir(path):
os.makedirs(path)
return {"path": path}
def get_callbacks(experiment):
def schedule(epoch, lr):
if epoch > 10:
out = lr * 0.2
elif epoch % 5 == 0 and epoch > 10:
out = lr * 0.2
return out
monitor, patience, verbose = "val_loss", 10, True
es = EarlyStopping(monitor=monitor, patience=10, verbose=verbose)
mdlchkpt = ModelCheckpoint(
filepath=os.path.join(experiment["path"], "siamese_net.hdf5"),
monitor=monitor, save_best_only=True)
tb = TensorBoard(log_dir=experiment["path"])
lrs = ReduceLROnPlateau(patience=10, verbose=True)
return [mdlchkpt, tb, lrs]
def main():
data_lookup = {
"train": "tr_160_18_100.bin",
"val": "val_40_18_100.bin"
}
args = {
"batch_size": 32,
"data_version": "kitti2015",
"util_root": "/home/marco/repos/EfficientStereoMatching/data/KITTI2015/debug_15/",
"data_root": "/home/marco/repos/EfficientStereoMatching/data/KITTI2015/data_scene_flow/training",
"experiment_root": "/home/marco/repos/EfficientStereoMatching/experiments",
"num_val_loc": 1000,
"num_tr_img": 160, # TODO: Avoid hardcoding this
"num_val_img": 40,
}
dh = DataHandler(args)
dh.load(data_lookup["train"])
train_gen = dh.generator
train_samples = dh.pixel_loc.shape[0]
dh_val = DataHandler(args)
dh_val.load(data_lookup["val"])
val_gen = dh_val.generator
val_samples = dh_val.pixel_loc.shape[0]
network = dot_win37_dep9(dh.args["l_psz"], dh.args["r_psz"])
network.build_model()
mdl = network.model
experiment = create_experiment(args["experiment_root"])
cbs = get_callbacks(experiment)
optim = Adagrad()
mdl.compile(optimizer=optim, loss="categorical_crossentropy")
mdl.fit_generator(
generator=train_gen,
steps_per_epoch=train_samples // args["batch_size"],
epochs=10,
verbose=1,
callbacks=cbs,
validation_data=val_gen,
validation_steps=val_samples // args["batch_size"])
if __name__ == "__main__":
main()