Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[export] Fix for unflattening modules with duplicate tensors (#125192)
In the given test case, we have a ModuleList of 3 modules (`norm.0`, `norm.1`, `norm.2`) which share the same `weight` and `bias` tensors. However when we trace, they all end up pointing to one state dict name, (ex. `norm.2`). ``` graph(): %p_norms_0_weight : [num_users=0] = placeholder[target=p_norms_0_weight] %p_norms_0_bias : [num_users=0] = placeholder[target=p_norms_0_bias] %p_norms_1_weight : [num_users=0] = placeholder[target=p_norms_1_weight] %p_norms_1_bias : [num_users=0] = placeholder[target=p_norms_1_bias] %p_norms_2_weight : [num_users=3] = placeholder[target=p_norms_2_weight] %p_norms_2_bias : [num_users=3] = placeholder[target=p_norms_2_bias] %input_ : [num_users=1] = placeholder[target=input_] %native_layer_norm : [num_users=1] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%input_, [2, 2, 3], %p_norms_2_weight, %p_norms_2_bias, 1e-05), kwargs = {}) %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%native_layer_norm, 0), kwargs = {}) %native_layer_norm_1 : [num_users=1] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%getitem, [2, 2, 3], %p_norms_2_weight, %p_norms_2_bias, 1e-05), kwargs = {}) %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%native_layer_norm_1, 0), kwargs = {}) %native_layer_norm_2 : [num_users=1] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%getitem_3, [2, 2, 3], %p_norms_2_weight, %p_norms_2_bias, 1e-05), kwargs = {}) %getitem_6 : [num_users=1] = call_function[target=operator.getitem](args = (%native_layer_norm_2, 0), kwargs = {}) return (getitem_6,) ``` This causes an error in the unflattener where after constructing the submodules for `norm.0`, it will have the graph pointing to `norm.2.weight` and `norm.2.bias`: ``` graph(): %p_norms_2_bias : [num_users=1] = placeholder[target=p_norms_2_bias] %p_norms_2_weight : [num_users=1] = placeholder[target=p_norms_2_weight] %input_ : [num_users=1] = placeholder[target=input_] %native_layer_norm : [num_users=1] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%input_, [2, 2, 3], %p_norms_2_weight, %p_norms_2_bias, 1e-05), kwargs = {}) %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%native_layer_norm, 0), kwargs = {}) return getitem ``` Since the attributes are not within the same scope of the graph, (`norm.0` vs. `norm.2`), they will not be added to the subgraph, causing an error. So this PR handles the duplicate state dict attributes by modifying the `inputs_to_state` dict to map from node names to a list of possible state dict target names. Pull Request resolved: #125192 Approved by: https://github.com/zhxchen17
- Loading branch information