Skip to content

Commit

Permalink
Refactored _remove_auto_functionalization_from_graph_helper (pytorch#…
Browse files Browse the repository at this point in the history
…125180)

Summary:
Refactored the function to remove multiple list slicing and use unused variable.

Test Plan:
python test/run_test.py

Reviewers: @drisspg

Subscribers:

Tasks: [T187526123](https://www.internalfb.com/intern/tasks/?t=187526123) [T93492332](https://www.internalfb.com/intern/tasks/?t=93492332)

Tags: @pytorchbot merge -r viable/strict
Pull Request resolved: pytorch#125180
Approved by: https://github.com/drisspg
  • Loading branch information
jainapurva authored and petrex committed May 3, 2024
1 parent c9fef5e commit 33898e8
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions torch/export/_remove_auto_functionalized_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,12 @@ def _remove_auto_functionalization_from_graph_helper(ep, auto_functionalize_node
)

# If the result of getitem was used in an output node, update the output spec with the correct name
adusted_index = user.args[1] - len(func._schema.returns)
original_arg = original_kwargs[mutable_args_names[adusted_index]]
adjusted_index = user.args[1] - len(func._schema.returns)
original_arg = original_kwargs[mutable_args_names[adjusted_index]]

# This is a little fragile/implementation dependent, but the order of the mutable args is the same as the order
# of the getitem calls following the HOP.
user.replace_all_uses_with(
original_kwargs[mutable_args_names[adusted_index]]
)
user.replace_all_uses_with(original_arg)

if len(func._schema.returns) == 1:
# If the function has 1 return then it will just directly return the
Expand Down

0 comments on commit 33898e8

Please sign in to comment.