Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scale parsed as float in ONNX scaled_dot_product_attention implementation #125158

Open
nicolas-mng opened this issue Apr 29, 2024 · 1 comment
Open
Labels
module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@nicolas-mng
Copy link

Hi,
I have been getting errors looking like the one below when trying to export a model to ONNX within which I manually provide a scale argument to the scaled dot product attention calls:

   File "/usr/local/lib/python3.10/dist-packages/torch/onnx/symbolic_opset14.py", line 176, in scaled_dot_product_attention
    query_scaled = g.op("Mul", query, g.op("Sqrt", scale))
   File "<@beartype(torch.onnx._internal.jit_utils.GraphContext.op) at 0x7f10f6b3feb0>", line 44, in op
 beartype.roar.BeartypeCallHintParamViolation: Method torch.onnx._internal.jit_utils.GraphContext.op() parameter raw_args=0.03125 violates type hint typing.Union[torch.Tensor, torch.Value], as float 0.03125 not <protocol "torch.Tensor"> or <class "torch.Value">.

Looking at symbolic_opset14.py, I see that _maybe_get_const parses scale as a float which triggers the error as the type is not correct, cf.
here

Is this intentional? If I remove this line, then I manage to export successfully my model.
Thanks

@cpuhrsch cpuhrsch added module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Apr 30, 2024
@thiagocrepaldi
Copy link
Collaborator

Propose a PR and provide a unit test and we can review it for you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants