Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incompatibility with torch.fx #188

Open
nicolas-dufour opened this issue May 18, 2022 · 6 comments
Open

Incompatibility with torch.fx #188

nicolas-dufour opened this issue May 18, 2022 · 6 comments
Labels
wontfix This will not be worked on

Comments

@nicolas-dufour
Copy link

Describe the bug
When trying to call torchvision.models.feature_extraction.get_graph_node_names on a model that has an einops operation, the operation fail and gets the following error
Traceback (most recent call last):
File "/home/Documents/project/run.py", line 96, in run
print(get_graph_node_names(model.encoder))
File "/home/anaconda3/envs/base/lib/python3.9/site-packages/torchvision/models/feature_extraction.py", line 239, in get_graph_node_names
train_tracer.trace(model.train())
File "/home/anaconda3/envs/base/lib/python3.9/site-packages/torch/fx/_symbolic_trace.py", line 566, in trace
self.create_node('output', 'output', (self.create_arg(fn(*args)),), {},
File "/home/Documents/project/encoder.py", line 192, in forward
latents = repeat(self.latents, " l d -> b l d", b=batch_size) * self.lr_mul
File "/home/anaconda3/envs/base/lib/python3.9/site-packages/einops/einops.py", line 537, in repeat
return reduce(tensor, pattern, reduction='repeat', **axes_lengths)
File "/home/anaconda3/envs/base/lib/python3.9/site-packages/einops/einops.py", line 410, in reduce
return _apply_recipe(recipe, tensor, reduction_type=reduction)
File "/home/anaconda3/base/scam/lib/python3.9/site-packages/einops/einops.py", line 231, in _apply_recipe
backend = get_backend(tensor)
File "/home/anaconda3/base/scam/lib/python3.9/site-packages/einops/_backends.py", line 52, in get_backend
raise RuntimeError('Tensor type unknown to einops {}'.format(type(tensor)))
RuntimeError: Tensor type unknown to einops <class 'torch.fx.proxy.Proxy'>
Reproduction steps
Steps to reproduce the behavior:
Create a model with repeat operation, then call get_graph_node_names on it

Expected behavior
rearrange should work with torch.fx

Your platform
einops 0.4.1, torch 1.11, torchvision 0.12, python 3.9

@nicolas-dufour nicolas-dufour added the bug Something isn't working label May 18, 2022
@arogozhnikov
Copy link
Owner

Hmmm, torch.fx is ... awkward. Using symboling tracing in a framework that always bet on define-by-run is breaking rules.

Einops requires some information about shape, but torch.fx does not provide it during tracing (... breaking even more rules). Even number of dimensions is not available to perform validation, so that's just unlikely to work.

I think ok walkaround would be pre-scripting these layers first, somewhat like

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = torch.jit.script(Rearrange('h w -> w h'))
        self.layer2 = torch.jit.script(Reduce('w h -> w', 'min'))
        
    def forward(self, x):
        x = self.layer1(x)
        return self.layer2(x)

... but it does not work with torch.fx because torch.ScriptedModule has no implementation for serialization / deserialization.
So... two torch things don't know how to interact.

Another less appealing way is to trace modules:

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = torch.jit.trace(Rearrange('h w -> w h'), [torch.zeros([10, 20])])
        self.layer2 = torch.jit.trace(Reduce('w h -> w', 'min'), [torch.zeros([20, 10])])
        
    def forward(self, x):
        # print(type(x))
        # print(x.shape)
        x = self.layer1(x)
        return self.layer2(x)

This works... kinda. After tracing modules are really just a bunch of simple operations like torch.repeat, torch.transpose, etc., but torch.fx does not understand it and stores modules as pickles.

Don't see a reasonable way forward here

@arogozhnikov arogozhnikov added wontfix This will not be worked on and removed bug Something isn't working labels Jul 3, 2022
@arogozhnikov arogozhnikov changed the title [BUG] Incompatibility with torch.fx Incompatibility with torch.fx Aug 24, 2022
@miteshksingh
Copy link

Hi,

I stumbled across this bug too. Is the current fix to rewrite the model without einops?

@arogozhnikov
Copy link
Owner

arogozhnikov commented Nov 17, 2022

yes

updated much later: maybe not, see proposals below

@miteshksingh
Copy link

Okay. Thank you for the quick response!

@Jongchan
Copy link

Jongchan commented Apr 7, 2023

Hello, all.

I am using einops.rearrange as a part of tensor preprocessing, and managed to bypass this problem with wrap in torch.fx. My specific use-case is using create_feature_extractor in torchvision which uses torch.fx for tracing.

You may pass autowrap_functions or autowrap_modules when using get_graph_node_names.

Please refer to the following links for example usage:

Hope this can help 😄
(BTW, thank you for the advice, @songheony)

@Red-Eyed
Copy link

Thank you @Jongchan

import torch
from einops import rearrange
torch.fx.wrap('rearrange')

For some reason it fixed error during fx export

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
wontfix This will not be worked on
Projects
None yet
Development

No branches or pull requests

5 participants