-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
executable file
·122 lines (95 loc) · 4.4 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
Copyright (C) 2019 NVIDIA Corporation. All rights reserved. Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""
import os
if "STY" not in os.environ.keys():
import multiprocessing
multiprocessing.set_start_method('spawn', True)
import sys
from collections import OrderedDict
from options.train_options import TrainOptions
from datasets.coco_loader import get_loader
import torch
from modules.helpers.iter_counter import IterationCounter
from modules.helpers.visualizer import Visualizer
from trainers.olie_trainer import OlieTrainer
from detectron2.checkpoint import DetectionCheckpointer
from modules.solov2.solov2 import SOLOv2
import warnings
from etaprogress.progress import ProgressBar
from adet.config import get_cfg
warnings.filterwarnings("ignore")
def setup_cfg(args):
# load config from file and command-line arguments
cfg = get_cfg()
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
return cfg
def main(solo):
# parse options
opt = TrainOptions().parse()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# load the dataset
dataloader, _ = get_loader(device=device, \
root=opt.coco+'train2017', \
json=opt.coco+'annotations/instances_train2017.json', \
batch_size=opt.batch_size, \
shuffle=False, \
num_workers=0)
# create trainer for our model
trainer = OlieTrainer(opt, solo)
# create tool for counting iterations
iter_counter = IterationCounter(opt, len(dataloader))
# create tool for visualization
visualizer = Visualizer(opt)
for epoch in iter_counter.training_epochs():
iter_counter.record_epoch_start(epoch)
total = len(dataloader)
bar = ProgressBar(total, max_width=80)
for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter):
bar.numerator = i+1
print(bar, end='\r')
iter_counter.record_one_iteration()
# Training
# train generator
if i % opt.D_steps_per_G == 0:
trainer.run_generator_one_step(data_i)
# train discriminator
trainer.run_discriminator_one_step(data_i)
# Visualizations
if iter_counter.needs_printing():
losses = trainer.get_latest_losses()
visualizer.print_current_errors(epoch, iter_counter.epoch_iter,
losses, iter_counter.time_per_iter)
visualizer.plot_current_errors(losses, iter_counter.total_steps_so_far)
if iter_counter.needs_displaying():
visuals = OrderedDict([('input_label', trainer.get_semantics().max(dim=1)[1].cpu().unsqueeze(1)),
('synthesized_image', trainer.get_latest_generated()),
('real_image', data_i['image']),
('masked', trainer.get_mask())])
if not opt.no_instance:
visuals['instance'] = trainer.get_semantics()[:,35].cpu()
visualizer.display_current_results(visuals, epoch, iter_counter.total_steps_so_far)
if iter_counter.needs_saving():
print('saving the latest model (epoch %d, total_steps %d)' %
(epoch, iter_counter.total_steps_so_far))
trainer.save('latest')
iter_counter.record_current_iter()
trainer.update_learning_rate(epoch)
iter_counter.record_epoch_end()
if epoch % opt.save_epoch_freq == 0 or \
epoch == iter_counter.total_epochs:
print('saving the model at the end of epoch %d, iters %d' %
(epoch, iter_counter.total_steps_so_far))
trainer.save('latest')
trainer.save(epoch)
print('Training was successfully finished.')
if __name__ == '__main__':
args = TrainOptions().parse()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cfg = setup_cfg(args)
solo = SOLOv2(cfg=cfg).to(device)
checkpointer = DetectionCheckpointer(solo)
checkpointer.load(cfg.MODEL.WEIGHTS)
main(solo=solo.eval())