From ec6dad7bc83787047bb12ab0e05fee151d0a56b2 Mon Sep 17 00:00:00 2001 From: Apurva Date: Mon, 29 Apr 2024 12:06:15 -0700 Subject: [PATCH] Refactored _remove_auto_functionalization_from_graph_helper Summary: Refactored the function to remove multiple list slicing and use unused variable. Test Plan: python test/run_test.py Reviewers: @drisspg Subscribers: Tasks: Tags: [ghstack-poisoned] --- torch/export/_remove_auto_functionalized_pass.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/torch/export/_remove_auto_functionalized_pass.py b/torch/export/_remove_auto_functionalized_pass.py index 1d6d53192bea..c1cea8ec005f 100644 --- a/torch/export/_remove_auto_functionalized_pass.py +++ b/torch/export/_remove_auto_functionalized_pass.py @@ -43,14 +43,12 @@ def _remove_auto_functionalization_from_graph_helper(ep, auto_functionalize_node ) # If the result of getitem was used in an output node, update the output spec with the correct name - adusted_index = user.args[1] - len(func._schema.returns) - original_arg = original_kwargs[mutable_args_names[adusted_index]] + adjusted_index = user.args[1] - len(func._schema.returns) + original_arg = original_kwargs[mutable_args_names[adjusted_index]] # This is a little fragile/implementation dependent, but the order of the mutable args is the same as the order # of the getitem calls following the HOP. - user.replace_all_uses_with( - original_kwargs[mutable_args_names[adusted_index]] - ) + user.replace_all_uses_with(original_arg) if len(func._schema.returns) == 1: # If the function has 1 return then it will just directly return the