Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support RT-DETR #11395

Open
wants to merge 19 commits into
base: dev-3.x
Choose a base branch
from
Open
43 changes: 43 additions & 0 deletions configs/rtdetr/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# RT-DETR

> [DETRs Beat YOLOs on Real-time Object Detection](https://arxiv.org/abs/2304.08069)

<!-- [ALGORITHM] -->

## Abstract

Recently, end-to-end transformer-based detectors~(DETRs) have achieved remarkable performance. However, the issue of the high computational cost of DETRs has not been effectively addressed, limiting their practical application and preventing them from fully exploiting the benefits of no post-processing, such as non-maximum suppression (NMS). In this paper, we first analyze the influence of NMS in modern real-time object detectors on inference speed, and establish an end-to-end speed benchmark. To avoid the inference delay caused by NMS, we propose a Real-Time DEtection TRansformer (RT-DETR), the first real-time end-to-end object detector to our best knowledge. Specifically, we design an efficient hybrid encoder to efficiently process multi-scale features by decoupling the intra-scale interaction and cross-scale fusion, and propose IoU-aware query selection to improve the initialization of object queries. In addition, our proposed detector supports flexibly adjustment of the inference speed by using different decoder layers without the need for retraining, which facilitates the practical application of real-time object detectors. Our RT-DETR-L achieves 53.0% AP on COCO val2017 and 114 FPS on T4 GPU, while RT-DETR-X achieves 54.8% AP and 74 FPS, outperforming all YOLO detectors of the same scale in both speed and accuracy. Furthermore, our RT-DETR-R50 achieves 53.1% AP and 108 FPS, outperforming DINO-Deformable-DETR-R50 by 2.2% AP in accuracy and by about 21 times in FPS. ource code and pre-trained models are available at [this https URL](https://github.com/lyuwenyu/RT-DETR).

<div align=center>
<img src="https://user-images.githubusercontent.com/17582080/262603054-42636690-1ecf-4647-b075-842ecb9bc562.png"/>
</div>

## Results and Models

| Backbone | Model | Lr schd | box AP | Config | Download |
| :------: | :----------: | :-----: | :----: | :----------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------: |
| R-18vd | RT-DETR Dec3 | 72e | 46.5 | [config](./rtdetr_r18vd_8xb2-72e_coco.py) | [model](https://github.com/flytocc/mmdetection/releases/download/model_zoo/rtdetr_r18vd_8xb2-72e_coco_3dda8dd4.pth) \| log |
| R-34vd | RT-DETR Dec4 | 72e | 48.9 | [config](./rtdetr_r34vd_8xb2-72e_coco.py) | [model](https://github.com/flytocc/mmdetection/releases/download/model_zoo/rtdetr_r34vd_8xb2-72e_coco_9159eb52.pth) \| log |
| R-50vd | RT-DETR Dec6 | 72e | 53.1 | [config](./rtdetr_r50vd_8xb2-72e_coco.py) | [model](https://github.com/flytocc/mmdetection/releases/download/model_zoo/rtdetr_r50vd_8xb2-72e_coco_ad2bdcfe.pth) \| log |
| R-101vd | RT-DETR Dec6 | 72e | 54.3 | [config](./rtdetr_r101vd_8xb2-72e_coco.py) | [model](https://github.com/flytocc/mmdetection/releases/download/model_zoo/rtdetr_r101vd_8xb2-72e_coco_83ad1b19.pth) \| log |

### NOTE

Weights converted from the [official repo](https://github.com/lyuwenyu/RT-DETR).

The performance is unstable. `RT-DETR` with `R-50vd` may fluctuate about 0.4 mAP.

## Citation

We provide the config files for RT-DETR: [DETRs Beat YOLOs on Real-time Object Detection](https://arxiv.org/abs/2304.08069).

```latex
@misc{lv2023detrs,
title={DETRs Beat YOLOs on Real-time Object Detection},
author={Wenyu Lv and Shangliang Xu and Yian Zhao and Guanzhong Wang and Jinman Wei and Cheng Cui and Yuning Du and Qingqing Dang and Yi Liu},
year={2023},
eprint={2304.08069},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
69 changes: 69 additions & 0 deletions configs/rtdetr/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
Collections:
- Name: RT-DETR
Metadata:
Training Data: COCO
Training Techniques:
- AdamW
- Weight Decay
- Multi Scale Train
- Gradient Clip
Training Resources: 8x A100 GPUs
Architecture:
- ResNet
- Transformer
Paper:
URL: https://arxiv.org/abs/2304.08069
Title: 'DETRs Beat YOLOs on Real-time Object Detection'
README: configs/rtdetr/README.md
Code:
URL: https://github.com/flytocc/mmdetection/blob/f7cf93dcc8d5574393ca1eeb67a97f30da4290c7/mmdet/models/detectors/rtdetr.py#L15
Version: v3.3.0

Models:
- Name: rtdetr_r18vd_8xb2-72e_coco
In Collection: RT-DETR
Config: configs/rtdetr/rtdetr_r18vd_8xb2-72e_coco.py
Metadata:
Epochs: 72
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 46.5
Weights: https://github.com/flytocc/mmdetection/releases/download/model_zoo/rtdetr_r18vd_8xb2-72e_coco_3dda8dd4.pth

- Name: rtdetr_r34vd_8xb2-72e_coco
In Collection: RT-DETR
Config: configs/rtdetr/rtdetr_r34vd_8xb2-72e_coco.py
Metadata:
Epochs: 72
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 48.9
Weights: https://github.com/flytocc/mmdetection/releases/download/model_zoo/rtdetr_r34vd_8xb2-72e_coco_9159eb52.pth

- Name: rtdetr_r50vd_8xb2-72e_coco
In Collection: RT-DETR
Config: configs/rtdetr/rtdetr_r50vd_8xb2-72e_coco.py
Metadata:
Epochs: 72
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 53.1
Weights: https://github.com/flytocc/mmdetection/releases/download/model_zoo/rtdetr_r50vd_8xb2-72e_coco_ad2bdcfe.pth

- Name: rtdetr_r101vd_8xb2-72e_coco
In Collection: RT-DETR
Config: configs/rtdetr/rtdetr_r101vd_8xb2-72e_coco.py
Metadata:
Epochs: 72
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 54.3
Weights: https://github.com/flytocc/mmdetection/releases/download/model_zoo/rtdetr_r101vd_8xb2-72e_coco_83ad1b19.pth
16 changes: 16 additions & 0 deletions configs/rtdetr/rtdetr_r101vd_8xb2-72e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
_base_ = './rtdetr_r50vd_8xb2-72e_coco.py'
pretrained = 'https://github.com/flytocc/mmdetection/releases/download/model_zoo/resnet101vd_ssld_pretrained_64ed664a.pth' # noqa

model = dict(
backbone=dict(
depth=101, init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
neck=dict(out_channels=384),
encoder=dict(
in_channels=[384, 384, 384],
fpn_cfg=dict(in_channels=[384, 384, 384]),
layer_cfg=dict(
self_attn_cfg=dict(embed_dims=384),
ffn_cfg=dict(embed_dims=384, feedforward_channels=2048))))

# set all layers in backbone to lr_mult=0.01
_base_.optim_wrapper.paramwise_cfg.custom_keys.backbone.lr_mult = 0.01
31 changes: 31 additions & 0 deletions configs/rtdetr/rtdetr_r18vd_8xb2-72e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
_base_ = './rtdetr_r50vd_8xb2-72e_coco.py'
pretrained = 'https://github.com/flytocc/mmdetection/releases/download/model_zoo/resnet18vd_pretrained_55f5a0d6.pth' # noqa

model = dict(
backbone=dict(
depth=18,
frozen_stages=-1,
norm_cfg=dict(type='SyncBN', requires_grad=True),
norm_eval=False,
init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
neck=dict(in_channels=[128, 256, 512]),
encoder=dict(fpn_cfg=dict(expansion=0.5)),
decoder=dict(num_layers=3))

# set all layers in backbone to lr_mult=0.1
# set all norm layers, to decay_multi=0.0
num_blocks_list = (2, 2, 2, 2) # r18
downsample_norm_idx_list = (2, 3, 3, 3) # r18
backbone_norm_multi = dict(lr_mult=0.1, decay_mult=0.0)
_base_.optim_wrapper.paramwise_cfg.custom_keys.update({
f'backbone.layer{stage_id + 1}.{block_id}.bn': backbone_norm_multi
for stage_id, num_blocks in enumerate(num_blocks_list)
for block_id in range(num_blocks)
})
_base_.optim_wrapper.paramwise_cfg.custom_keys.update({
f'backbone.layer{stage_id + 1}.{block_id}.downsample.{downsample_norm_idx - 1}': # noqa
backbone_norm_multi
for stage_id, (num_blocks, downsample_norm_idx) in enumerate(
zip(num_blocks_list, downsample_norm_idx_list))
for block_id in range(num_blocks)
})
31 changes: 31 additions & 0 deletions configs/rtdetr/rtdetr_r34vd_8xb2-72e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
_base_ = './rtdetr_r50vd_8xb2-72e_coco.py'
pretrained = 'https://github.com/flytocc/mmdetection/releases/download/model_zoo/resnet34vd_pretrained_f6a72dc5.pth' # noqa

model = dict(
backbone=dict(
depth=34,
frozen_stages=-1,
norm_cfg=dict(type='SyncBN', requires_grad=True),
norm_eval=False,
init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
neck=dict(in_channels=[128, 256, 512]),
encoder=dict(fpn_cfg=dict(expansion=0.5)),
decoder=dict(num_layers=4))

# set all layers in backbone to lr_mult=0.1
# set all norm layers, to decay_multi=0.0
num_blocks_list = (3, 4, 6, 3) # r34
downsample_norm_idx_list = (2, 3, 3, 3) # r34
backbone_norm_multi = dict(lr_mult=0.1, decay_mult=0.0)
_base_.optim_wrapper.paramwise_cfg.custom_keys.update({
f'backbone.layer{stage_id + 1}.{block_id}.bn': backbone_norm_multi
for stage_id, num_blocks in enumerate(num_blocks_list)
for block_id in range(num_blocks)
})
_base_.optim_wrapper.paramwise_cfg.custom_keys.update({
f'backbone.layer{stage_id + 1}.{block_id}.downsample.{downsample_norm_idx - 1}': # noqa
backbone_norm_multi
for stage_id, (num_blocks, downsample_norm_idx) in enumerate(
zip(num_blocks_list, downsample_norm_idx_list))
for block_id in range(num_blocks)
})
189 changes: 189 additions & 0 deletions configs/rtdetr/rtdetr_r50vd_8xb2-72e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
_base_ = [
'../_base_/datasets/coco_detection.py', '../_base_/default_runtime.py'
]
pretrained = 'https://github.com/flytocc/mmdetection/releases/download/model_zoo/resnet50vd_ssld_v2_pretrained_d037e232.pth' # noqa

model = dict(
type='RTDETR',
num_queries=300, # num_matching_queries, 900 for DINO
with_box_refine=True,
as_two_stage=True,
data_preprocessor=dict(
type='DetDataPreprocessor',
batch_augments=[
dict(
type='BatchSyncRandomResize',
interval=1,
interpolations=['nearest', 'bilinear', 'bicubic', 'area'],
random_sizes=[
480, 512, 544, 576, 608, 640, 640, 640, 672, 704, 736, 768,
800
])
],
mean=[0, 0, 0], # [123.675, 116.28, 103.53] for DINO
std=[255, 255, 255], # [58.395, 57.12, 57.375] for DINO
bgr_to_rgb=True,
pad_size_divisor=1),
backbone=dict(
type='ResNetV1d', # ResNet for DINO
depth=50,
num_stages=4,
out_indices=(1, 2, 3),
frozen_stages=0, # -1 for DINO
norm_cfg=dict(type='SyncBN', requires_grad=False), # BN for DINO
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
neck=dict(
type='ChannelMapper',
in_channels=[512, 1024, 2048],
kernel_size=1,
out_channels=256,
act_cfg=None,
norm_cfg=dict(type='SyncBN', requires_grad=True), # GN for DINO
num_outs=3), # 4 for DINO
encoder=dict(
use_encoder_idx=[2],
num_encoder_layers=1,
in_channels=[256, 256, 256],
fpn_cfg=dict(
type='RTDETRFPN',
in_channels=[256, 256, 256],
out_channels=256,
expansion=1.0,
norm_cfg=dict(type='SyncBN', requires_grad=True)),
layer_cfg=dict(
self_attn_cfg=dict(embed_dims=256, num_heads=8, dropout=0.0),
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=1024, # 2048 for DINO
ffn_drop=0.0,
act_cfg=dict(type='GELU')))), # ReLU for DINO
decoder=dict(
num_layers=6,
return_intermediate=True,
layer_cfg=dict(
self_attn_cfg=dict(embed_dims=256, num_heads=8, dropout=0.0),
cross_attn_cfg=dict(
embed_dims=256,
num_levels=3, # 4 for DINO
dropout=0.0),
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=1024, # 2048 for DINO
ffn_drop=0.0)),
post_norm_cfg=None),
bbox_head=dict(
type='RTDETRHead',
num_classes=80,
sync_cls_avg_factor=True,
loss_cls=dict(
type='RTDETRVarifocalLoss', # FocalLoss in DINO
use_sigmoid=True,
alpha=0.75,
gamma=2.0,
iou_weighted=True,
loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=5.0),
loss_iou=dict(type='GIoULoss', loss_weight=2.0)),
dn_cfg=dict( # TODO: Move to model.train_cfg ?
label_noise_scale=0.5,
box_noise_scale=1.0,
group_cfg=dict(dynamic=True, num_groups=None,
num_dn_queries=100)), # TODO: half num_dn_queries
# training and testing settings
train_cfg=dict(
assigner=dict(
type='HungarianAssigner',
match_costs=[
dict(type='FocalLossCost', weight=2.0),
dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'),
dict(type='IoUCost', iou_mode='giou', weight=2.0)
])),
test_cfg=dict(max_per_img=300))

# train_pipeline, NOTE the img_scale and the Pad's size_divisor is different
# from the default setting in mmdet.
interpolations = ['nearest', 'bilinear', 'bicubic', 'area', 'lanczos']
train_pipeline = [
dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='RandomApply',
transforms=dict(type='PhotoMetricDistortion'),
prob=0.8),
dict(type='Expand', mean=[0, 0, 0]),
dict(
type='RandomApply', transforms=dict(type='MinIoURandomCrop'),
prob=0.8),
dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False),
dict(type='RandomFlip', prob=0.5),
dict(
type='RandomChoice',
transforms=[[
dict(
type='Resize',
scale=(640, 640),
keep_ratio=False,
interpolation=interpolation)
] for interpolation in interpolations]),
dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False),
dict(type='PackDetInputs')
]

test_pipeline = [
dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}),
dict(
type='Resize',
scale=(640, 640),
keep_ratio=False,
interpolation='bicubic'),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor'))
]

train_dataloader = dict(
dataset=dict(
filter_cfg=dict(filter_empty_gt=False), pipeline=train_pipeline))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
test_dataloader = dict(dataset=dict(pipeline=test_pipeline))

# optimizer
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=0.0001, weight_decay=0.0001),
clip_grad=dict(max_norm=0.1, norm_type=2),
paramwise_cfg=dict(
custom_keys={'backbone': dict(lr_mult=0.1)},
norm_decay_mult=0,
bypass_duplicate=True))

# learning policy
max_epochs = 72
train_cfg = dict(
type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1)

val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

param_scheduler = [
dict(
type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=2000)
]

# NOTE: `auto_scale_lr` is for automatically scaling LR,
# USER SHOULD NOT CHANGE ITS VALUES.
# base_batch_size = (8 GPUs) x (2 samples per GPU)
auto_scale_lr = dict(base_batch_size=16)

custom_hooks = [
dict(
type='EMAHook',
ema_type='ExpMomentumEMA',
momentum=0.0001,
update_buffers=True,
priority=49)
]