Skip to content

Commit

Permalink
[export] Fix for unflattening modules with duplicate tensors (#125192)
Browse files Browse the repository at this point in the history
In the given test case, we have a ModuleList of 3 modules (`norm.0`, `norm.1`, `norm.2`) which share the same `weight` and `bias` tensors. However when we trace, they all end up pointing to one state dict name, (ex. `norm.2`).
```
graph():
    %p_norms_0_weight : [num_users=0] = placeholder[target=p_norms_0_weight]
    %p_norms_0_bias : [num_users=0] = placeholder[target=p_norms_0_bias]
    %p_norms_1_weight : [num_users=0] = placeholder[target=p_norms_1_weight]
    %p_norms_1_bias : [num_users=0] = placeholder[target=p_norms_1_bias]
    %p_norms_2_weight : [num_users=3] = placeholder[target=p_norms_2_weight]
    %p_norms_2_bias : [num_users=3] = placeholder[target=p_norms_2_bias]
    %input_ : [num_users=1] = placeholder[target=input_]
    %native_layer_norm : [num_users=1] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%input_, [2, 2, 3], %p_norms_2_weight, %p_norms_2_bias, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%native_layer_norm, 0), kwargs = {})
    %native_layer_norm_1 : [num_users=1] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%getitem, [2, 2, 3], %p_norms_2_weight, %p_norms_2_bias, 1e-05), kwargs = {})
    %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%native_layer_norm_1, 0), kwargs = {})
    %native_layer_norm_2 : [num_users=1] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%getitem_3, [2, 2, 3], %p_norms_2_weight, %p_norms_2_bias, 1e-05), kwargs = {})
    %getitem_6 : [num_users=1] = call_function[target=operator.getitem](args = (%native_layer_norm_2, 0), kwargs = {})
    return (getitem_6,)
```
This causes an error in the unflattener where after constructing the submodules for `norm.0`, it will have the graph pointing to `norm.2.weight` and `norm.2.bias`:
```
graph():
    %p_norms_2_bias : [num_users=1] = placeholder[target=p_norms_2_bias]
    %p_norms_2_weight : [num_users=1] = placeholder[target=p_norms_2_weight]
    %input_ : [num_users=1] = placeholder[target=input_]
    %native_layer_norm : [num_users=1] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%input_, [2, 2, 3], %p_norms_2_weight, %p_norms_2_bias, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%native_layer_norm, 0), kwargs = {})
    return getitem
```
Since the attributes are not within the same scope of the graph, (`norm.0` vs. `norm.2`), they will not be added to the subgraph, causing an error.

So this PR handles the duplicate state dict attributes by modifying the `inputs_to_state` dict to map from node names to a list of possible state dict target names.
Pull Request resolved: #125192
Approved by: https://github.com/zhxchen17
  • Loading branch information
angelayi authored and pytorchmergebot committed May 1, 2024
1 parent af67704 commit a216d87
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 20 deletions.
33 changes: 33 additions & 0 deletions test/export/test_unflatten.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Owner(s): ["oncall: export"]
# flake8: noqa
import copy
import dataclasses
import unittest
from contextlib import contextmanager
Expand Down Expand Up @@ -676,6 +677,38 @@ def forward(self, x):
fqn_list,
)

def test_duplicate_placeholder(self):
N, C, H, W = 1, 2, 2, 3

class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
layer = torch.nn.LayerNorm([C, H, W])
self.norms = torch.nn.ModuleList(
[
layer, # reuse layer norm
layer,
layer,
]
)

def forward(self, input_):
for i in range(len(self.norms)):
output = self.norms[i](input_)
input_ = output
return output

mod = MyModule()
input_ = torch.randn(N, C, H, W)

ep_strict = export(copy.deepcopy(mod), (input_,), strict=True)
umod = unflatten(ep_strict)
self.assertTrue(torch.allclose(umod(input_), mod(input_)))

ep_non_strict = export(copy.deepcopy(mod), (input_,), strict=False)
umod = unflatten(ep_non_strict)
self.assertTrue(torch.allclose(umod(input_), mod(input_)))


if __name__ == "__main__":
run_tests()
81 changes: 61 additions & 20 deletions torch/export/unflatten.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import abc
import copy
import operator
from collections import defaultdict
from copy import deepcopy
from enum import Enum
from itertools import chain
from typing import Any, cast, Dict, List, Optional, Union
from typing import Any, cast, Dict, List, Optional, Tuple, Union

import torch
import torch.fx._pytree as fx_pytree
Expand All @@ -14,6 +15,7 @@
from torch.export.exported_program import (
ConstantArgument,
ExportedProgram,
InputKind,
ModuleCallSignature,
SymIntArgument,
TensorArgument,
Expand Down Expand Up @@ -212,22 +214,52 @@ def __init__(
attr_kind=_AttrKind.CONSTANT,
)

inputs_to_state: Dict[str, str] = {
**self.graph_signature.inputs_to_parameters,
**self.graph_signature.inputs_to_buffers,
**self.graph_signature.inputs_to_lifted_tensor_constants,
**self.graph_signature.inputs_to_lifted_custom_objs,
}
# This is to handle parameters/buffers that point to the same tensor
# object id -> list of (node_name, target_name)
consts_map: Dict[int, List[Tuple[str, str]]] = defaultdict(list)

def add_to_consts_map(obj_id, node_name, target_name):
name_list = consts_map[obj_id]
name_list.append((node_name, target_name))

for s in self.graph_signature.input_specs:
if s.kind == InputKind.PARAMETER or (
s.kind == InputKind.BUFFER and s.persistent
):
assert hasattr(s.arg, "name")
assert isinstance(s.target, str)
add_to_consts_map(
id(export_module.state_dict[s.target]), s.arg.name, s.target
)
elif (
(s.kind == InputKind.BUFFER and not s.persistent)
or s.kind == InputKind.CONSTANT_TENSOR
or s.kind == InputKind.CUSTOM_OBJ
):
assert hasattr(s.arg, "name")
assert isinstance(s.target, str)
add_to_consts_map(
id(export_module.constants[s.target]), s.arg.name, s.target
)

# node name -> list of possible targets
inputs_to_state: Dict[str, List[str]] = {}
for node_target in consts_map.values():
targets = [t[1] for t in node_target]
for n, _ in node_target:
inputs_to_state[n] = targets

_sink_params(self, inputs_to_state, [])
# Check all input nodes has been processed.
for module in self.modules():
if not isinstance(module, torch.fx.GraphModule):
for name, module in self.named_modules():
if not hasattr(module, "graph"):
continue
for node in module.graph.nodes:
if node.op != "placeholder":
continue
assert node.name not in inputs_to_state
assert (
node.name not in inputs_to_state
), f"{node.name} was not sunk into the module {name} which has the graph: {module.graph}"

# Cache so we don't have to compute this every time.
# NOTE: this needs to be kept in sync with the placeholders in
Expand Down Expand Up @@ -857,7 +889,7 @@ def _reorder_submodules(

def _sink_params(
module: torch.nn.Module,
inputs_to_state: Dict[str, str],
inputs_to_state: Dict[str, List[str]],
scope: List[str],
):
"""Sink params, buffers, and constants from graph inputs into get_attr nodes.
Expand Down Expand Up @@ -896,16 +928,25 @@ def _sink_params(
continue

if len(node.users) > 0:
state_name = inputs_to_state[node.name].split(".")
# If there's a mismatch beteewn scope name and state name, then there must be multuple scopes
# pointing to the same state name, meaning some modules are shared. In such case, we can simply
# skip updating the current node because another later iteration will take care of this input
# node when the unique match between scope and state name occurs.
# To make sure this always happen, we should enforce the invariant that no placeholder node
# in the unflattened graph appears in inputs_to_state dict, which means all the extra input
# nodes have been handled.
if state_name[: len(scope)] != scope:
state_name = None
for sn in inputs_to_state[node.name]:
sn_split = sn.split(".")
if sn_split[: len(scope)] == scope:
state_name = sn_split
break

# If there's a mismatch beteewn scope name and state name, then
# there must be multuple scopes pointing to the same state name,
# meaning some modules are shared. In such case, we can simply skip
# updating the current node because another later iteration will
# take care of this input node when the unique match between scope
# and state name occurs. To make sure this always happen, we should
# enforce the invariant that no placeholder node in the unflattened
# graph appears in inputs_to_state dict, which means all the extra
# input nodes have been handled.
if state_name is None:
continue

attr_path = state_name[len(scope) :]
state_attr = _recursive_getattr(module, attr_path)
assert isinstance(state_attr, (torch.Tensor, torch.ScriptObject))
Expand Down

0 comments on commit a216d87

Please sign in to comment.