Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support set schema inference function in python #5940

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

OYCN
Copy link
Contributor

@OYCN OYCN commented Feb 16, 2024

Description

Support set schema inference callback function for custion OpSchema before register in python.

example:

def func(ctx: onnx.shape_inference.InferenceContext):
    # get some info
    assert ctx.get_num_inputs() == 2
    value = ctx.get_input_type(0)
    ...
    # get or create output proto object
    output = ctx.get_output_type(0)
    # set type or shape
    ...
    # set the result proto
    ctx.set_output_type(0, output)

schema.set_type_and_shape_inference_function(func)

Note

Depends on #5906

Motivation and Context

Follow up of #5019

@OYCN
Copy link
Contributor Author

OYCN commented Feb 16, 2024

We're spending a considerable amount of code on passing proto objects between C++ and Python. However, the repository seems well-prepared for this feature.

Signed-off-by: opluss <[email protected]>
Copy link

codecov bot commented Feb 16, 2024

Codecov Report

Attention: Patch coverage is 93.54839% with 4 lines in your changes are missing coverage. Please review.

Project coverage is 57.10%. Comparing base (83194ed) to head (73bed3b).
Report is 35 commits behind head on main.

Files Patch % Lines
onnx/shape_inference.py 83.33% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #5940      +/-   ##
==========================================
+ Coverage   56.95%   57.10%   +0.14%     
==========================================
  Files         506      506              
  Lines       30467    30989     +522     
  Branches     4592     4601       +9     
==========================================
+ Hits        17353    17696     +343     
- Misses      12285    12467     +182     
+ Partials      829      826       -3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

onnx/cpp2py_export.cc Fixed Show resolved Hide resolved
onnx/cpp2py_export.cc Fixed Show resolved Hide resolved
onnx/cpp2py_export.cc Fixed Show resolved Hide resolved
@OYCN
Copy link
Contributor Author

OYCN commented Mar 29, 2024

Hi @justinchuby ,

In this PR, we can implement shape inference on the Python side, similar to how it's done on the C++ side. If you have any suggestions for this implementation, I'm open to making adjustments accordingly.

@justinchuby justinchuby marked this pull request as ready for review March 29, 2024 18:38
@justinchuby justinchuby requested a review from a team as a code owner March 29, 2024 18:38
@justinchuby
Copy link
Contributor

Thank you! Is it ready to be reviewed?

@OYCN OYCN changed the title [WIP] Support set schema inference function in python Support set schema inference function in python Mar 31, 2024
@OYCN
Copy link
Contributor Author

OYCN commented Mar 31, 2024

Thank you! Is it ready to be reviewed?

Yes, I have removed the 'WIP' prefix from the title. Please feel free to leave any comments. :D

@justinchuby justinchuby self-assigned this Apr 1, 2024
)

assert ctx.get_num_inputs() == 2
in0 = ctx.get_input_type(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern with this is it goes through serialization to access the type information. It is not really efficient. I would change the API so that it does not return a TypeProto but the type and the shape as regular python objects.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I agree with Xavier, but I am a bit confused also. I see the method implementation serializes proto values to string and returns them. We could just return a pointer to the C++ Proto object (wrapped as a Python object). Is that your suggestion Xavier?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we closely mimic the C++ API design interface, or should we integrate Python's native types for interactions? Utilizing the proto pointer for interactions may require additional codes to bind them to Python (If there is another way please correct me), or we need to include some third-party library.

out.tensor_type.shape.dim.add().dim_value = N
out.tensor_type.shape.dim.add().dim_value = La * Lb
out.tensor_type.shape.dim.add().dim_value = out_len[i]
ctx.set_output_type(i, out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here, we should avoid serialization with something like set_output_type_and_shape(in0_type, (N, La*Lb, out_lin[i]). The type is created on C++ side, there is no serialization and it would be more efficient.

@xadupre
Copy link
Contributor

xadupre commented Apr 9, 2024

It would be a nice feature to have.

@@ -114,6 +116,9 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) {
#endif // ONNX_ML
);

// Avoid Segmentation fault if we not free the python function in Custom Schema
onnx_cpp2py_export.add_object("_cleanup", py::capsule([] { OpSchemaRegistry::OpSchemaDeregisterAll(); }));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify when this gets invoked?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the segfault is caused by:
The python object (the inference function in custom schema) need destroyed before the python interpreter is destroyed. The static container within the schema factory is destroyed after main function and before the interpreter. Therefore, we need to manually destroy the Python object.

About '_cleanup' : https://pybind11.readthedocs.io/en/stable/advanced/misc.html#module-destructors

onnx/cpp2py_export.cc Outdated Show resolved Hide resolved
onnx/cpp2py_export.cc Outdated Show resolved Hide resolved
onnx/cpp2py_export.cc Outdated Show resolved Hide resolved
@gramalingam
Copy link
Contributor

Thanks for creating the PR! It would be great to add this functionality. My comments above.

Copy link
Contributor Author

@OYCN OYCN left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgive my mistake, I didn't know that the response need to be commit manually. My reply is always in pending status.

@@ -114,6 +116,9 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) {
#endif // ONNX_ML
);

// Avoid Segmentation fault if we not free the python function in Custom Schema
onnx_cpp2py_export.add_object("_cleanup", py::capsule([] { OpSchemaRegistry::OpSchemaDeregisterAll(); }));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps a more graceful approach would be to collect the schemas registered from Python and deregister them during cleanup. However, I'm not sure if it's worth the effort. In most cases, invoking cleanup implies that Python is exiting.

)

assert ctx.get_num_inputs() == 2
in0 = ctx.get_input_type(0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we closely mimic the C++ API design interface, or should we integrate Python's native types for interactions? Utilizing the proto pointer for interactions may require additional codes to bind them to Python (If there is another way please correct me), or we need to include some third-party library.

@@ -114,6 +116,9 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) {
#endif // ONNX_ML
);

// Avoid Segmentation fault if we not free the python function in Custom Schema
onnx_cpp2py_export.add_object("_cleanup", py::capsule([] { OpSchemaRegistry::OpSchemaDeregisterAll(); }));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the segfault is caused by:
The python object (the inference function in custom schema) need destroyed before the python interpreter is destroyed. The static container within the schema factory is destroyed after main function and before the interpreter. Therefore, we need to manually destroy the Python object.

About '_cleanup' : https://pybind11.readthedocs.io/en/stable/advanced/misc.html#module-destructors

@xadupre
Copy link
Contributor

xadupre commented May 27, 2024

I would do this in a different way. I think we should export in python all functions doing shape inference for every operator. We should have a python version of the algorithm doing shape inference. That way, adding new shape inference function would be easier and we could do local shape inference. This feature is needed when building a model. Shapes can be used to write efficient sequences of onnx operators.

@OYCN
Copy link
Contributor Author

OYCN commented May 30, 2024

Hi Xavier,

After thinking about it for a while, please correct me if my understanding is wrong. Implementing a Python version of shape inference may not eliminate the need to pass ONNX protos between C++ and Python. In the current architecture, if we want to obtain any proto in the shape inference function, we need to export InferenceContext from C++ to Python. Perhaps a more efficient implementation for interacting with information carried by proto is key to solving the problem. This could be done using native Python types (though it’s challenging to deliver graph protos), type pointers, or other methods.

Best Regards.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: In progress
Development

Successfully merging this pull request may close these issues.

None yet

4 participants