Skip to content

Commit

Permalink
fix overwrite builtin layer destruction (#4732)
Browse files Browse the repository at this point in the history
* fix overwrite builtin layer destruction

* make modelbin class copyable

* test++
  • Loading branch information
nihui committed May 17, 2023
1 parent f893d24 commit 903ec7c
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 17 deletions.
5 changes: 5 additions & 0 deletions src/modelbin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ ModelBin::~ModelBin()
{
}

Mat ModelBin::load(int /*w*/, int /*type*/) const
{
return Mat();
}

Mat ModelBin::load(int w, int h, int type) const
{
Mat m = load(w * h, type);
Expand Down
2 changes: 1 addition & 1 deletion src/modelbin.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class NCNN_EXPORT ModelBin
// 2 = float16
// 3 = int8
// load vec
virtual Mat load(int w, int type) const = 0;
virtual Mat load(int w, int type) const;
// load image
virtual Mat load(int w, int h, int type) const;
// load dim
Expand Down
21 changes: 20 additions & 1 deletion src/net.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2025,7 +2025,26 @@ void Net::clear()
}
else
{
delete layer;
// check overwrite builtin layer destroyer
int index = -1;
const size_t overwrite_builtin_layer_registry_entry_count = d->overwrite_builtin_layer_registry.size();
for (size_t i = 0; i < overwrite_builtin_layer_registry_entry_count; i++)
{
if (d->overwrite_builtin_layer_registry[i].typeindex == layer->typeindex)
{
index = i;
break;
}
}

if (index != -1 && d->overwrite_builtin_layer_registry[index].destroyer)
{
d->overwrite_builtin_layer_registry[index].destroyer(layer, d->overwrite_builtin_layer_registry[index].userdata);
}
else
{
delete layer;
}
}
}
d->layers.clear();
Expand Down
98 changes: 83 additions & 15 deletions tests/test_squeezenet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,25 +241,93 @@ static int test_squeezenet(const ncnn::Option& opt, int load_model_type, float e
return check_top2(cls_scores, epsilon);
}

class MySoftmax : public ncnn::Layer
class MyConvolution : public ncnn::Layer
{
public:
MySoftmax()
MyConvolution()
{
one_blob_only = true;
support_inplace = true;
impl = ncnn::create_layer("Convolution");

one_blob_only = impl->one_blob_only;
support_inplace = impl->support_inplace;

support_packing = impl->support_packing;
support_vulkan = impl->support_vulkan;
support_bf16_storage = impl->support_bf16_storage;
support_fp16_storage = impl->support_fp16_storage;
support_int8_storage = impl->support_int8_storage;
support_image_storage = impl->support_image_storage;
}

~MyConvolution()
{
delete impl;
}

virtual int load_param(const ncnn::ParamDict& pd)
{
#if NCNN_VULKAN
impl->vkdev = vkdev;
#endif // NCNN_VULKAN

return impl->load_param(pd);
}

virtual int load_model(const ncnn::ModelBin& mb)
{
return impl->load_model(mb);
}

virtual int create_pipeline(const ncnn::Option& opt)
{
int ret = impl->create_pipeline(opt);

one_blob_only = impl->one_blob_only;
support_inplace = impl->support_inplace;

support_packing = impl->support_packing;
support_vulkan = impl->support_vulkan;
support_bf16_storage = impl->support_bf16_storage;
support_fp16_storage = impl->support_fp16_storage;
support_int8_storage = impl->support_int8_storage;
support_image_storage = impl->support_image_storage;

return ret;
}

virtual int destroy_pipeline(const ncnn::Option& opt)
{
return impl->destroy_pipeline(opt);
}

virtual int forward(const ncnn::Mat& bottom_blob, ncnn::Mat& top_blob, const ncnn::Option& opt) const
{
return impl->forward(bottom_blob, top_blob, opt);
}

#if NCNN_VULKAN
virtual int upload_model(ncnn::VkTransfer& cmd, const ncnn::Option& opt)
{
return impl->upload_model(cmd, opt);
}

virtual int forward_inplace(ncnn::Mat& bottom_top_blob, const ncnn::Option& /*opt*/) const
virtual int forward(const ncnn::VkMat& bottom_blob, ncnn::VkMat& top_blob, ncnn::VkCompute& cmd, const ncnn::Option& opt) const
{
bottom_top_blob.fill(0.f);
bottom_top_blob[123] = 0.5f;
bottom_top_blob[456] = 0.1f;
return 0;
return impl->forward(bottom_blob, top_blob, cmd, opt);
}

virtual int forward(const ncnn::VkImageMat& bottom_blob, ncnn::VkImageMat& top_blob, ncnn::VkCompute& cmd, const ncnn::Option& opt) const
{
return impl->forward(bottom_blob, top_blob, cmd, opt);
}
#endif // NCNN_VULKAN

private:
ncnn::Layer* impl;
};

DEFINE_LAYER_CREATOR(MySoftmax)
DEFINE_LAYER_CREATOR(MyConvolution)
DEFINE_LAYER_DESTROYER(MyConvolution)

static int test_squeezenet_overwrite_softmax(const ncnn::Option& opt, int load_model_type, float epsilon = 0.001)
{
Expand All @@ -279,7 +347,7 @@ static int test_squeezenet_overwrite_softmax(const ncnn::Option& opt, int load_m
if (load_model_type == 0)
{
// load from plain model file
squeezenet.register_custom_layer("Softmax", MySoftmax_layer_creator);
squeezenet.register_custom_layer("Convolution", MyConvolution_layer_creator, MyConvolution_layer_destroyer);
squeezenet.load_param(MODEL_DIR "/squeezenet_v1.1.param");

// test random feature disabled bits
Expand All @@ -296,7 +364,7 @@ static int test_squeezenet_overwrite_softmax(const ncnn::Option& opt, int load_m
if (load_model_type == 1)
{
// load from plain model memory
squeezenet.register_custom_layer("Softmax", MySoftmax_layer_creator);
squeezenet.register_custom_layer("Convolution", MyConvolution_layer_creator, MyConvolution_layer_destroyer);
param_str = read_file_string(MODEL_DIR "/squeezenet_v1.1.param");
model_data = read_file_content(MODEL_DIR "/squeezenet_v1.1.bin");
squeezenet.load_param_mem((const char*)param_str.c_str());
Expand All @@ -305,14 +373,14 @@ static int test_squeezenet_overwrite_softmax(const ncnn::Option& opt, int load_m
if (load_model_type == 2)
{
// load from binary model file
squeezenet.register_custom_layer(ncnn::layer_to_index("Softmax"), MySoftmax_layer_creator);
squeezenet.register_custom_layer(ncnn::layer_to_index("Convolution"), MyConvolution_layer_creator, MyConvolution_layer_destroyer);
squeezenet.load_param_bin(MODEL_DIR "/squeezenet_v1.1.param.bin");
squeezenet.load_model(MODEL_DIR "/squeezenet_v1.1.bin");
}
if (load_model_type == 3)
{
// load from binary model memory
squeezenet.register_custom_layer(ncnn::layer_to_index("Softmax"), MySoftmax_layer_creator);
squeezenet.register_custom_layer(ncnn::layer_to_index("Convolution"), MyConvolution_layer_creator, MyConvolution_layer_destroyer);
param_data = read_file_content(MODEL_DIR "/squeezenet_v1.1.param.bin");
model_data = read_file_content(MODEL_DIR "/squeezenet_v1.1.bin");
squeezenet.load_param((const unsigned char*)param_data);
Expand Down Expand Up @@ -345,7 +413,7 @@ static int test_squeezenet_overwrite_softmax(const ncnn::Option& opt, int load_m
cls_scores[j] = out[j];
}

return cls_scores[123] == 0.5f && cls_scores[456] == 0.1f ? 0 : -1;
return check_top2(cls_scores, epsilon);
}

int main()
Expand Down

0 comments on commit 903ec7c

Please sign in to comment.