-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Request for ResNet101 Config and Weights for RayDN. #3
Comments
Hi,@kkkcx,thank you for your interest in our work. We don't have the plan to release the config and weights of R101 for the time being. It may be realeased after the paper is accepted. Please stay tuned. |
Thank you for your response. I have attempted to modify the training config for r101 on my own. My config is as follows: _base_ = [
'../../../mmdetection3d/configs/_base_/datasets/nus-3d.py',
'../../../mmdetection3d/configs/_base_/default_runtime.py'
]
backbone_norm_cfg = dict(type='LN', requires_grad=True)
plugin=True
plugin_dir='projects/mmdet3d_plugin/'
# If point cloud range is changed, the models should also change their point
# cloud range accordingly
point_cloud_range = [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0]
voxel_size = [0.2, 0.2, 8]
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
# For nuScenes we usually do 10-class detection
class_names = [
'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier',
'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'
]
num_gpus = 8
batch_size = 2
num_iters_per_epoch = 28130 // (num_gpus * batch_size)
num_epochs = 60
queue_length = 1
num_frame_losses = 1
collect_keys=['lidar2img', 'intrinsics', 'extrinsics','timestamp', 'img_timestamp', 'ego_pose', 'ego_pose_inv']
input_modality = dict(
use_lidar=False,
use_camera=True,
use_radar=False,
use_map=False,
use_external=True)
model = dict(
type='RepDetr3D',
num_frame_head_grads=num_frame_losses,
num_frame_backbone_grads=num_frame_losses,
num_frame_losses=num_frame_losses,
use_grid_mask=True,
stride=[8, 16, 32, 64],
position_level=[0, 1, 2, 3],
# img_backbone=dict(
# init_cfg=dict(
# type='Pretrained', checkpoint="ckpts/cascade_mask_rcnn_r50_fpn_coco-20e_20e_nuim_20201009_124951-40963960.pth",
# prefix='backbone.'),
# type='ResNet',
# depth=50,
# num_stages=4,
# out_indices=(0, 1, 2, 3),
# frozen_stages=-1,
# norm_cfg=dict(type='BN2d', requires_grad=False),
# norm_eval=True,
# with_cp=True,
# style='pytorch'),
img_backbone=dict(
type='ResNet',
depth=101,
num_stages=4,
frozen_stages=-1,
style='pytorch',
with_cp=True,
out_indices=(0, 1, 2, 3),
norm_eval=True,
norm_cfg=dict(type='BN', requires_grad=False),
init_cfg=dict(
type='Pretrained',
checkpoint=
'ckpts/cascade_mask_rcnn_r101_fpn_1x_nuim_20201024_134804-45215b1e.pth',
prefix='backbone.')),
img_neck=dict(
type='FPN', ###remove unused parameters
start_level=1,
add_extra_convs='on_output',
relu_before_extra_convs=True,
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=4),
img_roi_head=dict(
type='YOLOXHeadCustom',
num_classes=10,
in_channels=256,
strides=[8, 16, 32, 64],
train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)),
test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65)),),
pts_bbox_head=dict(
type='RayDNHead',
num_classes=10,
in_channels=256,
num_query=300,
memory_len=512,
topk_proposals=128,
num_propagated=128,
scalar=10, ##noise groups
noise_scale = 1.0,
dn_weight= 1.0, ##dn loss weight
split = 0.75, ###positive rate
with_dn=True,
raydn_group=1,
raydn_num=5,
raydn_alpha=8,
raydn_beta=2,
raydn_radius=3,
with_ego_pos=True,
match_with_velo=False,
code_weights = [2.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
transformer=dict(
type='Detr3DTransformer',
decoder=dict(
type='Detr3DTransformerDecoder',
embed_dims=256,
num_layers=6,
transformerlayers=dict(
type='Detr3DTemporalDecoderLayer',
batch_first=True,
attn_cfgs=[
dict(
type='MultiheadAttention',
embed_dims=256,
num_heads=8,
dropout=0.1),
dict(
type='DeformableFeatureAggregationCuda',
embed_dims=256,
num_groups=8,
num_levels=4,
num_cams=6,
dropout=0.1,
num_pts=13,
bias=2.),
],
feedforward_channels=2048,
ffn_dropout=0.1,
with_cp=True, ###use checkpoint to save memory
operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
'ffn', 'norm')),
)),
bbox_coder=dict(
type='NMSFreeCoder',
post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0],
pc_range=point_cloud_range,
max_num=300,
voxel_size=voxel_size,
num_classes=10),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=2.0),
loss_bbox=dict(type='L1Loss', loss_weight=0.25),
loss_iou=dict(type='GIoULoss', loss_weight=0.0),),
# model training and testing settings
train_cfg=dict(pts=dict(
grid_size=[512, 512, 1],
voxel_size=voxel_size,
point_cloud_range=point_cloud_range,
out_size_factor=4,
assigner=dict(
type='HungarianAssigner3D',
cls_cost=dict(type='FocalLossCost', weight=2.0),
reg_cost=dict(type='BBox3DL1Cost', weight=0.25),
iou_cost=dict(type='IoUCost', weight=0.0), # Fake cost. This is just to make it compatible with DETR head.
pc_range=point_cloud_range),)))
dataset_type = 'CustomNuScenesDataset'
data_root = './data/nuscenes/'
file_client_args = dict(backend='disk')
ida_aug_conf = {
"resize_lim": (0.38, 0.55),
"final_dim": (256, 704),
"bot_pct_lim": (0.0, 0.0),
"rot_lim": (0.0, 0.0),
"H": 900,
"W": 1600,
"rand_flip": True,
}
train_pipeline = [
dict(type='LoadMultiViewImageFromFiles', to_float32=True),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True, with_bbox=True,
with_label=True, with_bbox_depth=True),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectNameFilter', classes=class_names),
dict(type='ResizeCropFlipRotImage', data_aug_conf = ida_aug_conf, training=True),
dict(type='GlobalRotScaleTransImage',
rot_range=[-0.3925, 0.3925],
translation_std=[0, 0, 0],
scale_ratio_range=[0.95, 1.05],
reverse_angle=True,
training=True,
),
dict(type='NormalizeMultiviewImage', **img_norm_cfg),
dict(type='PadMultiViewImage', size_divisor=32),
dict(type='PETRFormatBundle3D', class_names=class_names, collect_keys=collect_keys + ['prev_exists']),
dict(type='Collect3D', keys=['gt_bboxes_3d', 'gt_labels_3d', 'img', 'gt_bboxes', 'gt_labels', 'centers2d', 'depths', 'prev_exists'] + collect_keys,
meta_keys=('filename', 'ori_shape', 'img_shape', 'pad_shape', 'scale_factor', 'flip', 'box_mode_3d', 'box_type_3d', 'img_norm_cfg', 'scene_token', 'gt_bboxes_3d','gt_labels_3d'))
]
test_pipeline = [
dict(type='LoadMultiViewImageFromFiles', to_float32=True),
dict(type='ResizeCropFlipRotImage', data_aug_conf = ida_aug_conf, training=False),
dict(type='NormalizeMultiviewImage', **img_norm_cfg),
dict(type='PadMultiViewImage', size_divisor=32),
dict(
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
pts_scale_ratio=1,
flip=False,
transforms=[
dict(
type='PETRFormatBundle3D',
collect_keys=collect_keys,
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['img'] + collect_keys,
meta_keys=('filename', 'ori_shape', 'img_shape','pad_shape', 'scale_factor', 'flip', 'box_mode_3d', 'box_type_3d', 'img_norm_cfg', 'scene_token'))
])
]
data = dict(
samples_per_gpu=batch_size,
workers_per_gpu=4,
train=dict(
type=dataset_type,
data_root=data_root,
ann_file=data_root + 'nuscenes2d_temporal_infos_train.pkl',
num_frame_losses=num_frame_losses,
seq_split_num=2,
seq_mode=True,
pipeline=train_pipeline,
classes=class_names,
modality=input_modality,
collect_keys=collect_keys + ['img', 'prev_exists', 'img_metas'],
queue_length=queue_length,
test_mode=False,
use_valid_flag=True,
filter_empty_gt=False,
box_type_3d='LiDAR'),
val=dict(type=dataset_type, data_root=data_root, pipeline=test_pipeline, collect_keys=collect_keys + ['img', 'img_metas'], queue_length=queue_length, ann_file=data_root + 'nuscenes2d_temporal_infos_val.pkl', classes=class_names, modality=input_modality),
test=dict(type=dataset_type, data_root=data_root, pipeline=test_pipeline, collect_keys=collect_keys + ['img', 'img_metas'], queue_length=queue_length, ann_file=data_root + 'nuscenes2d_temporal_infos_val.pkl', classes=class_names, modality=input_modality),
shuffler_sampler=dict(type='InfiniteGroupEachSampleInBatchSampler'),
nonshuffler_sampler=dict(type='DistributedSampler')
)
optimizer = dict(
type='AdamW',
lr=4e-4, # bs 8: 2e-4 || bs 16: 4e-4
paramwise_cfg=dict(
custom_keys={
'img_backbone': dict(lr_mult=0.25),
}),
weight_decay=0.01)
optimizer_config = dict(type='Fp16OptimizerHook', loss_scale='dynamic', grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='CosineAnnealing',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
min_lr_ratio=1e-3,
)
evaluation = dict(interval=num_iters_per_epoch*num_epochs, pipeline=test_pipeline)
find_unused_parameters=False #### when use checkpoint, find_unused_parameters must be False
checkpoint_config = dict(interval=num_iters_per_epoch, max_keep_ckpts=1000)
runner = dict(
type='IterBasedRunner', max_iters=num_epochs * num_iters_per_epoch) and the results obtained are as follows: mAP: 0.4636 Per-class results: It appears that the results are still not satisfactory and there remains a gap compared to the outcomes you have achieved. May I kindly ask if there might be any inaccuracies or improvements needed in my config? I deeply appreciate your guidance. |
Hi Dr. Liu,
Thank you for your excellent work on RayDN! I have been reading your paper and noticed that ResNet101 achieved better results than ResNet50 in the experiments. Could you please provide the config file and training weights for the ResNet101 version of RayDN? Thank you very much!!
The text was updated successfully, but these errors were encountered: