Skip to content

Commit

Permalink
updated vit
Browse files Browse the repository at this point in the history
  • Loading branch information
anxiangsir committed Mar 29, 2022
1 parent 4833d54 commit 024196c
Show file tree
Hide file tree
Showing 39 changed files with 1,179 additions and 124 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ The supported methods are as follows:
- [x] [ArcFace_mxnet (CVPR'2019)](recognition/arcface_mxnet)
- [x] [ArcFace_torch (CVPR'2019)](recognition/arcface_torch)
- [x] [SubCenter ArcFace (ECCV'2020)](recognition/subcenter_arcface)
- [x] [PartialFC_mxnet (Arxiv'2020)](recognition/partial_fc)
- [x] [PartialFC_torch (Arxiv'2020)](recognition/arcface_torch)
- [x] [PartialFC_mxnet (CVPR'2022)](recognition/partial_fc)
- [x] [PartialFC_torch (CVPR'2022)](recognition/arcface_torch)
- [x] [VPL (CVPR'2021)](recognition/vpl)
- [x] [Arcface_oneflow](recognition/arcface_oneflow)
- [x] [ArcFace_Paddle (CVPR'2019)](recognition/arcface_paddle)
Expand Down
5 changes: 0 additions & 5 deletions recognition/arcface_torch/.gitignore

This file was deleted.

76 changes: 56 additions & 20 deletions recognition/arcface_torch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ identity on a single server.

## Requirements

- Install [PyTorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md).
- Install [PyTorch](http://pytorch.org) (torch>=1.9.0), our doc for [install.md](docs/install.md).
- (Optional) Install [DALI](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/), our doc for [install_dali.md](docs/install_dali.md).
- `pip install -r requirement.txt`.
- `pip install -r requirements.txt`.

## How to Training

Expand Down Expand Up @@ -58,26 +58,55 @@ For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with F
globalised multi-racial testset contains 242,143 identities and 1,624,305 images.


> 1. Large Scale Datasets
| Datasets | Backbone | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | Training Throughout | log |
|:-----------------|:------------|:------------|:------------|:------------|:--------------------|:------------------------------------------------------------------------------------------------------------------------------------------------|
| MS1MV3 | mobileface | 65.76 | 94.44 | 91.85 | ~13000 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_mobileface_lr02/training.log) |
| Glint360K | mobileface | 69.83 | 95.17 | 92.58 | -11000 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_mobileface_lr02_bs4k/training.log) |
| WF42M-PFC-0.2 | mobileface | 73.80 | 95.40 | 92.64 | (16GPUs)~18583 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_mobilefacenet_pfc02_bs8k_16gpus/training.log) |
| MS1MV3 | r100 | 83.23 | 96.88 | 95.31 | ~3400 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_r100_lr02/training.log) |
| Glint360K | r100 | 90.86 | 97.53 | 96.43 | ~5000 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_r100_lr02_bs4k_16gpus/training.log) |
| WF42M-PFC-0.2 | r50(bs4k) | 93.83 | 97.53 | 96.16 | (8 GPUs)~5900 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r50_bs4k_pfc02/training.log) |
| WF42M-PFC-0.2 | r50(bs8k) | 93.96 | 97.46 | 96.12 | (16GPUs)~11000 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r50_lr01_pfc02_bs8k_16gpus/training.log) |
| WF42M-PFC-0.2 | r50(bs4k) | 94.04 | 97.48 | 95.94 | (32GPUs)~17000 | click me |
| WF42M-PFC-0.0018 | r100(bs16k) | 93.08 | 97.51 | 95.88 | (32GPUs)~10000 | click me |
| WF42M-PFC-0.2 | r100(bs4k) | 96.69 | 97.85 | 96.63 | (16GPUs)~5200 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r100_bs4k_pfc02/training.log) |

> 2. VIT For Face Recognition
| Datasets | Backbone | FLOPs | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | Training Throughout | log |
|:--------------|:-------------|:------|:------------|:------------|:------------|:--------------------|:---------|
| WF42M-PFC-0.3 | R18(bs4k) | 2.6 | 79.13 | 95.77 | 93.36 | - | click me |
| WF42M-PFC-0.3 | R50(bs4k) | 6.3 | 94.03 | 97.48 | 95.94 | - | click me |
| WF42M-PFC-0.3 | R100(bs4k) | 12.1 | 96.69 | 97.82 | 96.45 | - | click me |
| WF42M-PFC-0.3 | R200(bs4k) | 23.5 | 97.70 | 97.97 | 96.93 | - | click me |
| WF42M-PFC-0.3 | VIT-T(bs24k) | 1.5 | 92.24 | 97.31 | 95.97 | (64GPUs)~35000 | click me |
| WF42M-PFC-0.3 | VIT-S(bs24k) | 5.7 | 95.87 | 97.73 | 96.57 | (64GPUs)~25000 | click me |
| WF42M-PFC-0.3 | VIT-B(bs24k) | 11.4 | 97.42 | 97.90 | 97.04 | (64GPUs)~13800 | click me |
| WF42M-PFC-0.3 | VIT-L(bs24k) | 25.3 | 97.85 | 98.00 | 97.23 | (64GPUs)~9406 | click me |

WF42M means WebFace42M, `PFC-0.3` means negivate class centers sample rate is 0.3.

> 3. Noisy Datasets
| Datasets | Backbone | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | log |
|:-------------------------|:---------|:------------|:------------|:------------|:---------|
| WF12M-Flip(40%) | R50 | 43.87 | 88.35 | 80.78 | click me |
| WF12M-Flip(40%)-PFC-0.3* | R50 | 80.20 | 96.11 | 93.79 | click me |
| WF12M-Conflict | R50 | 79.93 | 95.30 | 91.56 | click me |
| WF12M-Conflict-PFC-0.3* | R50 | 91.68 | 97.28 | 95.75 | click me |

WF12M means WebFace12M, `+PFC-0.3*` denotes additional abnormal inter-class filtering.

| Datasets | Backbone | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | Training Throughout | log |
|:-------------------------|:-----------|:------------|:------------|:------------|:--------------------|:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| MS1MV3 | mobileface | 65.76 | 94.44 | 91.85 | ~13000 | [log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_mobileface_lr02/training.log)\|[config](configs/ms1mv3_mobileface_lr02.py) |
| Glint360K | mobileface | 69.83 | 95.17 | 92.58 | -11000 | [log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_mobileface_lr02_bs4k/training.log)\|[config](configs/glint360k_mobileface_lr02_bs4k.py) |
| WebFace42M-PartialFC-0.2 | mobileface | 73.80 | 95.40 | 92.64 | (16GPUs)~18583 | [log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_mobilefacenet_pfc02_bs8k_16gpus/training.log)\|[config](configs/webface42m_mobilefacenet_pfc02_bs8k_16gpus.py) |
| MS1MV3 | r100 | 83.23 | 96.88 | 95.31 | ~3400 | [log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_r100_lr02/training.log)\|[config](configs/ms1mv3_r100_lr02.py) |
| Glint360K | r100 | 90.86 | 97.53 | 96.43 | ~5000 | [log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_r100_lr02_bs4k_16gpus/training.log)\|[config](configs/glint360k_r100_lr02_bs4k_16gpus.py) |
| WebFace42M-PartialFC-0.2 | r50(bs4k) | 93.83 | 97.53 | 96.16 | (8 GPUs)~5900 | [log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r50_bs4k_pfc02/training.log)\|[config](configs/webface42m_r50_lr01_pfc02_bs4k_8gpus.py) |
| WebFace42M-PartialFC-0.2 | r50(bs8k) | 93.96 | 97.46 | 96.12 | (16GPUs)~11000 | [log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r50_lr01_pfc02_bs8k_16gpus/training.log)\|[config](configs/webface42m_r50_lr01_pfc02_bs8k_16gpus.py) |
| WebFace42M-PartialFC-0.2 | r50(bs4k) | 94.04 | 97.48 | 95.94 | (32GPUs)~17000 | log\|[config](configs/webface42m_r50_lr01_pfc02_bs4k_32gpus.py) |
| WebFace42M-PartialFC-0.2 | r100(bs4k) | 96.69 | 97.85 | 96.63 | (16GPUs)~5200 | [log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r100_bs4k_pfc02/training.log)\|[config](configs/webface42m_r100_lr01_pfc02_bs4k_16gpus.py) |
| WebFace42M-PartialFC-0.2 | r200 | - | - | - | - | log\|config |

`PartialFC-0.2` means negivate class centers sample rate is 0.2.


## Speed Benchmark
<div><img src="https://github.com/anxiangsir/insightface_arcface_log/blob/master/pfc_exp.png" width = "90%" /></div>


`arcface_torch` can train large-scale face recognition training set efficiently and quickly. When the number of
**Arcface-Torch** can train large-scale face recognition training set efficiently and quickly. When the number of
classes in training sets is greater than 1 Million, partial fc sampling strategy will get same
accuracy with several times faster training performance and smaller GPU memory.
Partial FC is a sparse variant of the model parallel architecture for large sacle face recognition. Partial FC use a
Expand All @@ -86,12 +115,12 @@ sparse part of the parameters will be updated, which can reduce a lot of GPU mem
we can scale trainset of 29 millions identities, the largest to date. Partial FC also supports multi-machine distributed
training and mixed precision training.

![Image text](https://github.com/anxiangsir/insightface_arcface_log/blob/master/partial_fc_v2.png)


More details see
[speed_benchmark.md](docs/speed_benchmark.md) in docs.

### 1. Training speed of different parallel methods (samples / second), Tesla V100 32GB * 8. (Larger is better)
> 1. Training speed of different parallel methods (samples / second), Tesla V100 32GB * 8. (Larger is better)
`-` means training failed because of gpu memory limitations.

Expand All @@ -104,7 +133,7 @@ More details see
| 16000000 | **-** | **-** | 2679 |
| 29000000 | **-** | **-** | **1855** |

### 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better)
> 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better)
| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
|:--------------------------------|:--------------|:---------------|:---------------|
Expand All @@ -126,11 +155,18 @@ More details see
pages={4690--4699},
year={2019}
}
@inproceedings{an2022pfc,
title={Killing Two Birds with One Stone: Efficient and Robust Training of Face Recognition CNNs by Partial FC},
author={An, Xiang and Deng, Jiangkang and Guo, Jia and Feng, Ziyong and Zhu, Xuhan and Jing, Yang and Tongliang, Liu},
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
year={2022}
}
@inproceedings{an2020partical_fc,
title={Partial FC: Training 10 Million Identities on a Single Machine},
author={An, Xiang and Zhu, Xuhan and Xiao, Yang and Wu, Lan and Zhang, Ming and Gao, Yuan and Qin, Bin and
Zhang, Debing and Fu Ying},
booktitle={Arxiv 2010.05222},
booktitle={Proceedings of International Conference on Computer Vision Workshop},
pages={1445-1449},
year={2020}
}
```
62 changes: 61 additions & 1 deletion recognition/arcface_torch/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,69 @@ def get_model(name, **kwargs):
elif name == "r2060":
from .iresnet2060 import iresnet2060
return iresnet2060(False, **kwargs)

elif name == "mbf":
fp16 = kwargs.get("fp16", False)
num_features = kwargs.get("num_features", 512)
return get_mbf(fp16=fp16, num_features=num_features)

elif name == "mbf_large":
from .mobilefacenet import get_mbf_large
fp16 = kwargs.get("fp16", False)
num_features = kwargs.get("num_features", 512)
return get_mbf_large(fp16=fp16, num_features=num_features)

elif name == "vit_t":
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112, patch_size=9, num_classes=num_features, embed_dim=256, depth=12,
num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1)

elif name == "vit_t_dp005_mask0": # For WebFace42M
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112, patch_size=9, num_classes=num_features, embed_dim=256, depth=12,
num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.0)

elif name == "vit_s":
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=12,
num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1)

elif name == "vit_s_dp005_mask_0": # For WebFace42M
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=12,
num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.0)

elif name == "vit_b":
# this is a feature
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24,
num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1, using_checkpoint=True)

elif name == "vit_b_dp005_mask_005": # For WebFace42M
# this is a feature
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24,
num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.05, using_checkpoint=True)

elif name == "vit_l_dp005_mask_005": # For WebFace42M
# this is a feature
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112, patch_size=9, num_classes=num_features, embed_dim=768, depth=24,
num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.05, using_checkpoint=True)

else:
raise ValueError()
raise ValueError()
14 changes: 11 additions & 3 deletions recognition/arcface_torch/backbones/iresnet.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import torch
from torch import nn
from torch.utils.checkpoint import checkpoint

__all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200']

using_ckpt = False

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
Expand Down Expand Up @@ -43,7 +44,7 @@ def __init__(self, inplanes, planes, stride=1, downsample=None,
self.downsample = downsample
self.stride = stride

def forward(self, x):
def forard_impl(self, x):
identity = x
out = self.bn1(x)
out = self.conv1(out)
Expand All @@ -54,7 +55,13 @@ def forward(self, x):
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
return out

def forward(self, x):
if self.training and using_ckpt:
return checkpoint(self.forard_imlp, x)
else:
return self.forard_impl(x)


class IResNet(nn.Module):
Expand All @@ -63,6 +70,7 @@ def __init__(self,
block, layers, dropout=0, num_features=512, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
super(IResNet, self).__init__()
self.extra_gflops = 0.0
self.fp16 = fp16
self.inplanes = 64
self.dilation = 1
Expand Down

0 comments on commit 024196c

Please sign in to comment.