-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_tood_yymnist.py
139 lines (106 loc) · 5.16 KB
/
train_tood_yymnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import copy
import os.path as osp
import os
import numpy as np
import mmcv
import sys
sys.path.insert(0,'.')
from mmdet.datasets.builder import DATASETS
from mmdet.datasets.custom import CustomDataset
from mmcv import Config
@DATASETS.register_module()
class yyminstDataset(CustomDataset):
CLASSES = ('0', '1', '2','3','4','5','6','7','8','9')
PALETTE = [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230),
(106, 0, 228), (0, 60, 100), (0, 80, 100), (0, 0, 70),
(0, 0, 192), (250, 170, 30), (100, 170, 30), (220, 220, 0),
(175, 116, 175), (250, 0, 30), (165, 42, 42), (255, 77, 255),
(0, 226, 252), (182, 182, 255), (0, 82, 0), (120, 166, 157),
(110, 76, 0), (174, 57, 255), (199, 100, 0), (72, 0, 118),
(255, 179, 240), (0, 125, 92), (209, 0, 151), (188, 208, 182),
(0, 220, 176), (255, 99, 164), (92, 0, 73), (133, 129, 255),
(78, 180, 255), (0, 228, 0), (174, 255, 243), (45, 89, 255),
(134, 134, 103), (145, 148, 174), (255, 208, 186),
(197, 226, 255), (171, 134, 1), (109, 63, 54), (207, 138, 255),
(151, 0, 95), (9, 80, 61), (84, 105, 51), (74, 65, 105),
(166, 196, 102), (208, 195, 210), (255, 109, 65), (0, 143, 149),
(179, 0, 194), (209, 99, 106), (5, 121, 0), (227, 255, 205),
(147, 186, 208), (153, 69, 1), (3, 95, 161), (163, 255, 0),
(119, 0, 170), (0, 182, 199), (0, 165, 120), (183, 130, 88),
(95, 32, 0), (130, 114, 135), (110, 129, 133), (166, 74, 118),
(219, 142, 185), (79, 210, 114), (178, 90, 62), (65, 70, 15),
(127, 167, 115), (59, 105, 106), (142, 108, 45), (196, 172, 0),
(95, 54, 80), (128, 76, 255), (201, 57, 1), (246, 0, 122),
(191, 162, 208)]
def load_annotations(self, ann_file):
cat2label = {k: i for i, k in enumerate(self.CLASSES)}
# load image list from file
ann_list = mmcv.list_from_file(self.ann_file)
data_infos = []
# convert annotations to middle format
for ann in ann_list:
image_id,_,label_info= ann.partition(" ")
filename = f'{self.img_prefix}{image_id}'
image = mmcv.imread(filename)
height, width = image.shape[:2]
data_info = dict(filename=f'{image_id}', width=width, height=height)
# load annotations
lines = label_info.split()
gt_bboxes = []
gt_labels = []
gt_bboxes_ignore = []
gt_labels_ignore = []
for line in lines:
content = line.split(",")
bbox_name= content[-1]
bbox = content[0:-1]
bbox = [float(e) for e in bbox]
if bbox_name in cat2label:
gt_labels.append(cat2label[bbox_name])
gt_bboxes.append(bbox)
else:
gt_labels_ignore.append(-1)
gt_bboxes_ignore.append(bbox)
data_anno = dict(
bboxes=np.array(gt_bboxes, dtype=np.float32).reshape(-1, 4),
labels=np.array(gt_labels, dtype=np.long),
bboxes_ignore=np.array(gt_bboxes_ignore,
dtype=np.float32).reshape(-1, 4),
labels_ignore=np.array(gt_labels_ignore, dtype=np.long))
data_info.update(ann=data_anno)
data_infos.append(data_info)
return data_infos
#cfg = Config.fromfile("myconfigs/yolox_s_8x8_300e_coco_yymnist.py")
cfg = Config.fromfile("configs/tood/tood_r50_fpn_1x_coco.py")
from mmdet.apis import set_random_seed
cfg.load_from = None
# Set up working dir to save files and logs.
cfg.work_dir = './tutorial_exps/tood_card'
# We can set the evaluation interval to reduce the evaluation times
cfg.evaluation.interval = 5
# We can set the checkpoint saving interval to reduce the storage cost
cfg.checkpoint_config.interval = 10
# Set seed thus the results are more reproducible
cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.data.samples_per_gpu=4
cfg.data.workers_per_gpu=2
cfg.runner = dict(type='EpochBasedRunner', max_epochs=100)
cfg.gpu_ids = range(1)
# We can initialize the logger for training and have a look
# at the final config used for training
print(f'Config:\n{cfg.pretty_text}')
# with open("myconfigs/tood_r50_fpn_1x_coco_card.py","w+") as f:
# f.writelines(cfg.pretty_text)
if __name__ == '__main__':
from mmdet.datasets import build_dataset
from mmdet.models import build_detector
from train.train import train_detector
cfg = Config.fromfile("myconfigs/tood_r50_fpn_1x_yymnist.py")
datasets = [build_dataset(cfg.data.train)]
model = build_detector(
cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg'))
for dataset in datasets:
classes=('0', '1', '2', '3', '4', '5', '6', '7', '8', '9'),
mmcv.mkdir_or_exist(os.path.abspath(cfg.work_dir))
train_detector(model, datasets, cfg, distributed=False, validate=True)