Skip to content

Latest commit

 

History

History

dbnet

English | 中文

DBNet and DBNet++

DBNet: Real-time Scene Text Detection with Differentiable Binarization DBNet++: Real-Time Scene Text Detection with Differentiable Binarization and Adaptive Scale Fusion

1. Introduction

DBNet

DBNet is a segmentation-based scene text detection method. Segmentation-based methods are gaining popularity for scene text detection purposes as they can more accurately describe scene text of various shapes, such as curved text. The drawback of current segmentation-based SOTA methods is the post-processing of binarization (conversion of probability maps into text bounding boxes) which often requires a manually set threshold (reduces prediction accuracy) and complex algorithms for grouping pixels (resulting in a considerable time cost during inference). To eliminate the problem described above, DBNet integrates an adaptive threshold called Differentiable Binarization(DB) into the architecture. DB simplifies post-processing and enhances the performance of text detection.Moreover, it can be removed in the inference stage without sacrificing performance.[1]

Figure 1. Overall DBNet architecture

Figure 1. Overall DBNet architecture

The overall architecture of DBNet is presented in Figure 1. It consists of multiple stages:

  1. Feature extraction from a backbone at different scales. ResNet-50 is used as a backbone, and features are extracted from stages 2, 3, 4, and 5.
  2. The extracted features are upscaled and summed up with the previous stage features in a cascade fashion.
  3. The resulting features are upscaled once again to match the size of the largest feature map (from the stage 2) and concatenated along the channel axis.
  4. Then, the final feature map (shown in dark blue) is used to predict both the probability and threshold maps by applying 3×3 convolutional operator and two de-convolutional operators with stride 2.
  5. The probability and threshold maps are merged into one approximate binary map by the Differentiable binarization module. The approximate binary map is used to generate text bounding boxes.

DBNet++

DBNet++ is an extension of DBNet and thus replicates its architecture. The only difference is that instead of concatenating extracted and scaled features from the backbone as DBNet did, DBNet++ uses an adaptive way to fuse those features called Adaptive Scale Fusion (ASF) module (Figure 2). It improves the scale robustness of the network by fusing features of different scales adaptively. By using ASF, DBNet++’s ability to detect text instances of diverse scales is distinctly strengthened.[2]

Figure 2. Overall DBNet++ architecture

Figure 2. Overall DBNet++ architecture

Figure 3. Detailed architecture of the Adaptive Scale Fusion module

Figure 3. Detailed architecture of the Adaptive Scale Fusion module

ASF consists of two attention modules – stage-wise attention and spatial attention, where the latter is integrated in the former as described in the Figure 3. The stage-wise attention module learns the weights of the feature maps of different scales. While the spatial attention module learns the attention across the spatial dimensions. The combination of these two modules leads to scale-robust feature fusion. DBNet++ performs better in detecting text instances of diverse scales, especially for large-scale text instances where DBNet may generate inaccurate or discrete bounding boxes.

2. General purpose models

Here we present general purpose models that were trained on wide variety of tasks (real-world photos, street views, documents, etc.) and challenges (straight texts, curved texts, long text lines, etc.) with two primary languages: Chinese and English. These models can be used right off-the-shelf in your applications or for initialization of your models.

The models were trained on 12 public datasets (CTW, LSVT, RCTW-17, TextOCR, etc.) that contain wide range of images. The training set has 153,511 images and the validation set has 9,786 images.
The test set consists of 598 images manually selected from the above-mentioned datasets.

Model Context Backbone Languages F-score on Our Test Set Throughput Download
DBNet D910x8-MS2.0-G ResNet-50 Chinese + English 83.41% 256 img/s ckpt | mindir
DBNet++ D910x4-MS2.0-G ResNet-50 Chinese + English 84.30% 104 img/s ckpt | mindir

The input_shape for exported DBNet MindIR and DBNet++ MindIR in the links are (1,3,736,1280) and (1,3,1152,2048), respectively.

3. Results

DBNet and DBNet++ were trained on the ICDAR2015, MSRA-TD500, SCUT-CTW1500, Total-Text, and MLT2017 datasets. In addition, we conducted pre-training on the SynthText dataset and provided a URL to download pretrained weights. All training results are as follows:

ICDAR2015

Model Context Backbone Pretrained Recall Precision F-score Train T. Throughput Recipe Download
DBNet D910x1-MS2.0-G MobileNetV3 ImageNet 76.31% 78.27% 77.28% 10 s/epoch 100 img/s yaml ckpt | mindir
DBNet D910x8-MS2.3-G MobileNetV3 ImageNet 76.22% 77.98% 77.09% 1.1 s/epoch 960 img/s yaml Coming soon
DBNet D910x1-MS2.0-G ResNet-18 ImageNet 80.12% 83.41% 81.73% 9.3 s/epoch 108 img/s yaml ckpt | mindir
DBNet D910x1-MS2.0-G ResNet-50 ImageNet 83.53% 86.62% 85.05% 13.3 s/epoch 75.2 img/s yaml ckpt | mindir
DBNet D910x8-MS2.2-G ResNet-50 ImageNet 82.62% 88.54% 85.48% 2.3 s/epoch 435 img/s yaml Coming soon
DBNet++ D910x1-MS2.0-G ResNet-50 SynthText 85.70% 87.81% 86.74% 17.7 s/epoch 56 img/s yaml ckpt | mindir
DBNet++ D910x1-MS2.2-G ResNet-50 SynthText 86.81% 86.85% 86.86% 12.7 s/epoch 78.2 img/s yaml ckpt | mindir

The input_shape for exported DBNet MindIR and DBNet++ MindIR in the links are (1,3,736,1280) and (1,3,1152,2048), respectively.

MSRA-TD500

Model Context Backbone Pretrained Recall Precision F-score Train T. Throughput Recipe Download
DBNet D910x1-MS2.0-G ResNet-18 SynthText 79.90% 88.07% 83.78% 5.6 s/epoch 121.7 img/s yaml ckpt
DBNet D910x1-MS2.0-G ResNet-50 SynthText 84.02% 87.48% 85.71% 9.6 s/epoch 71.2 img/s yaml ckpt

MSRA-TD500 dataset has 300 training images and 200 testing images, reference paper Real-time Scene Text Detection with Differentiable Binarization, we trained using an extra 400 traning images from HUST-TR400. You can down all dataset for training.

SCUT-CTW1500

Model Context Backbone Pretrained Recall Precision F-score Train T. Throughput Recipe Download
DBNet D910x1-MS2.0-G ResNet-18 SynthText 85.68% 85.33% 85.50% 8.2 s/epoch 122.1 img/s yaml ckpt
DBNet D910x1-MS2.0-G ResNet-50 SynthText 87.83% 84.71% 86.25% 14.0 s/epoch 71.4 img/s yaml ckpt

Total-Text

Model Context Backbone Pretrained Recall Precision F-score Train T. Throughput Recipe Download
DBNet D910x1-MS2.0-G ResNet-18 SynthText 83.66% 87.61% 85.59% 12.9 s/epoch 96.9 img/s yaml ckpt
DBNet D910x1-MS2.0-G ResNet-50 SynthText 84.79% 87.07% 85.91% 18.0 s/epoch 69.1 img/s yaml ckpt

MLT2017

Model Context Backbone Pretrained Recall Precision F-score Train T. Throughput Recipe Download
DBNet D910x8-MS2.0-G ResNet-18 SynthText 73.62% 83.93% 78.44% 20.9 s/epoch 344.8 img/s yaml ckpt
DBNet D910x8-MS2.0-G ResNet-50 SynthText 76.04% 84.51% 80.05% 23.6 s/epoch 305.6 img/s yaml ckpt

SynthText

Model Context Backbone Pretrained Train Loss Train T. Throughput Recipe Download
DBNet D910x1-MS2.0-G ResNet-18 ImageNet 2.41 7075 s/epoch 121.37 img/s yaml ckpt
DBNet D910x1-MS2.0-G ResNet-50 ImageNet 2.25 10470 s/epoch 82.02 img/s yaml ckpt

Notes

  • Context: Training context denoted as {device}x{pieces}-{MS version}{MS mode}, where mindspore mode can be G - graph mode or F - pynative mode with ms function. For example, D910x8-G is for training on 8 pieces of Ascend 910 NPU using graph mode.
  • Note that the training time of DBNet is highly affected by data processing and varies on different machines.

4. Quick Start

4.1 Installation

Please refer to the installation instruction in MindOCR.

4.2 Dataset preparation

4.2.1 ICDAR2015 dataset

Please download ICDAR2015 dataset, and convert the labels to the desired format referring to dataset_converters.

The prepared dataset file struture should be:

.
├── test
│   ├── images
│   │   ├── img_1.jpg
│   │   ├── img_2.jpg
│   │   └── ...
│   └── test_det_gt.txt
└── train
    ├── images
    │   ├── img_1.jpg
    │   ├── img_2.jpg
    │   └── ....jpg
    └── train_det_gt.txt

4.2.2 MSRA-TD500 dataset

Please download MSRA-TD500 dataset,and convert the labels to the desired format referring to dataset_converters.

The prepared dataset file struture should be:

MSRA-TD500
 ├── test
 │   ├── IMG_0059.gt
 │   ├── IMG_0059.JPG
 │   ├── IMG_0080.gt
 │   ├── IMG_0080.JPG
 │   ├── ...
 │   ├── train_det_gt.txt
 ├── train
 │   ├── IMG_0030.gt
 │   ├── IMG_0030.JPG
 │   ├── IMG_0063.gt
 │   ├── IMG_0063.JPG
 │   ├── ...
 │   ├── test_det_gt.txt

4.2.3 SCUT-CTW1500 dataset

Please download SCUT-CTW1500 dataset,and convert the labels to the desired format referring to dataset_converters.

The prepared dataset file struture should be:

ctw1500
 ├── test_images
 │   ├── 1001.jpg
 │   ├── 1002.jpg
 │   ├── ...
 ├── train_images
 │   ├── 0001.jpg
 │   ├── 0002.jpg
 │   ├── ...
 ├── test_det_gt.txt
 ├── train_det_gt.txt

4.2.4 Total-Text dataset

Please download Total-Text dataset,and convert the labels to the desired format referring to dataset_converters.

The prepared dataset file struture should be:

totaltext
 ├── Images
 │   ├── Train
 │   │   ├── img1001.jpg
 │   │   ├── img1002.jpg
 │   │   ├── ...
 │   ├── Test
 │   │   ├── img1.jpg
 │   │   ├── img2.jpg
 │   │   ├── ...
 ├── test_det_gt.txt
 ├── train_det_gt.txt

4.2.5 MLT2017 dataset

The MLT2017 dataset is a multilingual text detection and recognition dataset that includes nine languages: Chinese, Japanese, Korean, English, French, Arabic, Italian, German, and Hindi. Please download MLT2017 and extract the dataset. Then convert the .gif format images in the data to .jpg or .png format, and convert the labels to the desired format referring to dataset_converters.

The prepared dataset file struture should be:

MLT_2017
 ├── train
 │   ├── img_1.png
 │   ├── img_2.png
 │   ├── img_3.jpg
 │   ├── img_4.jpg
 │   ├── ...
 ├── validation
 │   ├── img_1.jpg
 │   ├── img_2.jpg
 │   ├── ...
 ├── train_det_gt.txt
 ├── validation_det_gt.txt

If users want to use their own dataset for training, please convert the labels to the desired format referring to dataset_converters. Then configure the yaml file, and use a single or multiple devices to run train.py for training. For detailed information, please refer to the following tutorials.

4.2.6 SynthText dataset

Please download SynthText dataset and process it as described in dataset_converters

.
├── SynthText
│   ├── 1
│   │   ├── img_1.jpg
│   │   ├── img_2.jpg
│   │   └── ...
│   ├── 2
│   │   ├── img_1.jpg
│   │   ├── img_2.jpg
│   │   └── ...
│   ├── ...
│   ├── 200
│   │   ├── img_1.jpg
│   │   ├── img_2.jpg
│   │   └── ...
│   └── gt.mat

⚠️ Additionally, It is strongly recommended to pre-process the SynthText dataset before using it as it contains some faulty data:

python tools/dataset_converters/convert.py --dataset_name=synthtext --task=det --label_dir=/path-to-data-dir/SynthText/gt.mat --output_path=/path-to-data-dir/SynthText/gt_processed.mat

This operation will generate a filtered output in the same format as the original SynthText.

4.3 Update yaml config file

Update configs/det/dbnet/db_r50_icdar15.yaml configuration file with data paths, specifically the following parts. The dataset_root will be concatenated with data_dir and label_file respectively to be the complete dataset directory and label file path.

...
train:
  ckpt_save_dir: './tmp_det'
  dataset_sink_mode: False
  dataset:
    type: DetDataset
    dataset_root: dir/to/dataset          <--- Update
    data_dir: train/images                <--- Update
    label_file: train/train_det_gt.txt    <--- Update
...
eval:
  dataset_sink_mode: False
  dataset:
    type: DetDataset
    dataset_root: dir/to/dataset          <--- Update
    data_dir: test/images                 <--- Update
    label_file: test/test_det_gt.txt      <--- Update
...

Optionally, change num_workers according to the cores of CPU.

DBNet consists of 3 parts: backbone, neck, and head. Specifically:

model:
  type: det
  transform: null
  backbone:
    name: det_resnet50  # Only ResNet50 is supported at the moment
    pretrained: True    # Whether to use weights pretrained on ImageNet
  neck:
    name: DBFPN         # FPN part of the DBNet
    out_channels: 256
    bias: False
    use_asf: False      # Adaptive Scale Fusion module from DBNet++ (use it for DBNet++ only)
  head:
    name: DBHead
    k: 50               # amplifying factor for Differentiable Binarization
    bias: False
    adaptive: True      # True for training, False for inference

4.4 Training

  • Standalone training

Please set distribute in yaml config file to be False.

python tools/train.py -c=configs/det/dbnet/db_r50_icdar15.yaml
  • Distributed training

Please set distribute in yaml config file to be True.

# n is the number of GPUs/NPUs
mpirun --allow-run-as-root -n 2 python tools/train.py --config configs/det/dbnet/db_r50_icdar15.yaml

The training result (including checkpoints, per-epoch performance and curves) will be saved in the directory parsed by the arg ckpt_save_dir in yaml config file. The default directory is ./tmp_det.

4.5 Evaluation

To evaluate the accuracy of the trained model, you can use eval.py. Please set the checkpoint path to the arg ckpt_load_path in the eval section of yaml config file, set distribute to be False, and then run:

python tools/eval.py -c=configs/det/dbnet/db_r50_icdar15.yaml

5. MindSpore Lite Inference

Please refer to the tutorial MindOCR Inference for model inference based on MindSpot Lite on Ascend 310, including the following steps:

  • Model Export

Please download the exported MindIR file first, or refer to the Model Export tutorial and use the following command to export the trained ckpt model to MindIR file:

python tools/export.py --model_name_or_config dbnet_resnet50 --data_shape 736 1280 --local_ckpt_path /path/to/local_ckpt.ckpt
# or
python tools/export.py --model_name_or_config configs/det/dbnet/db_r50_icdar15.yaml --data_shape 736 1280 --local_ckpt_path /path/to/local_ckpt.ckpt

The data_shape is the model input shape of height and width for MindIR file. The shape value of MindIR in the download link can be found in ICDAR2015 Notes.

  • Environment Installation

Please refer to Environment Installation tutorial to configure the MindSpore Lite inference environment.

  • Model Conversion

Please refer to Model Conversion, and use the converter_lite tool for offline conversion of the MindIR file.

  • Inference

Assuming that you obtain output.mindir after model conversion, go to the deploy/py_infer directory, and use the following command for inference:

python infer.py \
    --input_images_dir=/your_path_to/test_images \
    --det_model_path=your_path_to/output.mindir \
    --det_model_name_or_config=../../configs/det/dbnet/db_r50_icdar15.yaml \
    --res_save_dir=results_dir

References

[1] Minghui Liao, Zhaoyi Wan, Cong Yao, Kai Chen, Xiang Bai. Real-time Scene Text Detection with Differentiable Binarization. arXiv:1911.08947, 2019

[2] Minghui Liao, Zhisheng Zou, Zhaoyi Wan, Cong Yao, Xiang Bai. Real-Time Scene Text Detection with Differentiable Binarization and Adaptive Scale Fusion. arXiv:2202.10304, 2022