Skip to content

Commit

Permalink
[export] Don't create a new fake mode if dynamo tracing
Browse files Browse the repository at this point in the history
  • Loading branch information
angelayi committed Apr 29, 2024
1 parent 96cc73d commit 2a019bf
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions torch/_export/non_strict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,19 @@ def make_fake_inputs(nn_module, args, kwargs, dynamic_shapes):
"co_firstlineno": code.co_firstlineno,
}

fake_mode = FakeTensorMode(
shape_env=ShapeEnv(tracked_fakes=[], co_fields=co_fields),
allow_non_fake_inputs=True,
)
context = torch._guards.TracingContext.try_get()
if context is not None:
# This occurs when we are exporting within dynamo. There already exists
# a toplevel TracingContext with a fake mode, so we do not want to
# create another fake mode. In this scenario, we also shouldn't have any
# constraints since the toplevel tracing context should handle it.
assert len(constraints) == 0, "Found constraints when tracing with a toplevel tracing context."
fake_mode = context.fake_mode
else:
fake_mode = FakeTensorMode(
shape_env=ShapeEnv(tracked_fakes=[], co_fields=co_fields),
allow_non_fake_inputs=True,
)
if fake_mode.shape_env is None or fake_mode.shape_env.tracked_fakes is None:
raise ValueError(
"Detected fake_mode does not have a shape_env with tracked fakes. "
Expand Down

0 comments on commit 2a019bf

Please sign in to comment.