Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch_dim_to_trt_axes does not handle dim=-1 correctly #898

Open
Thrsu opened this issue Nov 27, 2023 · 0 comments
Open

torch_dim_to_trt_axes does not handle dim=-1 correctly #898

Thrsu opened this issue Nov 27, 2023 · 0 comments

Comments

@Thrsu
Copy link

Thrsu commented Nov 27, 2023

Description:

As the official PyTorch documentation, The default value of dim for GumbelSoftmax operator is -1, representing the last dimension. However, the torch_dim_to_trt_axes function does not handle the case when dim is set to -1, representing the last dimension. This results in incorrect behavior when converting the dim value to a TensorRT axes bitmask.

Reproduce:

Here is a minimal script to reproduce the issue:

import torch
from torch.nn import Module
from torch2trt import torch2trt

para_0 = torch.randn([5, 5], dtype=torch.float32).cuda()
para_1 = 2.0
para_2 = True
class gumbel_softmax(Module):
    def forward(self, *args):
        return torch.nn.functional.gumbel_softmax(args[0], para_1,para_2,)
model = gumbel_softmax().float().eval().cuda()
model_trt = torch2trt(model, [para_0])

The traceback information is as below:

Traceback (most recent call last):
  ...
   model_trt = torch2trt(model, [para_0])
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch2trt-0.4.0-py3.9.egg/torch2trt/torch2trt.py", line 778, in torch2trt
    outputs = module(*inputs)
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/share/OPERA/torch/reproduce/test_3138.py", line 11, in forward
    return torch.nn.functional.gumbel_softmax(args[0], para_1,para_2,)
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch2trt-0.4.0-py3.9.egg/torch2trt/torch2trt.py", line 300, in wrapper
    outputs = method(*args, **kwargs)
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch/nn/functional.py", line 1915, in gumbel_softmax
    index = y_soft.max(dim, keepdim=True)[1]
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch2trt-0.4.0-py3.9.egg/torch2trt/torch2trt.py", line 309, in wrapper
    converter["converter"](ctx)
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch2trt-0.4.0-py3.9.egg/torch2trt/converters/max.py", line 36, in convert_max
    __convert_max_reduce(ctx)
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch2trt-0.4.0-py3.9.egg/torch2trt/converters/max.py", line 26, in __convert_max_reduce
    layer = ctx.network.add_reduce(input_trt,  trt.ReduceOperation.MAX, torch_dim_to_trt_axes(dim), keepdim)
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch2trt-0.4.0-py3.9.egg/torch2trt/torch2trt.py", line 116, in torch_dim_to_trt_axes
    axes |= 1 << d 
ValueError: negative shift count

Environment

  • torch: 2.1.1
  • torch2trt: 0.4.0
  • tensorrt: 8.6.1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant