Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Circular imports when importing einops and torch._dynamo #315

Open
befelix opened this issue Apr 11, 2024 · 6 comments
Open

Circular imports when importing einops and torch._dynamo #315

befelix opened this issue Apr 11, 2024 · 6 comments

Comments

@befelix
Copy link

befelix commented Apr 11, 2024

Describe the bug
Importing einops before torch._dynamo currently leads to warnings. I'm not sure if this needs a fix on the pytorch or einops side. This is annoying for CI pipelines, where warnings are typically treated as errors. Note that, with sorted imports, einops will typically import before any torch namespaces.

$ python -W "error" -c 'import einops; import torch._dynamo'

Traceback (most recent call last):
  File "/home/bfe2rng/Code/rlcore/venv/lib/python3.11/site-packages/einops/_torch_specific.py", line 106, in allow_ops_in_compiled_graph
    from torch._dynamo import allow_in_graph
ImportError: cannot import name 'allow_in_graph' from partially initialized module 'torch._dynamo' (most likely due to a circular import) (/home/bfe2rng/Code/rlcore/venv/lib/python3.11/site-packages/torch/_dynamo/__init__.py)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/bfe2rng/Code/rlcore/venv/lib/python3.11/site-packages/torch/_dynamo/__init__.py", line 2, in <module>
    from . import allowed_functions, convert_frame, eval_frame, resume_execution
  File "/home/bfe2rng/Code/rlcore/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 62, in <module>
    from .output_graph import OutputGraph
  File "/home/bfe2rng/Code/rlcore/venv/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 89, in <module>
    from .variables.builder import GraphArg, TrackedFake, VariableBuilder, wrap_fx_proxy
  File "/home/bfe2rng/Code/rlcore/venv/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 143, in <module>
    from .optimizer import OptimizerVariable
  File "/home/bfe2rng/Code/rlcore/venv/lib/python3.11/site-packages/torch/_dynamo/variables/optimizer.py", line 5, in <module>
    from ..decorators import mark_static_address
  File "/home/bfe2rng/Code/rlcore/venv/lib/python3.11/site-packages/torch/_dynamo/decorators.py", line 284, in <module>
    allowed_functions.add_module_init_func("einops", _allow_in_graph_einops)
  File "/home/bfe2rng/Code/rlcore/venv/lib/python3.11/site-packages/torch/_dynamo/allowed_functions.py", line 489, in add_module_init_func
    init_func()
  File "/home/bfe2rng/Code/rlcore/venv/lib/python3.11/site-packages/torch/_dynamo/decorators.py", line 264, in _allow_in_graph_einops
    from einops._torch_specific import (  # noqa: F401
  File "/home/bfe2rng/Code/rlcore/venv/lib/python3.11/site-packages/einops/_torch_specific.py", line 127, in <module>
    allow_ops_in_compiled_graph()
  File "/home/bfe2rng/Code/rlcore/venv/lib/python3.11/site-packages/einops/_torch_specific.py", line 108, in allow_ops_in_compiled_graph
    warnings.warn("allow_ops_in_compiled_graph failed to import torch: ensure pytorch >=2.0", ImportWarning)
ImportWarning: allow_ops_in_compiled_graph failed to import torch: ensure pytorch >=2.0

Reproduction steps
Steps to reproduce the behavior:

python -W "error" -c 'import einops; import torch._dynamo'

Expected behavior

No output, import order does not matter.

Your platform

Ubuntu 22.04, torch==2.2.2, einops==0.7.0

@befelix befelix added the bug Something isn't working label Apr 11, 2024
@arogozhnikov
Copy link
Owner

arogozhnikov commented Apr 13, 2024

Hi @befelix , thank you for reporting, there is a circular import problem that we need to address.

step-by-step what happens in python -W "error" -c 'import einops; import torch._dynamo'

  1. einops is imported
  2. torch._dynamo is imported, it detects that einops is imported, and calls https://github.com/pytorch/pytorch/blob/88a71594933b2464d9d8b6b3533c5a945a4ac2ff/torch/_dynamo/decorators.py#L322
  3. ... which imports from einops._torch_specific ...
  4. einops._torch_specific tries to register ops by using allow_in_graph, that is needs to import first from torch._dynamo. That's where we run in circular imports

cc @wconstab @jansel

@arogozhnikov
Copy link
Owner

arogozhnikov commented Apr 14, 2024

Seems this torch PR was trying to address the problem, but that didn't quite work:
https://github.com/pytorch/pytorch/pull/111835/files

Fix wouldn't work it seems: einops needs allow_in_graph being importable, meaning that hook can only be called after _dynamo was imported. Currently hook is called during the import.

@jansel
Copy link

jansel commented Apr 14, 2024

Should we delete the special einops handling from torch._dynamo?

@arogozhnikov
Copy link
Owner

arogozhnikov commented Apr 14, 2024

it is preferrable to defer that logic to 'on first compile' (or when torch._dynamo is fully imported)

Otherwise, we better just delete that handling code - it wouldn't work as intended

jansel added a commit to pytorch/pytorch that referenced this issue Apr 15, 2024
See arogozhnikov/einops#315

ghstack-source-id: e5ad8f6da07887b7290fbae8c42af237a33dfbcb
Pull Request resolved: #124084
@jansel
Copy link

jansel commented Apr 15, 2024

Can you check if this fixes it: pytorch/pytorch#124084

pytorchmergebot pushed a commit to pytorch/pytorch that referenced this issue Apr 18, 2024
@arogozhnikov
Copy link
Owner

Thank you @jansel for merging that. I can confirm nightly torch==2.4.0.dev20240420+cpu does not have this issue.

Before fix comes live users need to use manual registration of functions:
https://github.com/arogozhnikov/einops/wiki/Using-torch.compile-with-einops

@arogozhnikov arogozhnikov added backend bug and removed bug Something isn't working labels Apr 20, 2024
@arogozhnikov arogozhnikov changed the title Warning when importing einops before torch._dynamo Circular imports when importing einops and torch._dynamo Apr 20, 2024
pytorch-bot bot pushed a commit to pytorch/pytorch that referenced this issue Apr 21, 2024
sanketpurandare pushed a commit to sanketpurandare/pytorch that referenced this issue Apr 22, 2024
andoorve pushed a commit to andoorve/pytorch that referenced this issue May 1, 2024
petrex pushed a commit to petrex/pytorch that referenced this issue May 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants