-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_efficientdet.py
258 lines (234 loc) · 11.3 KB
/
train_efficientdet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
#coding=utf-8
import os
import logging
import warnings
import time
import numpy as np
import mxnet as mx
from mxnet import nd
from mxnet import gluon
from mxnet import autograd
import gluoncv as gcv
gcv.utils.check_version('0.6.0')
from gluoncv import data as gdata
from gluoncv import utils as gutils
from gluoncv.model_zoo import get_model
from gluoncv.data.batchify import Tuple, Stack, Pad
from gluoncv.data.transforms.presets.ssd import SSDDefaultTrainTransform
from gluoncv.data.transforms.presets.ssd import SSDDefaultValTransform
from gluoncv.utils.metrics.voc_detection import VOC07MApMetric
from gluoncv.utils.metrics.coco_detection import COCODetectionMetric
from gluoncv.utils.metrics.accuracy import Accuracy
from gluoncv.utils import LRScheduler, LRSequential
from mxnet.contrib import amp
from config import parse_args
from model.efficientdet import get_efficientdet
from loss import EfficientDetLoss
def get_dataset(dataset, args):
if dataset.lower() == 'voc':
train_dataset = gdata.VOCDetection(
splits=[(2007, 'trainval'), (2012, 'trainval')])
val_dataset = gdata.VOCDetection(
splits=[(2007, 'test')])
val_metric = VOC07MApMetric(iou_thresh=0.5, class_names=val_dataset.classes)
elif dataset.lower() == 'coco':
train_dataset = gdata.COCODetection(root=args.dataset_root + "/coco", splits='instances_train2017')
val_dataset = gdata.COCODetection(root=args.dataset_root + "/coco", splits='instances_val2017', skip_empty=False)
val_metric = COCODetectionMetric(
val_dataset, args.save_prefix + '_eval', cleanup=True,
data_shape=(args.data_shape, args.data_shape))
# coco validation is slow, consider increase the validation interval
if args.val_interval == 1:
args.val_interval = 10
else:
raise NotImplementedError('Dataset: {} not implemented.'.format(dataset))
if args.num_samples < 0:
args.num_samples = len(train_dataset)
return train_dataset, val_dataset, val_metric
def get_dataloader(net, train_dataset, val_dataset, data_shape, batch_size, num_workers):
"""Get dataloader."""
width, height = data_shape, data_shape
# use fake data to generate fixed anchors for target generation
with autograd.train_mode():
_, _, anchors = net(mx.nd.zeros((1, 3, height, width)))
anchors = anchors.as_in_context(mx.cpu())
batchify_fn = Tuple(Stack(), Stack(), Stack()) # stack image, cls_targets, box_targets
train_loader = gluon.data.DataLoader(
train_dataset.transform(SSDDefaultTrainTransform(width, height, anchors)),
batch_size, True, batchify_fn=batchify_fn, last_batch='rollover', num_workers=num_workers)
val_batchify_fn = Tuple(Stack(), Pad(pad_val=-1))
val_loader = gluon.data.DataLoader(
val_dataset.transform(SSDDefaultValTransform(width, height)),
batch_size, False, batchify_fn=val_batchify_fn, last_batch='keep', num_workers=num_workers//2)
return train_loader, val_loader
def save_params(net, best_map, current_map, epoch, save_interval, prefix):
current_map = float(current_map)
if current_map > best_map[0]:
best_map[0] = current_map
net.save_parameters('{:s}_best.params'.format(prefix, epoch, current_map))
with open(prefix+'_best_map.log', 'a') as f:
f.write('{:04d}:\t{:.4f}\n'.format(epoch, current_map))
if save_interval and epoch % save_interval == 0:
net.save_parameters('{:s}_{:04d}_{:.4f}.params'.format(prefix, epoch, current_map))
def validate(net, val_data, ctx, eval_metric):
"""Test on validation dataset."""
eval_metric.reset()
# set nms threshold and topk constraint
net.set_nms(nms_thresh=0.45, nms_topk=400)
net.hybridize()
for batch in val_data:
data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0, even_split=False)
label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0, even_split=False)
det_bboxes = []
det_ids = []
det_scores = []
gt_bboxes = []
gt_ids = []
gt_difficults = []
for x, y in zip(data, label):
# get prediction results
ids, scores, bboxes = net(x)
det_ids.append(ids)
det_scores.append(scores)
# clip to image size
det_bboxes.append(bboxes.clip(0, batch[0].shape[2]))
# split ground truths
gt_ids.append(y.slice_axis(axis=-1, begin=4, end=5))
gt_bboxes.append(y.slice_axis(axis=-1, begin=0, end=4))
gt_difficults.append(y.slice_axis(axis=-1, begin=5, end=6) if y.shape[-1] > 5 else None)
# update metric
eval_metric.update(det_bboxes, det_ids, det_scores, gt_bboxes, gt_ids, gt_difficults)
return eval_metric.get()
def train(net, train_data, val_data, eval_metric, ctx, args):
"""Training pipeline"""
net.collect_params().reset_ctx(ctx)
if args.lr_decay_period > 0:
lr_decay_epoch = list(range(args.lr_decay_period, args.epochs, args.lr_decay_period))
else:
lr_decay_epoch = [int(i) for i in args.lr_decay_epoch.split(',')]
lr_decay_epoch = [e - args.warmup_epochs for e in lr_decay_epoch]
num_batches = args.num_samples // args.batch_size
lr_scheduler = LRSequential([
LRScheduler('linear', base_lr=0, target_lr=args.lr,
nepochs=args.warmup_epochs, iters_per_epoch=num_batches),
LRScheduler(args.lr_mode, base_lr=args.lr,
nepochs=args.epochs - args.warmup_epochs,
iters_per_epoch=num_batches,
step_epoch=lr_decay_epoch,
step_factor=args.lr_decay, power=2),
])
trainer = gluon.Trainer(
net.collect_params(), 'sgd',
{'lr_scheduler': lr_scheduler, 'wd': args.wd, 'momentum': args.momentum},
update_on_kvstore=(False if args.amp else None))
if args.amp:
amp.init_trainer(trainer)
cls_box_loss = EfficientDetLoss(len(classes)+1, rho=0.1, lambd=50.0)
ce_metric = mx.metric.Loss('FocalLoss')
smoothl1_metric = mx.metric.Loss('SmoothL1')
# set up logger
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
log_file_path = args.save_prefix + '_train.log'
log_dir = os.path.dirname(log_file_path)
if log_dir and not os.path.exists(log_dir):
os.makedirs(log_dir)
fh = logging.FileHandler(log_file_path)
logger.addHandler(fh)
logger.info(args)
logger.info('Start training from [Epoch {}]'.format(args.start_epoch))
best_map = [0]
for epoch in range(args.start_epoch+1, args.epochs+1):
logger.info("[Epoch {}] Set learning rate to {}".format(epoch, trainer.learning_rate))
ce_metric.reset()
smoothl1_metric.reset()
tic = time.time()
btic = time.time()
net.hybridize()
for i, batch in enumerate(train_data):
data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
cls_targets = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
box_targets = gluon.utils.split_and_load(batch[2], ctx_list=ctx, batch_axis=0)
with autograd.record():
cls_preds = []
box_preds = []
for x in data:
cls_pred, box_pred, _ = net(x)
cls_preds.append(cls_pred)
box_preds.append(box_pred)
sum_loss, cls_loss, box_loss = cls_box_loss(cls_preds, box_preds, cls_targets, box_targets)
if args.amp:
with amp.scale_loss(sum_loss, trainer) as scaled_loss:
autograd.backward(scaled_loss)
else:
autograd.backward(sum_loss)
# since we have already normalized the loss, we don't want to normalize
# by batch-size anymore
trainer.step(1)
local_batch_size = int(args.batch_size)
ce_metric.update(0, [l * local_batch_size for l in cls_loss])
smoothl1_metric.update(0, [l * local_batch_size for l in box_loss])
if args.log_interval and not (i + 1) % args.log_interval:
name1, loss1 = ce_metric.get()
name2, loss2 = smoothl1_metric.get()
logger.info('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}={:.3f}, {}={:.3f}'.format(
epoch, i, args.batch_size/(time.time()-btic), name1, loss1, name2, loss2))
btic = time.time()
name1, loss1 = ce_metric.get()
name2, loss2 = smoothl1_metric.get()
logger.info('[Epoch {}] Training cost: {:.3f}, {}={:.3f}, {}={:.3f}'.format(
epoch, (time.time()-tic), name1, loss1, name2, loss2))
if (epoch % args.val_interval == 0) or (args.save_interval and epoch % args.save_interval == 0):
# consider reduce the frequency of validation to save time
map_name, mean_ap = validate(net, val_data, ctx, eval_metric)
val_msg = '\n'.join(['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)])
logger.info('[Epoch {}] Validation: \n{}'.format(epoch, val_msg))
current_map = float(mean_ap[-1])
else:
current_map = 0.
save_params(net, best_map, current_map, epoch, args.save_interval, args.save_prefix)
if __name__ == '__main__':
args = parse_args()
if args.amp:
amp.init()
# fix seed for mxnet, numpy and python builtin random generator.
gutils.random.seed(args.seed)
ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()]
ctx = ctx if ctx else [mx.cpu()]
net_name = '_'.join((args.network, str(args.data_shape), args.dataset))
if not os.path.exists(args.save_prefix):
os.mkdir(args.save_prefix)
args.save_prefix += net_name
if args.dataset.lower() == 'coco':
from gluoncv.data import COCODetection
classes = COCODetection.CLASSES
elif args.dataset.lower() == 'voc':
from gluoncv.data import VOCDetection
classes = VOCDetection.CLASSES
else:
raise NotImplementedError('Dataset: {} not implemented.'.format(args.dataset))
if args.syncbn and len(ctx) > 1:
net = get_efficientdet(args.network, classes,
pretrained_base=False, act_type=args.act_type,
norm_layer=gluon.contrib.nn.SyncBatchNorm,
norm_kwargs={'num_devices': len(ctx)})
async_net = get_efficientdet(args.network, classes,
act_type=args.act_type,
pretrained_base=False) # used by cpu worker
else:
net = get_efficientdet(args.network, classes, act_type=args.act_type,
pretrained_base=False, norm_layer=gluon.nn.BatchNorm)
async_net = net
net.initialize()
async_net.initialize()
if args.resume.strip():
net.load_parameters(args.resume.strip())
async_net.load_parameters(args.resume.strip())
net.hybridize()
train_dataset, val_dataset, eval_metric = get_dataset(args.dataset, args)
batch_size = args.batch_size
train_data, val_data = get_dataloader(
async_net, train_dataset, val_dataset, args.data_shape, batch_size, args.num_workers)
# training
train(net, train_data, val_data, eval_metric, ctx, args)