Skip to content

Commit

Permalink
[dynamo] support inactive context managers across graph breaks (pytor…
Browse files Browse the repository at this point in the history
…ch#125203)

Fix pytorch#124900.

When we reconstruct `ContextWrappingVariables`s, we only reconstruct the context class, not the object. Normally, contexts are active (via `with ctx:`) and we initialize the context object in the resume function. But for the case of inactive contexts (contexts declared ahead of time before the `with` block), we do not reconstruct them properly in the optimized bytecode or resume function. So this PR adds initialization for inactive contexts in the resume function.

Pull Request resolved: pytorch#125203
Approved by: https://github.com/jansel
  • Loading branch information
williamwen42 authored and andoorve committed May 1, 2024
1 parent f5ded79 commit d3f04bd
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 2 deletions.
16 changes: 16 additions & 0 deletions test/dynamo/test_ctx_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,6 +1304,22 @@ def inner_func(x):
self.assertEqual(fn(x), opt_fn(x))
self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad)

def test_inactive_context_graph_break(self):
def fn(x):
x = x + 1
ctx = torch.set_grad_enabled(True)
torch._dynamo.graph_break()
with ctx:
x = x + 1
return x

x = torch.zeros(10, requires_grad=False)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch.compile(fn, backend=cnts)
self.assertEqual(fn(x), opt_fn(x))
self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad)
self.assertEqual(cnts.frame_count, 2)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
28 changes: 28 additions & 0 deletions torch/_dynamo/resume_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,17 @@ def _filter_iter(l1, l2, cond):
return res


def _load_tuple_and_call(tup):
insts = []
if sys.version_info >= (3, 11):
insts.append(create_instruction("PUSH_NULL"))
insts.append(create_instruction("SWAP", arg=2))
for val in tup:
insts.append(create_instruction("LOAD_CONST", argval=val))
insts.extend(create_call_function(len(tup), False))
return insts


class ContinueExecutionCache:
cache = ExactWeakKeyDictionary()
generated_code_metadata = ExactWeakKeyDictionary()
Expand All @@ -341,6 +352,8 @@ def generate(
argnames: Tuple[str],
argnames_null: Tuple[str],
setup_fns: Tuple[ReenterWith],
stack_ctx_vars: Tuple[int, Tuple[Any]],
argnames_ctx_vars: Tuple[str, Tuple[Any]],
null_idxes: Tuple[int],
) -> types.CodeType:
assert offset is not None
Expand All @@ -359,6 +372,8 @@ def generate(
argnames,
argnames_null,
setup_fns,
stack_ctx_vars,
argnames_ctx_vars,
null_idxes,
)

Expand Down Expand Up @@ -420,6 +435,7 @@ def update(instructions: List[Instruction], code_options: Dict[str, Any]):
# map old hook targets to new targets generated by the hook
old_hook_target_remap = {}
null_idxes_i = 0
stack_ctx_vars_d = dict(stack_ctx_vars) # type: ignore[var-annotated,arg-type]
for i in range(nstack):
while (
null_idxes_i < len(null_idxes)
Expand All @@ -437,6 +453,12 @@ def update(instructions: List[Instruction], code_options: Dict[str, Any]):
old_hook_target = offset_to_inst[hook_target_offset]
meta.prefix_block_target_offset_remap.append(hook_target_offset)
old_hook_target_remap[old_hook_target] = exn_target
real_i = i + null_idxes_i
if real_i in stack_ctx_vars_d:
# current stack variable is a context var -
# load args for context variable and construct it
prefix.extend(_load_tuple_and_call(stack_ctx_vars_d[real_i]))

if is_py311_plus:
# reverse the mapping since targets of later/nested contexts are inserted
# into the mapping later, but show up earlier in the prefix.
Expand All @@ -446,6 +468,12 @@ def update(instructions: List[Instruction], code_options: Dict[str, Any]):

assert not hooks

# initialize inactive context vars in argnames
for name, vals in argnames_ctx_vars:
prefix.append(create_instruction("LOAD_FAST", argval=name))
prefix.extend(_load_tuple_and_call(vals))
prefix.append(create_instruction("STORE_FAST", argval=name))

# 3.12+: store NULL into variables that were NULL
if argnames_null:
assert sys.version_info >= (3, 12)
Expand Down
23 changes: 21 additions & 2 deletions torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2282,6 +2282,23 @@ def create_call_resume_at(self, inst):
if sys.version_info < (3, 12):
assert len(argnames_null) == 0, "variables should not be NULL in < 3.12"

# Handle inactive context variables - inactive context variables
# are reconstructed to be the class, NOT the object.
# So the resume function needs to construct the context object
# from the class and the context object's target values.
# e.g. torch.set_grad_enabled(True) will be reconstructed as
# torch.set_grad_enabled
stack_ctx_vars = []
for i, var in enumerate(self.stack):
if type.__instancecheck__(ContextWrappingVariable, var):
stack_ctx_vars.append((i, tuple(var.target_values))) # type: ignore[attr-defined]
argnames_ctx_vars = []
for name in argnames:
if type.__instancecheck__(
ContextWrappingVariable, var := self.symbolic_locals[name]
):
argnames_ctx_vars.append((name, tuple(var.target_values))) # type: ignore[attr-defined]

cg = PyCodegen(self)

# Python does not allow null to be an arg to a function, so
Expand All @@ -2293,12 +2310,12 @@ def create_call_resume_at(self, inst):
if sys.version_info >= (3, 11):
# find indices of NullVariables
for i, var in enumerate(self.stack):
if isinstance(var, NullVariable):
if type.__instancecheck__(NullVariable, var):
null_idxes.append(i)
# generate bytecode to pop the nulls
null_cnt = 0
for i, var in enumerate(reversed(self.stack)):
if isinstance(var, NullVariable):
if type.__instancecheck__(NullVariable, var):
for j in range(2, i + 2 - null_cnt):
cg.append_output(create_instruction("SWAP", arg=j))
cg.extend_output(cg.pop_null())
Expand All @@ -2320,6 +2337,8 @@ def create_call_resume_at(self, inst):
argnames,
argnames_null,
tuple(b.resume_fn() for b in self.block_stack),
tuple(stack_ctx_vars),
tuple(argnames_ctx_vars),
tuple(null_idxes),
)

Expand Down
8 changes: 8 additions & 0 deletions torch/_dynamo/variables/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,17 @@ def realize(self):
assert self.vt is None
from ..symbolic_convert import InstructionTranslator
from .builder import VariableBuilder
from .ctx_manager import ContextWrappingVariable, NullContextVariable
from .misc import NullVariable

tx = InstructionTranslator.current_tx()
self.vt = VariableBuilder(tx, self.source)(self.value)

# we do not expect wrapping these variables in lazy VTs
assert not isinstance(
self.vt, (NullVariable, ContextWrappingVariable)
) or isinstance(self.vt, NullContextVariable)

del self.value
del self.source

Expand Down

0 comments on commit d3f04bd

Please sign in to comment.