Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Attention Fuse pass #809

Closed

Conversation

wjj19950828
Copy link
Contributor

@wjj19950828 wjj19950828 commented Jul 15, 2022

Do follow contributes

1、添加Attention fuse pass,仅支持ORT
2、添加enable_extra_ort_opt开关,含义为是否开启额外只针对ORT的优化,默认为false
若使用,使用如下命令

paddle2onnx --model_dir msra_ner_pruned_infer_model/ --model_filename float32.pdmodel --params_filename float32.pdiparams --save_file ner_model_test_0713.onnx --opset_version 13 --enable_onnx_checker True --enable_dev_version True --enable_extra_ort_opt True

相关性能测试

  • GPU 100次warmup,1000次预测取平均
  • 考虑到CPU 10次warmup+100次repeat方差太大,测速改为100次warmup + 2000次repeat,具体如下:

1、非裁剪模型
添加paddle2onnx Attention fuse pass,导出ONNX模型自带Attention node
GPU:
100次warmup,1000次预测取平均,预测时间仅包含run+数据拷贝,表示为(mean),单位为ms,w/o和w代表是否命中attention pass:
image
CPU:
取100次warmup以及2000次预测取平均,线程数为1,预测时间包含run+数据拷贝,表示为(mean),单位为ms,w/o和w代表是否命中attention pass:
image

2、裁剪模型
添加paddle2onnx Attention fuse pass,导出ONNX模型自带Attention node
GPU:
100次warmup,1000次预测取平均,预测时间仅包含run+数据拷贝,表示为(mean),单位为ms,w/o和w代表是否命中attention pass:
image
CPU:
取100次warmup以及2000次预测取平均,线程数为1,预测时间包含run+数据拷贝,表示为(mean),单位为ms,w/o和w代表是否命中attention pass:
image

#include "paddle2onnx/optimizer/fuse_constant_cast.h"
#include "paddle2onnx/optimizer/fuse_constant_reshape.h"
#include "paddle2onnx/optimizer/fuse_constant_unsqueeze.h"
#include "paddle2onnx/optimizer/fuse_paddle_conv_bias.h"
#include "paddle2onnx/optimizer/fuse_unsqueeze_conv2d_squeeze.h"

namespace paddle2onnx {
MapperHelper* MapperHelper::helper = nullptr;
MapperHelper *MapperHelper::helper = nullptr;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

代码风格被重新格式化了,需要修改回去

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


std::string getPassName() const override { return "fuse_attention"; }

bool patternMatchPredicate(Node *node) override {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在原始代码中都使用指针符号靠左对齐,新添加的代码需对齐

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

// +------|---+
// | |
// Add

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里有一幅整图的fuse前后对比

在下面的代码中也对应写出fuse中每一步的几个点节(而不单纯只用QKV WEIGHT来描述)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.添加相关注释

@@ -35,6 +35,7 @@ struct OptimizerOption {
passes.push_back("fuse_matmul_add_bias_into_gemm");
passes.push_back("eliminate_identity");
passes.push_back("eliminate_deadend");
passes.push_back("fuse_attention");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这不是一个必选的fuse,所以需要额外的开关来控制

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.添加enable_extra_ort_opt开关,默认为False

@jiangjiajun
Copy link
Collaborator

这个PR,需给出最终的实验数据和效果

@wjj19950828
Copy link
Contributor Author

这个PR,需给出最终的实验数据和效果

在描述中给出相关实验数据,CPU上增加repeat次数,解决方差较大问题

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants