Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MogaNet Implementation and Models #1691

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
23 changes: 23 additions & 0 deletions configs/_base_/models/moganet/moganet_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='MogaNet', arch='base', drop_path_rate=0.2,
attn_force_fp32=True,
),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=512,
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
init_cfg=None,
),
init_cfg=dict(
type='TruncNormal', layer=['Conv2d', 'Linear'], std=.02, bias=0.),
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)
23 changes: 23 additions & 0 deletions configs/_base_/models/moganet/moganet_large.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='MogaNet', arch='large', drop_path_rate=0.3,
attn_force_fp32=False,
),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=640,
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
init_cfg=None,
),
init_cfg=dict(
type='TruncNormal', layer=['Conv2d', 'Linear'], std=.02, bias=0.),
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)
23 changes: 23 additions & 0 deletions configs/_base_/models/moganet/moganet_small.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='MogaNet', arch='small', drop_path_rate=0.1,
attn_force_fp32=True,
),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=512,
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
init_cfg=None,
),
init_cfg=dict(
type='TruncNormal', layer=['Conv2d', 'Linear'], std=.02, bias=0.),
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)
23 changes: 23 additions & 0 deletions configs/_base_/models/moganet/moganet_tiny.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='MogaNet', arch='tiny', drop_path_rate=0.1,
attn_force_fp32=True,
),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=256,
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
init_cfg=None,
),
init_cfg=dict(
type='TruncNormal', layer=['Conv2d', 'Linear'], std=.02, bias=0.),
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.1),
dict(type='CutMix', alpha=1.0),
]),
)
23 changes: 23 additions & 0 deletions configs/_base_/models/moganet/moganet_xlarge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='MogaNet', arch='x-large', drop_path_rate=0.4,
attn_force_fp32=True,
),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=960,
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
init_cfg=None,
),
init_cfg=dict(
type='TruncNormal', layer=['Conv2d', 'Linear'], std=.02, bias=0.),
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)
23 changes: 23 additions & 0 deletions configs/_base_/models/moganet/moganet_xtiny.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='MogaNet', arch='x-tiny', drop_path_rate=0.05,
attn_force_fp32=True,
),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=192,
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
init_cfg=None,
),
init_cfg=dict(
type='TruncNormal', layer=['Conv2d', 'Linear'], std=.02, bias=0.),
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.1),
dict(type='CutMix', alpha=1.0),
]),
)
81 changes: 81 additions & 0 deletions configs/moganet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Efficient Multi-order Gated Aggregation Network

> [Efficient Multi-order Gated Aggregation Network](https://arxiv.org/abs/2211.03295)

<!-- [ALGORITHM] -->

## Abstract

Since the recent success of Vision Transformers (ViTs), explorations toward ViT-style architectures have triggered the resurgence of ConvNets. In this work, we explore the representation ability of modern ConvNets from a novel view of multi-order game-theoretic interaction, which reflects inter-variable interaction effects w.r.t.~contexts of different scales based on game theory. Within the modern ConvNet framework, we tailor the two feature mixers with conceptually simple yet effective depthwise convolutions to facilitate middle-order information across spatial and channel spaces respectively. In this light, a new family of pure ConvNet architecture, dubbed MogaNet, is proposed, which shows excellent scalability and attains competitive results among state-of-the-art models with more efficient use of parameters on ImageNet and multifarious typical vision benchmarks, including COCO object detection, ADE20K semantic segmentation, 2D\&3D human pose estimation, and video prediction. Typically, MogaNet hits 80.0\% and 87.8\% top-1 accuracy with 5.2M and 181M parameters on ImageNet, outperforming ParC-Net-S and ConvNeXt-L while saving 59\% FLOPs and 17M parameters. The source code is available at https://github.com/Westlake-AI/MogaNet.

<div align=center>
<img src="https://user-images.githubusercontent.com/44519745/200625735-86bd2237-5bbe-43c1-ab37-049810b8d8a1.jpg" width="100%"/>
</div>

## How to use it?

<!-- [TABS-BEGIN] -->

**Predict image**

```python
from mmpretrain import inference_model

predict = inference_model('moganet-tiny_3rdparty_8xb128_in1k', 'demo/bird.JPEG')
print(predict['pred_class'])
print(predict['pred_score'])
```

**Use the model**

```python
import torch
from mmpretrain import get_model

model = get_model('moganet-tiny_3rdparty_8xb128_in1k', pretrained=True)
inputs = torch.rand(1, 3, 224, 224)
out = model(inputs)
print(type(out))
# To extract features.
feats = model.extract_feat(inputs)
print(type(feats))
```

**Test Command**

Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset).

Test:

```shell
python tools/test.py configs/moganet/moganet-tiny_8xb128_in1k.py https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_tiny_sz224_8xb128_fp16_ep300.pth
```

<!-- [TABS-END] -->

## Models and results

### Image Classification on ImageNet-1k

| Model | Pretrain | Params (M) | Flops (G) | Top-1 (%) | Top-5 (%) | Config | Download |
| :-------------------------------------- | :----------: | :--------: | :-------: | :-------: | :-------: | :-------: | :-------: |
| `moganet-xtiny_3rdparty_8xb128_in1k`\* | From scratch | 2.97 | 0.79 | 76.48 | 93.49 | [config](moganet-xtiny_8xb128_in1k.py) | [model](https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-xtiny_3rdparty_8xb128_in1k.pth) |
| `moganet-tiny_3rdparty_8xb128_in1k`\* | From scratch | 5.20 | 1.09 | 77.24 | 93.51 | [config](moganet-tiny_8xb128_in1k.py) | [model](https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-tiny_3rdparty_8xb128_in1k.pth) |
| `moganet-small_3rdparty_8xb128_in1k`\* | From scratch | 4.94 | 25.35 | 83.38 | 96.58 | [config](moganet-small_8xb128_in1k.py) | [model](https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-small_3rdparty_8xb128_in1k.pth) |
| `moganet-base_3rdparty_8xb128_in1k`\* | From scratch | 9.88 | 43.72 | 84.20 | 96.77 | [config](moganet-base_8xb128_in1k.py) | [model](https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-base_3rdparty_8xb128_in1k.pth) |
| `moganet-large_3rdparty_8xb128_in1k`\* | From scratch | 15.84 | 82.48 | 84.76 | 97.15 | [config](moganet-large_8xb128_in1k.py) | [model](https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-large_3rdparty_8xb128_in1k.pth) |
| `moganet-xlarge_3rdparty_16xb32_in1k`\* | From scratch | 34.43 | 180.8 | 85.11 | 97.38 | [config](moganet-xlarge_16xb32_in1k.py) | [model](https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-xlarge_3rdparty_16xb32_in1k.pth) |

*Models with * are converted from the [official repo](https://github.com/Westlake-AI/MogaNet). The config files of these models are only for inference. We haven't reproduce the training results.*

## Citation

```bibtex
@article{Li2022MogaNet,
title={Efficient Multi-order Gated Aggregation Network},
author={Siyuan Li and Zedong Wang and Zicheng Liu and Cheng Tan and Haitao Lin and Di Wu and Zhiyuan Chen and Jiangbin Zheng and Stan Z. Li},
journal={ArXiv},
year={2022},
volume={abs/2211.03295}
}
```
113 changes: 113 additions & 0 deletions configs/moganet/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
Collections:
- Name: MogaNet
Metadata:
Training Data: ImageNet-1k
Architecture:
- Gating
- 1x1 Convolution
- LayerScale
Paper:
URL: https://arxiv.org/abs/2211.03295
Title: Efficient Multi-order Gated Aggregation Network
README: configs/moganet/README.md
Code:
Version: v1.0.0
URL: https://github.com/Lupin1998/mmpretrain/tree/main/mmpretrain/models/backbones/moganet.py

Models:
- Name: moganet-xtiny_3rdparty_8xb128_in1k
Metadata:
FLOPs: 843961073
Parameters: 3114270
In Collection: MogaNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 76.48
Top 5 Accuracy: 93.49
Task: Image Classification
Weights: https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-xtiny_3rdparty_8xb128_in1k.pth
Config: configs/moganet/moganet-xtiny_8xb128_in1k.py
Converted From:
Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_xtiny_sz224_8xb128_fp16_ep300.pth
Code: https://github.com/Westlake-AI/openmixup
- Name: moganet-tiny_3rdparty_8xb128_in1k
Metadata:
FLOPs: 1168231104
Parameters: 5449449
In Collection: MogaNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 77.24
Top 5 Accuracy: 93.51
Task: Image Classification
Weights: https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-tiny_3rdparty_8xb128_in1k.pth
Config: configs/moganet/moganet-tiny_8xb128_in1k.py
Converted From:
Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_tiny_sz224_8xb128_fp16_ep300.pth
Code: https://github.com/Westlake-AI/openmixup
- Name: moganet-small_3rdparty_8xb128_in1k
Metadata:
FLOPs: 5304284610
Parameters: 26566721
In Collection: MogaNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 83.38
Top 5 Accuracy: 96.58
Task: Image Classification
Weights: https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-small_3rdparty_8xb128_in1k.pth
Config: configs/moganet/moganet-small_8xb128_in1k.py
Converted From:
Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_small_sz224_8xb128_fp16_ep300.pth
Code: https://github.com/Westlake-AI/openmixup
- Name: moganet-base_3rdparty_8xb128_in1k
Metadata:
FLOPs: 10608569221
Parameters: 45843742
In Collection: MogaNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 84.20
Top 5 Accuracy: 96.77
Task: Image Classification
Weights: https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-base_3rdparty_8xb128_in1k.pth
Config: configs/moganet/moganet-base_8xb128_in1k.py
Converted From:
Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_base_sz224_8xb128_fp16_ep300.pth
Code: https://github.com/Westlake-AI/openmixup
- Name: moganet-large_3rdparty_8xb128_in1k
Metadata:
FLOPs: 17008070492
Parameters: 86486548
In Collection: MogaNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 84.76
Top 5 Accuracy: 97.15
Task: Image Classification
Weights: https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-large_3rdparty_8xb128_in1k.pth
Config: configs/moganet/moganet-large_8xb128_in1k.py
Converted From:
Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_large_sz224_8xb64_accu2_ep300.pth
Code: https://github.com/Westlake-AI/openmixup
- Name: moganet-xlarge_3rdparty_16xb32_in1k
Metadata:
FLOPs: 36968931000
Parameters: 189582540
In Collection: MogaNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 85.11
Top 5 Accuracy: 97.38
Task: Image Classification
Weights: https://github.com/Lupin1998/mmpretrain/releases/download/moganet-in1k-weights/moganet-xlarge_3rdparty_16xb32_in1k.pth
Config: configs/moganet/moganet-xlarge_16xb32_in1k.py
Converted From:
Weights: https://github.com/Westlake-AI/openmixup/releases/download/moganet-in1k-weights/moga_xlarge_ema_sz224_8xb32_accu2_ep300.pth
Code: https://github.com/Westlake-AI/openmixup
42 changes: 42 additions & 0 deletions configs/moganet/moganet-base_8xb128_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
_base_ = [
'../_base_/models/moganet/moganet_base.py',
'../_base_/datasets/imagenet_bs128_poolformer_small_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]

# schedule settings
optim_wrapper = dict(
paramwise_cfg=dict(
norm_decay_mult=0.0,
bias_decay_mult=0.0,
custom_keys={
'.layer_scale': dict(decay_mult=0.0),
'.scale': dict(decay_mult=0.0),
}),
)

# learning policy
param_scheduler = [
# warm up learning rate scheduler
dict(
type='LinearLR',
start_factor=1e-3,
by_epoch=True,
end=5,
# update by iter
convert_to_iter_based=True),
# main learning rate scheduler
dict(type='CosineAnnealingLR', eta_min=1e-5, by_epoch=False, begin=5)
]

# runtime setting
custom_hooks = [
dict(type='PreciseBNHook', num_samples=8192, priority='ABOVE_NORMAL'),
dict(type='EMAHook', momentum=1e-4, priority='ABOVE_NORMAL')
]

# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
# base_batch_size = (8 GPUs) x (128 samples per GPU)
auto_scale_lr = dict(base_batch_size=1024)