Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

datettime.now() is not supported by Dynamo #125171

Open
thiagocrepaldi opened this issue Apr 29, 2024 · 3 comments
Open

datettime.now() is not supported by Dynamo #125171

thiagocrepaldi opened this issue Apr 29, 2024 · 3 comments
Labels
module: dynamo oncall: export oncall: pt2 onnx-needs-info needs information from the author / reporter before ONNX team can take action

Comments

@thiagocrepaldi
Copy link
Collaborator

thiagocrepaldi commented Apr 29, 2024

馃悰 Describe the bug

I am actually not sure what is the best behavior here.

On one hand, when datetime.now() is assigned to a variable that is only used for printing, it could be ignored, just like the print itself.

However, what if the user model does something like return torch.rand() + datetime.now().second()? It should in theory be captured as a constant? Maybe this case is so weird that should be no-oped?

Both torch._dynamo.export and torch.export.export suffer from this issue

Repro

    def test_export_with_datetime(self):
        from datetime import datetime
        class DateTimeModule(torch.nn.Module):
            def forward(self, x):
                start_time = datetime.now()
                y = x + 1 + start_time.second
                return y

        input = torch.randn(2, 3)
        model = DateTimeModule()
        _ = torch.export.export(model, args=(input,))  # torch._dynamo.exc.Unsupported: call_method UserDefinedObjectVariable(classmethod_descriptor) __call__ [] {}

Error:

torch._dynamo.exc.Unsupported: call_method UserDefinedObjectVariable(classmethod_descriptor) __call__ [] {}

from user code:
   File "/opt/pytorch/test/onnx/test_fx_to_onnx.py", line 740, in forward
    start_time = datetime.now()

With TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1

Traceback (most recent call last):
  File "/opt/pytorch/test/onnx/test_fx_to_onnx.py", line 750, in test_export_with_datetime
    _ = torch.export.export(model, args=(input,))
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/export/__init__.py", line 174, in export
    return _export(
           ^^^^^^^^
  File "/opt/pytorch/torch/export/_trace.py", line 840, in wrapper
    raise e
  File "/opt/pytorch/torch/export/_trace.py", line 823, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/export/exported_program.py", line 85, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/export/_trace.py", line 1096, in _export
    gm_torch_level = _export_to_torch_ir(
                     ^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/export/_trace.py", line 428, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
                        ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/_dynamo/eval_frame.py", line 1251, in inner
    result_traced = opt_f(*args, **kwargs)
                    ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/_dynamo/eval_frame.py", line 403, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/_dynamo/convert_frame.py", line 977, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/_dynamo/convert_frame.py", line 411, in _convert_frame_assert
    return _compile(
           ^^^^^^^^^
  File "/opt/pytorch/torch/_utils_internal.py", line 70, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/ptca/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/_dynamo/convert_frame.py", line 700, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/_dynamo/utils.py", line 268, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/_dynamo/convert_frame.py", line 568, in compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/_dynamo/bytecode_transformation.py", line 1116, in transform_code_object
    transformations(instructions, code_options)
  File "/opt/pytorch/torch/_dynamo/convert_frame.py", line 173, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/_dynamo/convert_frame.py", line 515, in transform
    tracer.run()
  File "/opt/pytorch/torch/_dynamo/symbolic_convert.py", line 2230, in run
    super().run()
  File "/opt/pytorch/torch/_dynamo/symbolic_convert.py", line 880, in run
    while self.step():
          ^^^^^^^^^^^
  File "/opt/pytorch/torch/_dynamo/symbolic_convert.py", line 795, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/opt/pytorch/torch/_dynamo/symbolic_convert.py", line 492, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/_dynamo/symbolic_convert.py", line 1841, in CALL
    self.call_function(fn, args, kwargs)
  File "/opt/pytorch/torch/_dynamo/symbolic_convert.py", line 735, in call_function
    self.push(fn.call_function(self, args, kwargs))
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/_dynamo/variables/user_defined.py", line 783, in call_function
    return self.call_method(tx, "__call__", args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/_dynamo/variables/user_defined.py", line 646, in call_method
    return super().call_method(tx, name, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/_dynamo/variables/base.py", line 313, in call_method
    unimplemented(f"call_method {self} {name} {args} {kwargs}")
  File "/opt/pytorch/torch/_dynamo/exc.py", line 212, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: call_method UserDefinedObjectVariable(classmethod_descriptor) __call__ [] {}

from user code:
   File "/opt/pytorch/test/onnx/test_fx_to_onnx.py", line 740, in forward
    start_time = datetime.now()


To execute this test, run the following from the base repo dir:
     python test/onnx/test_fx_to_onnx.py -k test_export_with_datetime

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
==================================================================================== short test summary info =====================================================================================
FAILED [0.9377s] test/onnx/test_fx_to_onnx.py::TestFxToOnnx::test_export_with_datetime - torch._dynamo.exc.Unsupported: call_method UserDefinedObjectVariable(classmethod_descriptor) __call__ [] {}
======================================================================================= 1 failed in 14.04s =======================================================================================
I0429 16:52:50.265000 139214601663680 torch/_dynamo/utils.py:326] TorchDynamo compilation metrics:
I0429 16:52:50.265000 139214601663680 torch/_dynamo/utils.py:326] Function                           Runtimes (s)
I0429 16:52:50.265000 139214601663680 torch/_dynamo/utils.py:326] -------------------------------  --------------
I0429 16:52:50.265000 139214601663680 torch/_dynamo/utils.py:326] _compile.<locals>.compile_inner               0
V0429 16:52:50.266000 139214601663680 torch/fx/experimental/symbolic_shapes.py:109] lru_cache_stats constrain_symbol_range: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0429 16:52:50.266000 139214601663680 torch/fx/experimental/symbolic_shapes.py:109] lru_cache_stats evaluate_expr: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V0429 16:52:50.266000 139214601663680 torch/fx/experimental/symbolic_shapes.py:109] lru_cache_stats _simplify_floor_div: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0429 16:52:50.266000 139214601663680 torch/fx/experimental/symbolic_shapes.py:109] lru_cache_stats _maybe_guard_rel: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V0429 16:52:50.266000 139214601663680 torch/fx/experimental/symbolic_shapes.py:109] lru_cache_stats _find: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0429 16:52:50.266000 139214601663680 torch/fx/experimental/symbolic_shapes.py:109] lru_cache_stats has_hint: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V0429 16:52:50.266000 139214601663680 torch/fx/experimental/symbolic_shapes.py:109] lru_cache_stats size_hint: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V0429 16:52:50.266000 139214601663680 torch/fx/experimental/symbolic_shapes.py:109] lru_cache_stats simplify: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0429 16:52:50.266000 139214601663680 torch/fx/experimental/symbolic_shapes.py:109] lru_cache_stats _update_divisible: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0429 16:52:50.266000 139214601663680 torch/fx/experimental/symbolic_shapes.py:109] lru_cache_stats replace: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0429 16:52:50.266000 139214601663680 torch/fx/experimental/symbolic_shapes.py:109] lru_cache_stats _maybe_evaluate_static: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0429 16:52:50.266000 139214601663680 torch/fx/experimental/symbolic_shapes.py:109] lru_cache_stats get_implications: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0429 16:52:50.266000 139214601663680 torch/fx/experimental/symbolic_shapes.py:109] lru_cache_stats get_axioms: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0429 16:52:50.267000 139214601663680 torch/fx/experimental/symbolic_shapes.py:109] lru_cache_stats safe_expand: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V0429 16:52:50.267000 139214601663680 torch/fx/experimental/symbolic_shapes.py:109] lru_cache_stats uninteresting_files: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)

Versions

main

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

@thiagocrepaldi thiagocrepaldi added oncall: pt2 oncall: export onnx-needs-info needs information from the author / reporter before ONNX team can take action labels Apr 29, 2024
@ezyang
Copy link
Contributor

ezyang commented Apr 29, 2024

What's the real code actually doing? It seems reasonable for us to graph break here.

@thiagocrepaldi
Copy link
Collaborator Author

thiagocrepaldi commented Apr 29, 2024

What's the real code actually doing? It seems reasonable for us to graph break here.

My scenario is for full graph :)

On my actual model, the user used datetime.now() just for logging. I commented out the datetime.now() and logger.info calls and moved on. However, I wanted to streamline this scenario, though, to prevent users from changing the code when we can do something about it without chancges

def forward(self, ...):
    start_time = datetime.now()
    tt = (
        self.get_img_features(img_embeds)
        .to(target_device)
        .to(target_dtype)
        .reshape(-1, self.image_dim_out)
    )
    logger.info(f'img_embeds size: {img_embeds.size()}, loading time {datetime.now() - start_time}')

When I was creating a NoOpMethodVariable to make datetime.now() no-op during tracing (within torch/_dynamo/variables/misc.py and torch/_dynamo/variables/builder.py), I tried to think of an use case in which the user could use datetime.now() as an actual computational thing that should not be no-op.

@ezyang
Copy link
Contributor

ezyang commented Apr 30, 2024

Ah, OK, then this is pursuant to #116106 and we need to rope in @angelayi and @yanboliang

The way I always imagined this sort of thing, was that we'd have a notion of "poison" values. When we encounter a datetime.now(), or anything else that doesn't have side effects, we create a PoisonVariableTracker for it. As long as poison doesn't escape or participate in compute (OK for poison to flow into logging statements which are removed), then it is fine and export succeeds.

This model makes a lot more sense when the logging calls are being dropped as opposed to reordered to the end. IDK, there's probably other ways to do it. We could also just directly implement datetime.now(), it's not even that hard, should be handled similar to random.randint

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamo oncall: export oncall: pt2 onnx-needs-info needs information from the author / reporter before ONNX team can take action
Projects
None yet
Development

No branches or pull requests

3 participants