Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TorchFX PTQ backend #2764

Open
wants to merge 8 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ install-torch-test:
pip install -e .
pip install "git+https://github.com/openvinotoolkit/open_model_zoo.git@37f60eb#egg=accuracy_checker&subdirectory=tools/accuracy_checker"
pip install -r tests/torch/requirements.txt
pip install -r tests/torch/fx/requirements.txt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which ones are new depending on where they are installed?

pip install -r tests/cross_fw/install/requirements.txt
pip install -r tests/cross_fw/examples/requirements.txt

Expand Down
14 changes: 13 additions & 1 deletion nncf/common/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def create(model: TModel) -> NNCFGraph:
if model_backend == BackendType.OPENVINO:
from nncf.openvino.graph.nncf_graph_builder import GraphConverter

return GraphConverter.create_nncf_graph(model)
if model_backend == BackendType.TORCH_FX:
from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter

return GraphConverter.create_nncf_graph(model)
if model_backend == BackendType.TORCH:
return model.nncf.get_graph()
Expand Down Expand Up @@ -72,6 +76,10 @@ def create(model: TModel, inplace: bool = False) -> ModelTransformer:
from nncf.torch.model_transformer import PTModelTransformer

return PTModelTransformer(model)
if model_backend == BackendType.TORCH_FX:
from nncf.experimental.torch.fx.model_transformer import FXModelTransformer

return FXModelTransformer(model)
raise nncf.UnsupportedBackendError(
"Cannot create backend-specific model transformer because {} is not supported!".format(model_backend.value)
)
Expand All @@ -95,7 +103,7 @@ def create(model: TModel) -> Engine:
from nncf.openvino.engine import OVNativeEngine

return OVNativeEngine(model)
if model_backend == BackendType.TORCH:
if model_backend in (BackendType.TORCH, BackendType.TORCH_FX):
from nncf.torch.engine import PTEngine

return PTEngine(model)
Expand Down Expand Up @@ -151,6 +159,10 @@ def create(model: TModel, dataset: Dataset) -> aggregator.StatisticsAggregator:
from nncf.torch.statistics.aggregator import PTStatisticsAggregator

return PTStatisticsAggregator(dataset)
if model_backend == BackendType.TORCH_FX:
from nncf.experimental.torch.fx.statistics.aggregator import FXStatisticsAggregator

return FXStatisticsAggregator(dataset)
raise nncf.UnsupportedBackendError(
"Cannot create backend-specific statistics aggregator because {} is not supported!".format(
model_backend.value
Expand Down
4 changes: 2 additions & 2 deletions nncf/common/graph/patterns/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _get_backend_hw_patterns_map(backend: BackendType) -> Dict[HWFusedPatternNam

registry = OPENVINO_HW_FUSED_PATTERNS.registry_dict
return registry
if backend == BackendType.TORCH:
if backend in (BackendType.TORCH, BackendType.TORCH_FX):
from nncf.torch.hardware.fused_patterns import PT_HW_FUSED_PATTERNS

registry = PT_HW_FUSED_PATTERNS.registry_dict
Expand Down Expand Up @@ -73,7 +73,7 @@ def _get_backend_ignored_patterns_map(

registry = OPENVINO_IGNORED_PATTERNS.registry_dict
return registry
if backend == BackendType.TORCH:
if backend in (BackendType.TORCH, BackendType.TORCH_FX):
from nncf.torch.quantization.ignored_patterns import PT_IGNORED_PATTERNS

registry = PT_IGNORED_PATTERNS.registry_dict
Expand Down
24 changes: 21 additions & 3 deletions nncf/common/utils/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

class BackendType(Enum):
TORCH = "Torch"
TORCH_FX = "TorchFX"
TENSORFLOW = "Tensorflow"
ONNX = "ONNX"
OPENVINO = "OpenVINO"
Expand All @@ -33,6 +34,7 @@ def get_available_backends() -> List[BackendType]:
"""
frameworks = [
("torch", BackendType.TORCH),
("torch.fx", BackendType.TORCH_FX),
("tensorflow", BackendType.TENSORFLOW),
("onnx", BackendType.ONNX),
("openvino.runtime", BackendType.OPENVINO),
Expand All @@ -51,14 +53,27 @@ def get_available_backends() -> List[BackendType]:

def is_torch_model(model: TModel) -> bool:
"""
Returns True if the model is an instance of torch.nn.Module, otherwise False.
Returns True if the model is an instance of torch.nn.Module and not a torch.fx.GraphModule, otherwise False.
:param model: A target model.
:return: True if the model is an instance of torch.nn.Module, otherwise False.
:return: True if the model is an instance of torch.nn.Module and not torch.fx.GraphModule, otherwise False.
"""
import torch
import torch.fx

return isinstance(model, torch.nn.Module)
return not isinstance(model, torch.fx.GraphModule) and isinstance(model, torch.nn.Module)


def is_torch_fx_model(model: TModel) -> bool:
"""
Returns True if the model is an instance of torch.fx.GraphModule, otherwise False.
:param model: A target model.
:return: True if the model is an instance of torch.fx.GraphModule, otherwise False.
"""
import torch.fx

return isinstance(model, torch.fx.GraphModule)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.



def is_tensorflow_model(model: TModel) -> bool:
Expand Down Expand Up @@ -118,6 +133,9 @@ def get_backend(model: TModel) -> BackendType:
"""
available_backends = get_available_backends()

if BackendType.TORCH_FX in available_backends and is_torch_fx_model(model):
return BackendType.TORCH_FX

if BackendType.TORCH in available_backends and is_torch_model(model):
return BackendType.TORCH

Expand Down
10 changes: 10 additions & 0 deletions nncf/experimental/torch/fx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
114 changes: 114 additions & 0 deletions nncf/experimental/torch/fx/model_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict

# from functools import partial
from typing import Callable, List, Union

import torch
import torch.fx
from torch.fx.passes.split_utils import split_by_tags

from nncf.common.graph.model_transformer import ModelTransformer
from nncf.common.graph.transformations.commands import Command
from nncf.common.graph.transformations.commands import TransformationPriority
from nncf.common.graph.transformations.commands import TransformationType
from nncf.torch.graph.transformations.commands import PTModelExtractionCommand
from nncf.torch.graph.transformations.layout import PTTransformationLayout


class FXApplyTransformationCommand(Command):
def __init__(
self,
transformation_fn: Callable[[torch.fx.GraphModule], None],
priority: Union[TransformationPriority, int] = TransformationPriority.DEFAULT_PRIORITY,
):
super().__init__(TransformationType.INSERT)
self.tranformation_fn = transformation_fn
self.priority = priority


class FXModelTransformer(ModelTransformer):
"""
Applies transformations upon Torch FX model.
"""

def __init__(self, model: torch.fx.GraphModule):
super().__init__(model)

self._command_transformation_ordered_pairs = [
(FXApplyTransformationCommand, self._apply_transformation),
(PTModelExtractionCommand, self._apply_model_extraction),
]

def transform(self, transformation_layout: PTTransformationLayout) -> torch.fx.GraphModule:
# TODO(dlyakhov): Manage priorities of transformations.
transformations = transformation_layout.transformations
aggregated_transformations = defaultdict(list)
for transformation in transformations:
aggregated_transformations[transformation.__class__].append(transformation)

model = self._model
for transformation_cls, transformation_fn in self._command_transformation_ordered_pairs:
transformations = aggregated_transformations[transformation_cls]
if transformations:
model = transformation_fn(model, transformations)

# Do not eliminate dead code as
# the dead code is computing statistics :)
# model.graph.eliminate_dead_code()
model.recompile()
return model

@staticmethod
def _apply_model_extraction(
model: torch.fx.GraphModule,
transformations: List[PTModelExtractionCommand],
) -> torch.fx.GraphModule:
transformation = transformations[-1]
assert len(transformation.input_node_names) == 1
assert transformation.input_node_names == transformation.output_node_names
node_name = transformation.input_node_names[0]

tags = ["before", "extracted", "after"]
i = 0
for node in model.graph.nodes:
if node.name == node_name:
node.tag = tags[1]
weights = [node.all_input_nodes[1]]
while weights:
w_node = weights.pop()
assert w_node.tag in tags[0:2]
w_node.tag = tags[1]
weights.extend(w_node.all_input_nodes)
i = 2
continue
node.tag = tags[i]

splitted_gm = split_by_tags(model, tags)
return splitted_gm.extracted

@staticmethod
def get_graph_node_by_name(graph: torch.fx.Graph, name: str) -> torch.fx.Node:
for node in graph.nodes:
if node.name == name:
return node
raise RuntimeError(f"Node with name {name} is not found")

@staticmethod
def _apply_transformation(
model: torch.fx.GraphModule,
transformations: List[FXApplyTransformationCommand],
) -> torch.fx.GraphModule:
for transformation in transformations:
transformation.tranformation_fn(model)
return model
121 changes: 121 additions & 0 deletions nncf/experimental/torch/fx/nncf_graph_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple

import torch.fx

import nncf.torch.graph.operator_metatypes as om
from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.common.graph.layer_attributes import Dtype
from nncf.common.graph.operator_metatypes import UnknownMetatype
from nncf.common.logging import nncf_logger
from nncf.torch.graph.graph import PTNNCFGraph
from nncf.torch.graph.operator_metatypes import PT_OPERATOR_METATYPES


class GraphConverter:
"""
Builds the NNCFGraph from an torch.fx.GraphModule instance.
"""

@staticmethod
def _get_node_type_and_metatype(node: torch.fx.Node) -> Tuple[str, om.OperatorMetatype]:
if node.op == "placeholder":
node_type = "input"
node_metatype = om.PTInputNoopMetatype
elif node.op == "output":
node_type = "output"
node_metatype = om.PTOutputNoopMetatype
elif node.op == "get_attr":
node_type = "get_attr"
node_metatype = om.PTConstNoopMetatype
elif node.op in ("call_function",):
if hasattr(node.target, "overloadpacket"):
node_type = str(node.target.overloadpacket).split(".")[1]
elif node.target.__name__ == "getitem":
node_type = "__getitem__"
else:
# TODO(dlyakhov): get correct nodes types from this nodes as well
node_type = str(node.target)
node_metatype = PT_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type)
else:
node_type = node.op
node_metatype = UnknownMetatype
if node_metatype is UnknownMetatype:
nncf_logger.info(f"Unknown metatype for node: {node}")
return node_type, node_metatype

@staticmethod
def create_nncf_graph(model: torch.fx.GraphModule) -> NNCFGraph:
"""
Creates NNCFGraph from GraphModule.
All nodes from model which have valid metatype are added to NNCFGraph.
Then, corresponding edges are added to the NNCFGraph with shape, type, output and input port ids.
:param model: torch fx GraphModule.
:return: NNCFGraph.
"""

nncf_graph = PTNNCFGraph()

for source_node in model.graph.nodes:
node_type, node_metatype = GraphConverter._get_node_type_and_metatype(source_node)

nncf_graph.add_nncf_node(
node_name=source_node.name,
node_type=node_type,
node_metatype=node_metatype,
)

for source_node in model.graph.nodes:
source_nncf_node = nncf_graph.get_node_by_name(source_node.name)
for idx, dist_node in enumerate(source_node.users):
dist_node_id = nncf_graph.get_node_by_name(dist_node.name).node_id
input_port_id, output_port_id, tensor_shape = GraphConverter.get_edge_params(
model, source_node, source_nncf_node, dist_node, idx
)

nncf_graph.add_edge_between_nncf_nodes(
source_nncf_node.node_id,
dist_node_id,
tensor_shape=tensor_shape,
input_port_id=input_port_id,
output_port_id=output_port_id,
dtype=Dtype.FLOAT,
)

return nncf_graph

@staticmethod
def get_edge_params(
model, source_node: torch.fx.Node, source_nncf_node: NNCFNode, dist_node: torch.fx.Node, output_idx: int
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add type hints.

output_port_id = 0
if source_node.op in ("get_attr",):
tensor_shape = tuple(getattr(model, source_node.target).shape)
elif "val" in source_node.meta:
if source_nncf_node.metatype is om.PTBatchNormMetatype:
tensor = source_node.meta["val"][0]
elif source_nncf_node.metatype is om.PTSplitMetatype:
tensor = source_node.meta["val"][output_idx]
# Assume every split outputs corresponds to an unique output_port_id
output_port_id = output_idx
else:
tensor = source_node.meta["val"]
tensor_shape = tuple(tensor.shape)
else:
nncf_logger.info(f"Edge shape between {source_node.name} and {dist_node.name} is unknown.")
tensor_shape = None

input_port_id = dist_node.all_input_nodes.index(source_node)
return input_port_id, output_port_id, tensor_shape
10 changes: 10 additions & 0 deletions nncf/experimental/torch/fx/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading
Loading