New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[export] Fix for unflattening modules with duplicate tensors #125192
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125192
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 73213d7 with merge base 00dd4d5 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@angelayi has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
9254a77
to
b2eabac
Compare
@angelayi has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
b2eabac
to
73213d7
Compare
@angelayi has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@pytorchbot merge -f 'Landed internally' (Initiating merge automatically since Phabricator Diff has merged, using force because this PR might not pass merge_rules.json but landed internally) |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…#125192) In the given test case, we have a ModuleList of 3 modules (`norm.0`, `norm.1`, `norm.2`) which share the same `weight` and `bias` tensors. However when we trace, they all end up pointing to one state dict name, (ex. `norm.2`). ``` graph(): %p_norms_0_weight : [num_users=0] = placeholder[target=p_norms_0_weight] %p_norms_0_bias : [num_users=0] = placeholder[target=p_norms_0_bias] %p_norms_1_weight : [num_users=0] = placeholder[target=p_norms_1_weight] %p_norms_1_bias : [num_users=0] = placeholder[target=p_norms_1_bias] %p_norms_2_weight : [num_users=3] = placeholder[target=p_norms_2_weight] %p_norms_2_bias : [num_users=3] = placeholder[target=p_norms_2_bias] %input_ : [num_users=1] = placeholder[target=input_] %native_layer_norm : [num_users=1] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%input_, [2, 2, 3], %p_norms_2_weight, %p_norms_2_bias, 1e-05), kwargs = {}) %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%native_layer_norm, 0), kwargs = {}) %native_layer_norm_1 : [num_users=1] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%getitem, [2, 2, 3], %p_norms_2_weight, %p_norms_2_bias, 1e-05), kwargs = {}) %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%native_layer_norm_1, 0), kwargs = {}) %native_layer_norm_2 : [num_users=1] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%getitem_3, [2, 2, 3], %p_norms_2_weight, %p_norms_2_bias, 1e-05), kwargs = {}) %getitem_6 : [num_users=1] = call_function[target=operator.getitem](args = (%native_layer_norm_2, 0), kwargs = {}) return (getitem_6,) ``` This causes an error in the unflattener where after constructing the submodules for `norm.0`, it will have the graph pointing to `norm.2.weight` and `norm.2.bias`: ``` graph(): %p_norms_2_bias : [num_users=1] = placeholder[target=p_norms_2_bias] %p_norms_2_weight : [num_users=1] = placeholder[target=p_norms_2_weight] %input_ : [num_users=1] = placeholder[target=input_] %native_layer_norm : [num_users=1] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%input_, [2, 2, 3], %p_norms_2_weight, %p_norms_2_bias, 1e-05), kwargs = {}) %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%native_layer_norm, 0), kwargs = {}) return getitem ``` Since the attributes are not within the same scope of the graph, (`norm.0` vs. `norm.2`), they will not be added to the subgraph, causing an error. So this PR handles the duplicate state dict attributes by modifying the `inputs_to_state` dict to map from node names to a list of possible state dict target names. Pull Request resolved: pytorch#125192 Approved by: https://github.com/zhxchen17
In the given test case, we have a ModuleList of 3 modules (
norm.0
,norm.1
,norm.2
) which share the sameweight
andbias
tensors. However when we trace, they all end up pointing to one state dict name, (ex.norm.2
).This causes an error in the unflattener where after constructing the submodules for
norm.0
, it will have the graph pointing tonorm.2.weight
andnorm.2.bias
:Since the attributes are not within the same scope of the graph, (
norm.0
vs.norm.2
), they will not be added to the subgraph, causing an error.So this PR handles the duplicate state dict attributes by modifying the
inputs_to_state
dict to map from node names to a list of possible state dict target names.