Skip to content

Commit

Permalink
[Feature] Release RTMDet models and configs. (open-mmlab#8870)
Browse files Browse the repository at this point in the history
* [Feature] Release RTMDet models and configs.

* update config

* update link and metafile

* update
  • Loading branch information
RangiLyu authored and ZwwWayne committed Sep 26, 2022
1 parent c43a635 commit 9b79bd5
Show file tree
Hide file tree
Showing 9 changed files with 418 additions and 2 deletions.
25 changes: 25 additions & 0 deletions configs/rtmdet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# RTMDet

<!-- [ALGORITHM] -->

## Abstract

Our tech-report will be released soon.

<div align=center>
<img src="https://user-images.githubusercontent.com/12907710/192182907-f9a671d6-89cb-4d73-abd8-c2b9dada3c66.png"/>
</div>

## Results and Models

| Backbone | size | box AP | Params(M) | FLOPS(G) | TRT-FP16-Latency(ms) | Config | Download |
| :---------: | :--: | :----: | :-------: | :------: | :------------------: | :----------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| RTMDet-tiny | 640 | 40.9 | 4.8 | 8.1 | 0.98 | [config](./rtmdet_tiny_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco_20220902_112414-78e30dcc.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_s_8xb32-300e_coco_20220902_112414.log.json) |
| RTMDet-s | 640 | 44.5 | 8.89 | 14.8 | 1.22 | [config](./rtmdet_s_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_s_8xb32-300e_coco/rtmdet_s_8xb32-300e_coco_20220905_161602-a61dc0d2.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_s_8xb32-300e_coco/rtmdet_s_8xb32-300e_coco_20220905_161602.log.json) |
| RTMDet-m | 640 | 49.1 | 24.71 | 39.27 | 1.62 | [config](./rtmdet_m_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_m_8xb32-300e_coco/rtmdet_m_8xb32-300e_coco_20220719_112220-229f527c.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_m_8xb32-300e_coco/rtmdet_m_8xb32-300e_coco_20220719_112220.log.json) |
| RTMDet-l | 640 | 51.3 | 52.3 | 80.23 | 2.44 | [config](./rtmdet_l_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_l_8xb32-300e_coco/rtmdet_l_8xb32-300e_coco_20220719_112030-5a0be7c4.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_l_8xb32-300e_coco/rtmdet_l_8xb32-300e_coco_20220719_112030.log.json) |
| RTMDet-x | 640 | 52.6 | 94.86 | 141.67 | 3.10 | [config](./rtmdet_x_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_x_8xb32-300e_coco/rtmdet_x_8xb32-300e_coco_20220715_230555-cc79b9ae.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_x_8xb32-300e_coco/rtmdet_x_8xb32-300e_coco_20220715_230555.log.json) |

**Note**:

1. The inference speed is measured on an NVIDIA 3090 GPU with TensorRT 8.4.3, cuDNN 8.2.0, FP16, batch size=1, and the without NMS.
81 changes: 81 additions & 0 deletions configs/rtmdet/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
Collections:
- Name: RTMDet
Metadata:
Training Data: COCO
Training Techniques:
- AdamW
- Flat Cosine Annealing
Training Resources: 8x A100 GPUs
Architecture:
- CSPNeXt
- CSPNeXtPAFPN
README: configs/rtmdet/README.md
Code:
URL: https://github.com/open-mmlab/mmdetection/blob/v3.0.0rc1/mmdet/models/detectors/rtmdet.py#L6
Version: v3.0.0rc1

Models:
- Name: rtmdet_tiny_8xb32-300e_coco
In Collection: RTMDet
Config: configs/rtmdet/rtmdet_tiny_8xb32-300e_coco.py
Metadata:
Training Memory (GB): 7.6
Epochs: 300
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 40.9
Weights: https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco_20220902_112414-78e30dcc.pth

- Name: rtmdet_s_8xb32-300e_coco
In Collection: RTMDet
Config: configs/rtmdet/rtmdet_s_8xb32-300e_coco.py
Metadata:
Training Memory (GB): 7.6
Epochs: 300
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 44.5
Weights: https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_s_8xb32-300e_coco/rtmdet_s_8xb32-300e_coco_20220905_161602-a61dc0d2.pth

- Name: rtmdet_m_8xb32-300e_coco
In Collection: RTMDet
Config: configs/rtmdet/rtmdet_m_8xb32-300e_coco.py
Metadata:
Training Memory (GB): 7.6
Epochs: 300
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 49.1
Weights: https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_m_8xb32-300e_coco/rtmdet_m_8xb32-300e_coco_20220719_112220-229f527c.pth

- Name: rtmdet_l_8xb32-300e_coco
In Collection: RTMDet
Config: configs/rtmdet/rtmdet_l_8xb32-300e_coco.py
Metadata:
Training Memory (GB): 7.6
Epochs: 300
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 51.3
Weights: https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_l_8xb32-300e_coco/rtmdet_l_8xb32-300e_coco_20220719_112030-5a0be7c4.pth

- Name: rtmdet_x_8xb32-300e_coco
In Collection: RTMDet
Config: configs/rtmdet/rtmdet_x_8xb32-300e_coco.py
Metadata:
Training Memory (GB): 7.6
Epochs: 300
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 52.6
Weights: https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_x_8xb32-300e_coco/rtmdet_x_8xb32-300e_coco_20220715_230555-cc79b9ae.pth
181 changes: 181 additions & 0 deletions configs/rtmdet/rtmdet_l_8xb32-300e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
_base_ = [
'../_base_/default_runtime.py', '../_base_/schedules/schedule_1x.py',
'../_base_/datasets/coco_detection.py'
]
model = dict(
type='RTMDet',
data_preprocessor=dict(
type='DetDataPreprocessor',
mean=[103.53, 116.28, 123.675],
std=[57.375, 57.12, 58.395],
bgr_to_rgb=False,
batch_augments=None),
backbone=dict(
type='CSPNeXt',
arch='P5',
expand_ratio=0.5,
deepen_factor=1,
widen_factor=1,
channel_attention=True,
norm_cfg=dict(type='SyncBN'),
act_cfg=dict(type='SiLU')),
neck=dict(
type='CSPNeXtPAFPN',
in_channels=[256, 512, 1024],
out_channels=256,
num_csp_blocks=3,
expand_ratio=0.5,
norm_cfg=dict(type='SyncBN'),
act_cfg=dict(type='SiLU')),
bbox_head=dict(
type='RTMDetSepBNHead',
num_classes=80,
in_channels=256,
stacked_convs=2,
feat_channels=256,
anchor_generator=dict(
type='MlvlPointGenerator', offset=0, strides=[8, 16, 32]),
bbox_coder=dict(type='DistancePointBBoxCoder'),
loss_cls=dict(
type='QualityFocalLoss',
use_sigmoid=True,
beta=2.0,
loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=2.0),
with_objectness=False,
exp_on_reg=True,
share_conv=True,
pred_kernel_size=1,
norm_cfg=dict(type='SyncBN'),
act_cfg=dict(type='SiLU')),
train_cfg=dict(
assigner=dict(type='DynamicSoftLabelAssigner', topk=13),
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.6),
max_per_img=100),
)

train_pipeline = [
dict(
type='LoadImageFromFile',
file_client_args={{_base_.file_client_args}}),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='CachedMosaic', img_scale=(640, 640), pad_val=114.0),
dict(
type='RandomResize',
scale=(1280, 1280),
ratio_range=(0.1, 2.0),
keep_ratio=True),
dict(type='RandomCrop', crop_size=(640, 640)),
dict(type='YOLOXHSVRandomAug'),
dict(type='RandomFlip', prob=0.5),
dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))),
dict(
type='CachedMixUp',
img_scale=(640, 640),
ratio_range=(1.0, 1.0),
max_cached_images=20,
pad_val=(114, 114, 114)),
dict(type='PackDetInputs')
]

train_pipeline_stage2 = [
dict(
type='LoadImageFromFile',
file_client_args={{_base_.file_client_args}}),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='RandomResize',
scale=(640, 640),
ratio_range=(0.1, 2.0),
keep_ratio=True),
dict(type='RandomCrop', crop_size=(640, 640)),
dict(type='YOLOXHSVRandomAug'),
dict(type='RandomFlip', prob=0.5),
dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))),
dict(type='PackDetInputs')
]

test_pipeline = [
dict(
type='LoadImageFromFile',
file_client_args={{_base_.file_client_args}}),
dict(type='Resize', scale=(640, 640), keep_ratio=True),
dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))),
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor'))
]

train_dataloader = dict(
batch_size=32,
num_workers=10,
batch_sampler=None,
pin_memory=True,
dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(
batch_size=5, num_workers=10, dataset=dict(pipeline=test_pipeline))
test_dataloader = val_dataloader

max_epochs = 300
stage2_num_epochs = 20
base_lr = 0.004
interval = 10

train_cfg = dict(
max_epochs=max_epochs,
val_interval=interval,
dynamic_intervals=[(max_epochs - stage2_num_epochs, 1)])

# optimizer
optim_wrapper = dict(
_delete_=True,
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
paramwise_cfg=dict(
norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))

# learning rate
param_scheduler = [
dict(
type='LinearLR',
start_factor=1.0e-5,
by_epoch=False,
begin=0,
end=1000),
dict(
# use cosine lr from 150 to 300 epoch
type='CosineAnnealingLR',
eta_min=base_lr * 0.05,
begin=max_epochs // 2,
end=max_epochs,
T_max=max_epochs // 2,
by_epoch=True,
convert_to_iter_based=True),
]

# hooks
default_hooks = dict(
checkpoint=dict(
interval=interval,
max_keep_ckpts=3 # only keep latest 3 checkpoints
))
custom_hooks = [
dict(
type='EMAHook',
ema_type='ExpMomentumEMA',
momentum=0.0002,
update_buffers=True,
priority=49),
dict(
type='PipelineSwitchHook',
switch_epoch=max_epochs - stage2_num_epochs,
switch_pipeline=train_pipeline_stage2)
]
6 changes: 6 additions & 0 deletions configs/rtmdet/rtmdet_m_8xb32-300e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
_base_ = './rtmdet_l_8xb32-300e_coco.py'

model = dict(
backbone=dict(deepen_factor=0.67, widen_factor=0.75),
neck=dict(in_channels=[192, 384, 768], out_channels=192, num_csp_blocks=2),
bbox_head=dict(in_channels=192, feat_channels=192))
66 changes: 66 additions & 0 deletions configs/rtmdet/rtmdet_s_8xb32-300e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
_base_ = './rtmdet_l_8xb32-300e_coco.py'
checkpoint = 'TODO:imagenet_pretrain' # noqa
model = dict(
backbone=dict(
deepen_factor=0.33,
widen_factor=0.5,
init_cfg=dict(
type='Pretrained', prefix='backbone.', checkpoint=checkpoint)),
neck=dict(in_channels=[128, 256, 512], out_channels=128, num_csp_blocks=1),
bbox_head=dict(in_channels=128, feat_channels=128, exp_on_reg=False))

train_pipeline = [
dict(
type='LoadImageFromFile',
file_client_args={{_base_.file_client_args}}),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='CachedMosaic', img_scale=(640, 640), pad_val=114.0),
dict(
type='RandomResize',
scale=(1280, 1280),
ratio_range=(0.5, 2.0),
keep_ratio=True),
dict(type='RandomCrop', crop_size=(640, 640)),
dict(type='YOLOXHSVRandomAug'),
dict(type='RandomFlip', prob=0.5),
dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))),
dict(
type='CachedMixUp',
img_scale=(640, 640),
ratio_range=(1.0, 1.0),
max_cached_images=20,
pad_val=(114, 114, 114)),
dict(type='PackDetInputs')
]

train_pipeline_stage2 = [
dict(
type='LoadImageFromFile',
file_client_args={{_base_.file_client_args}}),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='RandomResize',
scale=(640, 640),
ratio_range=(0.5, 2.0),
keep_ratio=True),
dict(type='RandomCrop', crop_size=(640, 640)),
dict(type='YOLOXHSVRandomAug'),
dict(type='RandomFlip', prob=0.5),
dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))),
dict(type='PackDetInputs')
]

train_dataloader = dict(dataset=dict(pipeline=train_pipeline))

custom_hooks = [
dict(
type='EMAHook',
ema_type='ExpMomentumEMA',
momentum=0.0002,
update_buffers=True,
priority=49),
dict(
type='PipelineSwitchHook',
switch_epoch=280,
switch_pipeline=train_pipeline_stage2)
]
Loading

0 comments on commit 9b79bd5

Please sign in to comment.