Skip to content

[NeurIPS 2023] Masked Image Residual Learning for Scaling Deeper Vision Transformers

License

Notifications You must be signed in to change notification settings

russellllaputa/MIRL

Repository files navigation

[NeurIPS 2023] Masked Image Residual Learning for Scaling Deeper Vision Transformers

This is a PaddlePaddle implementation of the paper Masked Image Residual Learning for Scaling Deeper Vision Transformers

Abstract: Deeper Vision Transformers (ViTs) are more challenging to train. We expose a degradation problem in deeper layers of ViT when using masked image modeling (MIM) for pre-training. To ease the training of deeper ViTs, we introduce a self-supervised learning framework called Masked Image Residual Learning (MIRL), which significantly alleviates the degradation problem, making scaling ViT along depth a promising direction for performance upgrade. We reformulate the pretraining objective for deeper layers of ViT as learning to recover the residual of the masked image. We provide extensive empirical evidence showing that deeper ViTs can be effectively optimized using MIRL and easily gain accuracy from increased depth. With the same level of computational complexity as ViT-Base and ViT-Large, we instantiate 4.5× and 2× deeper ViTs, dubbed ViT-S-54 and ViT-B-48. The deeper ViT-S-54, costing 3× less than ViT-Large, achieves performance on par with ViT-Large. ViT-B-48 achieves 86.2% top-1 accuracy on ImageNet. On one hand, deeper ViTs pre-trained with MIRL exhibit excellent generalization capabilities on downstream tasks, such as object detection and semantic segmentation. On the other hand, MIRL demonstrates high pre-training efficiency. With less pre-training time, MIRL yields competitive performance compared to other approaches.

Architecture

Updates

09/Nov/2023

Upload the pre-trained and fine-tuned models.

07/Oct/2023

The preprint version is public at arxiv.

Prerequisites

This repo works with PaddlePaddle 2.3 or higher

Pretrain on ImageNet-1K

The pre-training instruction is in PRETRAIN.md.

The following table provides pretrained checkpoints and logs used in the paper.

model pre-trained 300 epochs pre-trained 800 epochs
ViT-B-48 checkpoint/log checkpoint/log
ViT-S-54 checkpoint/log checkpoint/log
ViT-B-24 checkpoint/log -

We also converted the Paddle checkpoint of ViT-B-48 pre-trained 800 epochs into a PyTorch version (download link) using script paddle2pytorch.py.

Main Results on ImageNet-1K

The fine-tuning instruction is in FINETUNE.md.

Encoder Method Pre-train Epochs FT acc@1(%) FT checkpoint/log
ViT-B BEiT 800 83.2 -
ViT-B MAE 1600 83.6 -
ViT-B MIRL 300/800 83.5/84.1 checkpoint/log
ViT-S-54 MIRL 300/800 84.4/84.8 checkpoint/log
ViT-B-24 MIRL 300 84.7 checkpoint/log
ViT-L MaskFeat 1600 85.7 -
ViT-L HPM 800 85.8 -
ViT-L MAE 1600 85.9 -
ViT-B-48 MIRL 300/800 85.3/86.2 checkpoint/log

Acknowledgement

Our project is based on MAE and its PaddlePaddle re-implementation in PLSC. Thanks for their wonderful work.

Citation

If you find this project useful in your research, please consider cite:

@article{huang2023masked,
  title={Masked Image Residual Learning for Scaling Deeper Vision Transformers},
  author={Huang, Guoxi and Fu, Hongtao and Bors, Adrian G},
  journal={arXiv preprint arXiv:2309.14136},
  year={2023}
}

Releases

No releases published

Packages

No packages published

Languages