Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[export] Fix for unflattening modules with duplicate tensors #125192

Closed
wants to merge 1 commit into from

Conversation

angelayi
Copy link
Contributor

In the given test case, we have a ModuleList of 3 modules (norm.0, norm.1, norm.2) which share the same weight and bias tensors. However when we trace, they all end up pointing to one state dict name, (ex. norm.2).

graph():
    %p_norms_0_weight : [num_users=0] = placeholder[target=p_norms_0_weight]
    %p_norms_0_bias : [num_users=0] = placeholder[target=p_norms_0_bias]
    %p_norms_1_weight : [num_users=0] = placeholder[target=p_norms_1_weight]
    %p_norms_1_bias : [num_users=0] = placeholder[target=p_norms_1_bias]
    %p_norms_2_weight : [num_users=3] = placeholder[target=p_norms_2_weight]
    %p_norms_2_bias : [num_users=3] = placeholder[target=p_norms_2_bias]
    %input_ : [num_users=1] = placeholder[target=input_]
    %native_layer_norm : [num_users=1] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%input_, [2, 2, 3], %p_norms_2_weight, %p_norms_2_bias, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%native_layer_norm, 0), kwargs = {})
    %native_layer_norm_1 : [num_users=1] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%getitem, [2, 2, 3], %p_norms_2_weight, %p_norms_2_bias, 1e-05), kwargs = {})
    %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%native_layer_norm_1, 0), kwargs = {})
    %native_layer_norm_2 : [num_users=1] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%getitem_3, [2, 2, 3], %p_norms_2_weight, %p_norms_2_bias, 1e-05), kwargs = {})
    %getitem_6 : [num_users=1] = call_function[target=operator.getitem](args = (%native_layer_norm_2, 0), kwargs = {})
    return (getitem_6,)

This causes an error in the unflattener where after constructing the submodules for norm.0, it will have the graph pointing to norm.2.weight and norm.2.bias:

graph():
    %p_norms_2_bias : [num_users=1] = placeholder[target=p_norms_2_bias]
    %p_norms_2_weight : [num_users=1] = placeholder[target=p_norms_2_weight]
    %input_ : [num_users=1] = placeholder[target=input_]
    %native_layer_norm : [num_users=1] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%input_, [2, 2, 3], %p_norms_2_weight, %p_norms_2_bias, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%native_layer_norm, 0), kwargs = {})
    return getitem

Since the attributes are not within the same scope of the graph, (norm.0 vs. norm.2), they will not be added to the subgraph, causing an error.

So this PR handles the duplicate state dict attributes by modifying the inputs_to_state dict to map from node names to a list of possible state dict target names.

Copy link

pytorch-bot bot commented Apr 29, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125192

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 73213d7 with merge base 00dd4d5 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

@angelayi has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@angelayi has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@angelayi has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge -f 'Landed internally'

(Initiating merge automatically since Phabricator Diff has merged, using force because this PR might not pass merge_rules.json but landed internally)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
…#125192)

In the given test case, we have a ModuleList of 3 modules (`norm.0`, `norm.1`, `norm.2`) which share the same `weight` and `bias` tensors. However when we trace, they all end up pointing to one state dict name, (ex. `norm.2`).
```
graph():
    %p_norms_0_weight : [num_users=0] = placeholder[target=p_norms_0_weight]
    %p_norms_0_bias : [num_users=0] = placeholder[target=p_norms_0_bias]
    %p_norms_1_weight : [num_users=0] = placeholder[target=p_norms_1_weight]
    %p_norms_1_bias : [num_users=0] = placeholder[target=p_norms_1_bias]
    %p_norms_2_weight : [num_users=3] = placeholder[target=p_norms_2_weight]
    %p_norms_2_bias : [num_users=3] = placeholder[target=p_norms_2_bias]
    %input_ : [num_users=1] = placeholder[target=input_]
    %native_layer_norm : [num_users=1] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%input_, [2, 2, 3], %p_norms_2_weight, %p_norms_2_bias, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%native_layer_norm, 0), kwargs = {})
    %native_layer_norm_1 : [num_users=1] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%getitem, [2, 2, 3], %p_norms_2_weight, %p_norms_2_bias, 1e-05), kwargs = {})
    %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%native_layer_norm_1, 0), kwargs = {})
    %native_layer_norm_2 : [num_users=1] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%getitem_3, [2, 2, 3], %p_norms_2_weight, %p_norms_2_bias, 1e-05), kwargs = {})
    %getitem_6 : [num_users=1] = call_function[target=operator.getitem](args = (%native_layer_norm_2, 0), kwargs = {})
    return (getitem_6,)
```
This causes an error in the unflattener where after constructing the submodules for `norm.0`, it will have the graph pointing to `norm.2.weight` and `norm.2.bias`:
```
graph():
    %p_norms_2_bias : [num_users=1] = placeholder[target=p_norms_2_bias]
    %p_norms_2_weight : [num_users=1] = placeholder[target=p_norms_2_weight]
    %input_ : [num_users=1] = placeholder[target=input_]
    %native_layer_norm : [num_users=1] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%input_, [2, 2, 3], %p_norms_2_weight, %p_norms_2_bias, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%native_layer_norm, 0), kwargs = {})
    return getitem
```
Since the attributes are not within the same scope of the graph, (`norm.0` vs. `norm.2`), they will not be added to the subgraph, causing an error.

So this PR handles the duplicate state dict attributes by modifying the `inputs_to_state` dict to map from node names to a list of possible state dict target names.
Pull Request resolved: pytorch#125192
Approved by: https://github.com/zhxchen17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants