-
Notifications
You must be signed in to change notification settings - Fork 0
/
runner_train.py
90 lines (65 loc) · 3.32 KB
/
runner_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import tensorflow as tf
import click
import logging
import os
import util.global_config as global_config
import util.builder
import util.helper
import trainloop
import networks.faster_rcnn_odapi_loader
logging.basicConfig(level=logging.INFO)
def loop(sess, input_pipeline_tensors, input_handles, network_tensors, frcnn):
filename = util.helper.summary_file([global_config.cfg['batch_size'],
global_config.cfg['epochs']])
train_writer = tf.summary.FileWriter(filename, sess.graph)
saver = tf.train.Saver()
if global_config.cfg['restore']:
logging.info('Restoring from latest checkpoint.')
saver.restore(sess, tf.train.latest_checkpoint(global_config.cfg['checkpoints']))
else:
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
sess.run(init_op)
globalstep = 0
cls_weight = global_config.cfg['cls_weight']
reg_weight = global_config.cfg['reg_weight']
frcnn_saved = util.helper.loaddict(global_config.cfg['frcnn_saved_file'])
if global_config.cfg['mode'] == 'training':
with util.helper.timeit() as ttime:
for epoch in range(global_config.cfg['epochs']):
globalstep = trainloop.run(sess, input_pipeline_tensors, input_handles, network_tensors,
train_writer, epoch, saver, globalstep, frcnn,
cls_weight, reg_weight, frcnn_saved)
logging.info('Done training (' + str(ttime.time()) + ' sec, ' + str(globalstep) + ' steps).')
# elif global_config.cfg['mode'] == 'testing':
# with util.helper.timeit() as ttime:
# globalstep = trainloop.run(sess, input_pipeline_tensors, input_handles, network_tensors,
# frcnn, ###)
#
# logging.info('Done testing (' + str(ttime.time()) + ' sec.)')
@click.command()
@click.option("--config", default="config.yml", help="The configuration file.")
def main(config):
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# This makes the configuration available as global_config.cfg dictionary and makes the required folders.
global_config.read(config)
frcnn = networks.faster_rcnn_odapi_loader.faster_rcnn_odapi()
frcnn.import_graph(global_config.cfg['faster_rcnn_graph'])
with tf.Graph().as_default():
# Build the computational graph.
input_pipeline_tensors, input_handles = util.builder.build_input_pipeline()
network_tensors = util.builder.build_lstm_and_classifier()
with tf.Session() as sess:
training_handle = sess.run(input_handles['tr_h'])
validation_handle = sess.run(input_handles['val_h'])
# This is required for switching between training and validation phase.
input_handles_eval = {
'training_initializer': input_handles['tr_it'].initializer,
'validation_initializer': input_handles['val_it'].initializer,
'training_handle': training_handle,
'validation_handle': validation_handle,
'handle': input_handles['h']
}
loop(sess, input_pipeline_tensors, input_handles_eval, network_tensors, frcnn)
if __name__ == '__main__':
main()