Skip to content

Commit

Permalink
[export] Don't create a new fake mode if dynamo tracing (#125185)
Browse files Browse the repository at this point in the history
Fixes #ISSUE_NUMBER

Pull Request resolved: #125185
Approved by: https://github.com/mikekgfb
  • Loading branch information
angelayi authored and pytorchmergebot committed May 9, 2024
1 parent 23e71ff commit 13545fe
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 4 deletions.
16 changes: 16 additions & 0 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4214,6 +4214,22 @@ def forward(self, x):
self.assertEqual(len(ep.constants), 1)
self.assertEqual(mod(inp), m(inp))

def test_export_as_backend(self):
def f(x, y):
return x + y

def my_custom_backend(gm, example_inputs):
gm = (
torch.export.export(gm, tuple(example_inputs), strict=False)
.run_decompositions()
.module()
)
return gm

inp = (torch.randn(3, 3), torch.randn(3, 3))
new_res = torch.compile(f, backend=my_custom_backend)(*inp)
self.assertTrue(torch.allclose(f(*inp), new_res))

def test_nonstrict_retrace_preserves_metadata(self):
class MyModule(torch.nn.Module):
def __init__(self):
Expand Down
19 changes: 15 additions & 4 deletions torch/_export/non_strict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,21 @@ def make_fake_inputs(nn_module, args, kwargs, dynamic_shapes):
"co_firstlineno": code.co_firstlineno,
}

fake_mode = FakeTensorMode(
shape_env=ShapeEnv(tracked_fakes=[], co_fields=co_fields),
allow_non_fake_inputs=True,
)
context = torch._guards.TracingContext.try_get()
if context is not None:
# This occurs when we are exporting within dynamo. There already exists
# a toplevel TracingContext with a fake mode, so we do not want to
# create another fake mode. In this scenario, we also shouldn't have any
# constraints since the toplevel tracing context should handle it.
assert (
len(constraints) == 0
), "Found constraints when tracing with a toplevel tracing context."
fake_mode = context.fake_mode
else:
fake_mode = FakeTensorMode(
shape_env=ShapeEnv(tracked_fakes=[], co_fields=co_fields),
allow_non_fake_inputs=True,
)
if fake_mode.shape_env is None or fake_mode.shape_env.tracked_fakes is None:
raise ValueError(
"Detected fake_mode does not have a shape_env with tracked fakes. "
Expand Down

0 comments on commit 13545fe

Please sign in to comment.