-
Notifications
You must be signed in to change notification settings - Fork 214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TorchFX PTQ backend #2764
Open
daniil-lyakhov
wants to merge
8
commits into
openvinotoolkit:develop
Choose a base branch
from
daniil-lyakhov:dl/torch_fx_quantization_init
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,776
−61
Open
TorchFX PTQ backend #2764
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
b77256d
TorchFX quantization init
daniil-lyakhov ec34211
Test code is removed
daniil-lyakhov f539fde
Reference graph are updated
daniil-lyakhov 8205d9c
torch-fx tests are added to pre-commit
daniil-lyakhov 2d5a02b
Comments
daniil-lyakhov 71b49f7
Model transformer minor refactoring
daniil-lyakhov afbd370
Comments
daniil-lyakhov 72d0f90
Comments
daniil-lyakhov File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
|
||
class BackendType(Enum): | ||
TORCH = "Torch" | ||
TORCH_FX = "TorchFX" | ||
TENSORFLOW = "Tensorflow" | ||
ONNX = "ONNX" | ||
OPENVINO = "OpenVINO" | ||
|
@@ -33,6 +34,7 @@ def get_available_backends() -> List[BackendType]: | |
""" | ||
frameworks = [ | ||
("torch", BackendType.TORCH), | ||
("torch.fx", BackendType.TORCH_FX), | ||
("tensorflow", BackendType.TENSORFLOW), | ||
("onnx", BackendType.ONNX), | ||
("openvino.runtime", BackendType.OPENVINO), | ||
|
@@ -51,14 +53,27 @@ def get_available_backends() -> List[BackendType]: | |
|
||
def is_torch_model(model: TModel) -> bool: | ||
""" | ||
Returns True if the model is an instance of torch.nn.Module, otherwise False. | ||
Returns True if the model is an instance of torch.nn.Module and not a torch.fx.GraphModule, otherwise False. | ||
:param model: A target model. | ||
:return: True if the model is an instance of torch.nn.Module, otherwise False. | ||
:return: True if the model is an instance of torch.nn.Module and not torch.fx.GraphModule, otherwise False. | ||
""" | ||
import torch | ||
import torch.fx | ||
|
||
return isinstance(model, torch.nn.Module) | ||
return not isinstance(model, torch.fx.GraphModule) and isinstance(model, torch.nn.Module) | ||
|
||
|
||
def is_torch_fx_model(model: TModel) -> bool: | ||
""" | ||
Returns True if the model is an instance of torch.fx.GraphModule, otherwise False. | ||
:param model: A target model. | ||
:return: True if the model is an instance of torch.fx.GraphModule, otherwise False. | ||
""" | ||
import torch.fx | ||
|
||
return isinstance(model, torch.fx.GraphModule) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you going to support ExportedProgram (https://pytorch.org/docs/stable/export.html#torch.export.ExportedProgram)? |
||
|
||
|
||
def is_tensorflow_model(model: TModel) -> bool: | ||
|
@@ -118,6 +133,9 @@ def get_backend(model: TModel) -> BackendType: | |
""" | ||
available_backends = get_available_backends() | ||
|
||
if BackendType.TORCH_FX in available_backends and is_torch_fx_model(model): | ||
return BackendType.TORCH_FX | ||
|
||
if BackendType.TORCH in available_backends and is_torch_model(model): | ||
return BackendType.TORCH | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Copyright (c) 2024 Intel Corporation | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# Copyright (c) 2024 Intel Corporation | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from collections import defaultdict | ||
|
||
# from functools import partial | ||
from typing import Callable, List, Union | ||
|
||
import torch | ||
import torch.fx | ||
from torch.fx.passes.split_utils import split_by_tags | ||
|
||
from nncf.common.graph.model_transformer import ModelTransformer | ||
from nncf.common.graph.transformations.commands import Command | ||
from nncf.common.graph.transformations.commands import TransformationPriority | ||
from nncf.common.graph.transformations.commands import TransformationType | ||
from nncf.torch.graph.transformations.commands import PTModelExtractionCommand | ||
from nncf.torch.graph.transformations.layout import PTTransformationLayout | ||
|
||
|
||
class FXApplyTransformationCommand(Command): | ||
def __init__( | ||
self, | ||
transformation_fn: Callable[[torch.fx.GraphModule], None], | ||
priority: Union[TransformationPriority, int] = TransformationPriority.DEFAULT_PRIORITY, | ||
): | ||
super().__init__(TransformationType.INSERT) | ||
self.tranformation_fn = transformation_fn | ||
self.priority = priority | ||
|
||
|
||
class FXModelTransformer(ModelTransformer): | ||
""" | ||
Applies transformations upon Torch FX model. | ||
""" | ||
|
||
def __init__(self, model: torch.fx.GraphModule): | ||
super().__init__(model) | ||
|
||
self._command_transformation_ordered_pairs = [ | ||
(FXApplyTransformationCommand, self._apply_transformation), | ||
(PTModelExtractionCommand, self._apply_model_extraction), | ||
] | ||
|
||
def transform(self, transformation_layout: PTTransformationLayout) -> torch.fx.GraphModule: | ||
# TODO(dlyakhov): Manage priorities of transformations. | ||
transformations = transformation_layout.transformations | ||
aggregated_transformations = defaultdict(list) | ||
for transformation in transformations: | ||
aggregated_transformations[transformation.__class__].append(transformation) | ||
|
||
model = self._model | ||
for transformation_cls, transformation_fn in self._command_transformation_ordered_pairs: | ||
transformations = aggregated_transformations[transformation_cls] | ||
if transformations: | ||
model = transformation_fn(model, transformations) | ||
|
||
# Do not eliminate dead code as | ||
# the dead code is computing statistics :) | ||
# model.graph.eliminate_dead_code() | ||
model.recompile() | ||
return model | ||
|
||
@staticmethod | ||
def _apply_model_extraction( | ||
model: torch.fx.GraphModule, | ||
transformations: List[PTModelExtractionCommand], | ||
) -> torch.fx.GraphModule: | ||
transformation = transformations[-1] | ||
assert len(transformation.input_node_names) == 1 | ||
assert transformation.input_node_names == transformation.output_node_names | ||
node_name = transformation.input_node_names[0] | ||
|
||
tags = ["before", "extracted", "after"] | ||
i = 0 | ||
for node in model.graph.nodes: | ||
if node.name == node_name: | ||
node.tag = tags[1] | ||
weights = [node.all_input_nodes[1]] | ||
while weights: | ||
w_node = weights.pop() | ||
assert w_node.tag in tags[0:2] | ||
w_node.tag = tags[1] | ||
weights.extend(w_node.all_input_nodes) | ||
i = 2 | ||
continue | ||
node.tag = tags[i] | ||
|
||
splitted_gm = split_by_tags(model, tags) | ||
return splitted_gm.extracted | ||
|
||
@staticmethod | ||
def get_graph_node_by_name(graph: torch.fx.Graph, name: str) -> torch.fx.Node: | ||
for node in graph.nodes: | ||
if node.name == name: | ||
return node | ||
raise RuntimeError(f"Node with name {name} is not found") | ||
|
||
@staticmethod | ||
def _apply_transformation( | ||
model: torch.fx.GraphModule, | ||
transformations: List[FXApplyTransformationCommand], | ||
) -> torch.fx.GraphModule: | ||
for transformation in transformations: | ||
transformation.tranformation_fn(model) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# Copyright (c) 2024 Intel Corporation | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Tuple | ||
|
||
import torch.fx | ||
|
||
import nncf.torch.graph.operator_metatypes as om | ||
from nncf.common.graph import NNCFGraph | ||
from nncf.common.graph import NNCFNode | ||
from nncf.common.graph.layer_attributes import Dtype | ||
from nncf.common.graph.operator_metatypes import UnknownMetatype | ||
from nncf.common.logging import nncf_logger | ||
from nncf.torch.graph.graph import PTNNCFGraph | ||
from nncf.torch.graph.operator_metatypes import PT_OPERATOR_METATYPES | ||
|
||
|
||
class GraphConverter: | ||
""" | ||
Builds the NNCFGraph from an torch.fx.GraphModule instance. | ||
""" | ||
|
||
@staticmethod | ||
def _get_node_type_and_metatype(node: torch.fx.Node) -> Tuple[str, om.OperatorMetatype]: | ||
if node.op == "placeholder": | ||
node_type = "input" | ||
node_metatype = om.PTInputNoopMetatype | ||
elif node.op == "output": | ||
node_type = "output" | ||
node_metatype = om.PTOutputNoopMetatype | ||
elif node.op == "get_attr": | ||
node_type = "get_attr" | ||
node_metatype = om.PTConstNoopMetatype | ||
elif node.op in ("call_function",): | ||
if hasattr(node.target, "overloadpacket"): | ||
node_type = str(node.target.overloadpacket).split(".")[1] | ||
elif node.target.__name__ == "getitem": | ||
node_type = "__getitem__" | ||
else: | ||
# TODO(dlyakhov): get correct nodes types from this nodes as well | ||
node_type = str(node.target) | ||
node_metatype = PT_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type) | ||
else: | ||
node_type = node.op | ||
node_metatype = UnknownMetatype | ||
if node_metatype is UnknownMetatype: | ||
nncf_logger.info(f"Unknown metatype for node: {node}") | ||
return node_type, node_metatype | ||
|
||
@staticmethod | ||
def create_nncf_graph(model: torch.fx.GraphModule) -> NNCFGraph: | ||
""" | ||
Creates NNCFGraph from GraphModule. | ||
All nodes from model which have valid metatype are added to NNCFGraph. | ||
Then, corresponding edges are added to the NNCFGraph with shape, type, output and input port ids. | ||
:param model: torch fx GraphModule. | ||
:return: NNCFGraph. | ||
""" | ||
|
||
nncf_graph = PTNNCFGraph() | ||
|
||
for source_node in model.graph.nodes: | ||
node_type, node_metatype = GraphConverter._get_node_type_and_metatype(source_node) | ||
|
||
nncf_graph.add_nncf_node( | ||
node_name=source_node.name, | ||
node_type=node_type, | ||
node_metatype=node_metatype, | ||
) | ||
|
||
for source_node in model.graph.nodes: | ||
source_nncf_node = nncf_graph.get_node_by_name(source_node.name) | ||
for idx, dist_node in enumerate(source_node.users): | ||
dist_node_id = nncf_graph.get_node_by_name(dist_node.name).node_id | ||
input_port_id, output_port_id, tensor_shape = GraphConverter.get_edge_params( | ||
model, source_node, source_nncf_node, dist_node, idx | ||
) | ||
|
||
nncf_graph.add_edge_between_nncf_nodes( | ||
source_nncf_node.node_id, | ||
dist_node_id, | ||
tensor_shape=tensor_shape, | ||
input_port_id=input_port_id, | ||
output_port_id=output_port_id, | ||
dtype=Dtype.FLOAT, | ||
) | ||
|
||
return nncf_graph | ||
|
||
@staticmethod | ||
def get_edge_params( | ||
model, source_node: torch.fx.Node, source_nncf_node: NNCFNode, dist_node: torch.fx.Node, output_idx: int | ||
): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add type hints. |
||
output_port_id = 0 | ||
if source_node.op in ("get_attr",): | ||
tensor_shape = tuple(getattr(model, source_node.target).shape) | ||
elif "val" in source_node.meta: | ||
if source_nncf_node.metatype is om.PTBatchNormMetatype: | ||
tensor = source_node.meta["val"][0] | ||
elif source_nncf_node.metatype is om.PTSplitMetatype: | ||
tensor = source_node.meta["val"][output_idx] | ||
# Assume every split outputs corresponds to an unique output_port_id | ||
output_port_id = output_idx | ||
else: | ||
tensor = source_node.meta["val"] | ||
tensor_shape = tuple(tensor.shape) | ||
else: | ||
nncf_logger.info(f"Edge shape between {source_node.name} and {dist_node.name} is unknown.") | ||
tensor_shape = None | ||
|
||
input_port_id = dist_node.all_input_nodes.index(source_node) | ||
return input_port_id, output_port_id, tensor_shape |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Copyright (c) 2024 Intel Corporation | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which ones are new depending on where they are installed?