Skip to content

Commit

Permalink
[compiled autograd] compile fwd in inference mode
Browse files Browse the repository at this point in the history
ghstack-source-id: 8aa1e631d8bdaacfc8d3a6a1f88da32e50cbce71
Pull Request resolved: #125201
  • Loading branch information
xmfan committed Apr 29, 2024
1 parent 5585138 commit 1bb329d
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions torch/_functorch/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,9 @@ def convert(idx, x):
any(x.requires_grad for x in fake_flat_args if isinstance(x, Tensor))
and torch.is_grad_enabled()
)
if torch._dynamo.compiled_autograd.compiled_autograd_enabled:
print(f"Fwd is called within compiled autograd ctx, assume bwd will be handled by compiled autograd")
needs_autograd = False

with enable_python_dispatcher():
# Patch set_rng_state as set_rng_state with fake tensors is
Expand Down

0 comments on commit 1bb329d

Please sign in to comment.