-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
executable file
·31 lines (22 loc) · 974 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import argparse
from Trainers.Trainer import getTrainer
from configs.config_loader import load_config
from tboard import initiateTensorboard
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Mondep Configuration')
parser.add_argument(
"-c", "--conf", action="store", dest="conf_file",
help="Path to config file"
)
parser.add_argument(
"-m", "--model", action="store", dest="model",
help="model name"
)
parser.add_argument("-tb", "--tensorboard", action="store_true", dest="tb_flag",help="tensorboard flag")
parser.add_argument("-tbpth", "--tensorboard_path", action="store", dest="tb_path",help="tensorboard path")
args = parser.parse_args()
initiateTensorboard(args.tb_flag,args.tb_path)
config = load_config(config_path=args.conf_file,model_name=args.model)
net = getTrainer(config)
net.train()
print("Model Training Completed: ",config['model_name'])