Skip to content
This repository has been archived by the owner on Nov 3, 2022. It is now read-only.

The implementation of resnet may not correct. #534

Open
CNOCycle opened this issue Mar 18, 2020 · 1 comment
Open

The implementation of resnet may not correct. #534

CNOCycle opened this issue Mar 18, 2020 · 1 comment

Comments

@CNOCycle
Copy link

The original resnet papaer summaries each layer's output size in Table 1.

The output size of last conv5_x layer (before avg_pool) is 7x7. But ResNet from keras-contrib is 4x4.

The following code can reproduce the result what I say:

# docker run --runtime=nvidia -it tensorflow/tensorflow:1.14.0-gpu-py3 bash
# apt update
# apt install -y git
# git clone https://github.com/keras-team/keras.git
# cd keras
# python setup.py install
# pip install git+https://www.github.com/keras-team/keras-contrib.git

import keras
from keras_contrib.applications.resnet import ResNet34
model = ResNet34(input_shape=(224,224,3), classes=1000)
model.summary()
batch_normalization_37 (BatchNo (None, 4, 4, 512)    2048        add_16[0][0]
__________________________________________________________________________________________________
activation_33 (Activation)      (None, 4, 4, 512)    0           batch_normalization_37[0][0]
__________________________________________________________________________________________________
global_average_pooling2d_1 (Glo (None, 512)          0           activation_33[0][0]
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 1000)         513000      global_average_pooling2d_1[0][0]
==================================================================================================
Total params: 21,827,624
Trainable params: 21,810,472
Non-trainable params: 17,152
__________________________________________________________________________________________________

I thought that the root cause is that the first cell's stride should handle carefully.

From pytorch implementation, they set frist cell's stride to 1 by default and set others cells' stride to 2 manully.

class ResNet(nn.Module):
    def __init__():
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])

From keras-team implementation, they set frist cell's stride to 1 manully and set others cells' stride to 2 by default.

def ResNet50(include_top=True,
             weights='imagenet',
             input_tensor=None,
             input_shape=None,
             pooling=None,
             classes=1000,
             **kwargs):
    def stack_fn(x):
        x = stack1(x, 64, 3, stride1=1, name='conv2')
        x = stack1(x, 128, 4, name='conv3')
        x = stack1(x, 256, 6, name='conv4')
        x = stack1(x, 512, 3, name='conv5')

However, from keras-contrib implementation, they do nothing when building each cell.

filters = initial_filters
for i, r in enumerate(repetitions):
transition_dilation_rates = [transition_dilation_rate] * r
transition_strides = [(1, 1)] * r
if transition_dilation_rate == (1, 1):
transition_strides[0] = (2, 2)
block = _residual_block(block_fn, filters=filters,
stage=i, blocks=r,
is_first_layer=(i == 0),
dropout=dropout,
transition_dilation_rates=transition_dilation_rates,
transition_strides=transition_strides,
residual_unit=residual_unit)(block)
filters *= 2

This issue could be fixed easily. We only need an extra condition to check whether current cell is the frist cell.

     for i, r in enumerate(repetitions):
         transition_dilation_rates = [transition_dilation_rate] * r
         transition_strides = [(1, 1)] * r
-        if transition_dilation_rate == (1, 1):
+        if transition_dilation_rate == (1, 1) and filters != initial_filters:
             transition_strides[0] = (2, 2)
@CNOCycle
Copy link
Author

After carefully comparing different implementations, I thought that _residual_block is also not correct.

The original paper summaries that residual function in Figure 5. For basic black, a relu layer is followed by Add layer.

However, basic_block which is implemented by keras-contrib returns immediately.

def basic_block(filters, stage, block, transition_strides=(1, 1),
dilation_rate=(1, 1), is_first_block_of_first_layer=False, dropout=None,
residual_unit=_bn_relu_conv):
"""Basic 3 X 3 convolution blocks for use on resnets with layers <= 34.
Follows improved proposed scheme in http://arxiv.org/pdf/1603.05027v2.pdf
"""
def f(input_features):
conv_name_base, bn_name_base = _block_name_base(stage, block)
if is_first_block_of_first_layer:
# don't repeat bn->relu since we just did bn->relu->maxpool
x = Conv2D(filters=filters, kernel_size=(3, 3),
strides=transition_strides,
dilation_rate=dilation_rate,
padding="same",
kernel_initializer="he_normal",
kernel_regularizer=l2(1e-4),
name=conv_name_base + '2a')(input_features)
else:
x = residual_unit(filters=filters, kernel_size=(3, 3),
strides=transition_strides,
dilation_rate=dilation_rate,
conv_name_base=conv_name_base + '2a',
bn_name_base=bn_name_base + '2a')(input_features)
if dropout is not None:
x = Dropout(dropout)(x)
x = residual_unit(filters=filters, kernel_size=(3, 3),
conv_name_base=conv_name_base + '2b',
bn_name_base=bn_name_base + '2b')(x)
return _shortcut(input_features, x)
return f

This issue could be fixed by the following patch:

-        return _shortcut(input_features, x)
+        x = _shortcut(input_features, x)
+        return Activation("relu")(x)

The second thing is I'm not sure why the condition is_first_block_of_first_layer should be checked.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant