-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathtrain.py
51 lines (42 loc) · 1.72 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from etc.utils import set_seed, ensure_dirs, get_config
from data_loader import get_dataloader
from torch.utils.tensorboard import SummaryWriter
from trainer import Trainer
import shutil
import argparse
import os
import sys
BASEPATH = os.path.dirname(__file__)
sys.path.insert(0, BASEPATH)
def initialize_path(args, config, save=True):
config['main_dir'] = os.path.join('.', config['name'])
config['model_dir'] = os.path.join(config['main_dir'], "pth")
config['tb_dir'] = os.path.join(config['main_dir'], "log")
config['info_dir'] = os.path.join(config['main_dir'], "info")
config['output_dir'] = os.path.join(config['main_dir'], "output")
ensure_dirs([config['main_dir'], config['model_dir'], config['tb_dir'],
config['info_dir'], config['output_dir']])
if save:
shutil.copy(args.config, os.path.join(config['info_dir'], 'config.yaml'))
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='configs/config.yaml',
help='Path to the config file.')
args = parser.parse_args()
""" initialize """
config = get_config(args.config)
initialize_path(args, config)
# Set random seed for reproducibility
print("Random Seed: ", config['manualSeed'])
set_seed(config['manualSeed'])
""" Dataloader """
train_src_loader = get_dataloader('train', config)
train_tar_loader = get_dataloader('train', config)
loader = {'train_src': train_src_loader, 'train_tar': train_tar_loader}
""" Summary Writer """
train_writer = SummaryWriter(os.path.join(config['tb_dir'], 'train'))
# Trainer
trainer = Trainer(config)
tr_info = open(os.path.join(config['info_dir'], "info-network"), "w")
print(trainer.gen, file=tr_info)
tr_info.close()
trainer.train(loader, train_writer)