Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test ResNet Ops Support #28

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
2ba43bd
E2M1 compression. (#2745)
andreyanufr Jun 28, 2024
beb9c04
Update Fast-/BC algorithms (#2747)
nikita-malininn Jun 28, 2024
a25e56d
Support TF 2.15 (#2609)
andrey-churkin Jul 1, 2024
082841c
[PT] Quantization of addmm function (#2713)
AlexanderDokuchaev Jul 1, 2024
e8ea252
Remove reference data for TF 2.8 from the tests (#2770)
andrey-churkin Jul 1, 2024
f5ad4ea
Revert "Update Fast-/BC algorithms (#2747)" (#2771)
nikita-malininn Jul 1, 2024
1f07622
Extending functions by lstsq, svd, eye (#2774)
ljaljushkin Jul 2, 2024
aeaaf19
Adjust NNCF to numpy 2.0 api (#2772)
AlexanderDokuchaev Jul 3, 2024
a644a1e
Update Fast-/BC with MatMul/Gemm support (#2776)
nikita-malininn Jul 3, 2024
fd0e33c
E8M0 scale for E2M1 weights. (#2767)
andreyanufr Jul 3, 2024
3ef206c
TorchFX quantization init
daniil-lyakhov May 21, 2024
58606d0
Test code is removed
daniil-lyakhov Jun 25, 2024
f0b926e
Reference graph are updated
daniil-lyakhov Jun 26, 2024
988c3a2
torch-fx tests are added to pre-commit
daniil-lyakhov Jun 26, 2024
acdf603
Comments
daniil-lyakhov Jun 27, 2024
55658a5
Model transformer minor refactoring
daniil-lyakhov Jul 1, 2024
1a1ef4a
Comments
daniil-lyakhov Jul 2, 2024
ac37c1d
Comments
daniil-lyakhov Jul 2, 2024
c48ac6c
Rebase
daniil-lyakhov Jul 3, 2024
c2f3201
Docstrings / SDPA unfold is removed
daniil-lyakhov Jul 3, 2024
745e6f0
Rebase
daniil-lyakhov Jul 3, 2024
6ba8a34
Comments
daniil-lyakhov Jul 4, 2024
b56a3a4
Add pytest for resnet for FX
anzr299 Jul 8, 2024
5c0d504
Add ruff fix for unused variables
anzr299 Jul 8, 2024
125a10f
add test for mobilenet_v3_small and vit_b_16.
anzr299 Jul 8, 2024
d9e4c41
Edit operator metatypes to include torch.nn.hardswish_ and torch.drop…
anzr299 Jul 8, 2024
e532407
Remove debug print statements
anzr299 Jul 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
2 changes: 1 addition & 1 deletion .github/workflows/precommit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ jobs:
lfs: true
- uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # v5.0.0
with:
python-version: 3.8.18
python-version: 3.9.19
cache: pip
- name: Install NNCF and test requirements
run: make install-tensorflow-test
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ conda install -c conda-forge nncf
- Python\* 3.8 or later
- Supported frameworks:
- PyTorch\* >=2.2, <2.4
- TensorFlow\* >=2.8.4, <=2.12.1
- TensorFlow\* >=2.8.4, <=2.15.1
- ONNX\* ==1.16.0
- OpenVINO\* >=2022.3.0

Expand Down
3 changes: 2 additions & 1 deletion constraints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ onnx==1.16.0
onnxruntime==1.17.1

# TensorFlow
tensorflow==2.12.1
tensorflow==2.12.1; python_version < '3.9'
tensorflow==2.15.1; python_version >= '3.9'

# Tests and examples
pytest==8.0.2
Expand Down
4 changes: 3 additions & 1 deletion docs/Installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ as well as the supported versions of Python:

| NNCF | OpenVINO | PyTorch | ONNX | TensorFlow | Python |
|-----------|------------|----------|----------|------------|--------|
| `develop` | `2024.2.0` | `2.3.0` | `1.16.0` | `2.12.0` | `3.8` |
| `develop` | `2024.2.0` | `2.3.0` | `1.16.0` | `2.15.1` | `3.8`* |
| `2.11.0` | `2024.2.0` | `2.3.0` | `1.16.0` | `2.12.0` | `3.8` |
| `2.10.0` | `2024.1.0` | `2.2.1` | `1.16.0` | `2.12.0` | `3.8` |
| `2.9.0` | `2024.0.0` | `2.1.2` | `1.13.1` | `2.12.0` | `3.8` |
Expand All @@ -53,3 +53,5 @@ as well as the supported versions of Python:
| `2.6.0` | `2023.1.0` | `2.0.1` | `1.13.1` | `2.12.0` | `3.8` |
| `2.5.0` | `2023.0.0` | `1.13.1` | `1.13.1` | `2.11.1` | `3.8` |
| `2.4.0` | `2022.1.0` | `1.12.1` | `1.12.0` | `2.8.2` | `3.8` |

> (*) Python 3.9 or higher is required for TensorFlow 2.15.1
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
tensorflow~=2.12.0
tensorflow~=2.12.0; python_version < '3.9'
tensorflow~=2.15.1; python_version >= '3.9'
tensorflow-datasets
tqdm
openvino==2024.2
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@

import tensorflow as tf
import tensorflow.keras.backend as K
from packaging import version

from examples.tensorflow.common.object_detection.architecture import nn_ops

tensorflow_version = version.parse(version.parse(tf.__version__).base_version)


class CSPDarknet53:
"""Class to build CSPDarknet53"""
Expand All @@ -25,12 +28,17 @@ def DarknetConv2D_BN_Mish(self, *args, **kwargs):
"""Darknet Convolution2D followed by SyncBatchNormalization and Mish."""
no_bias_kwargs = {"use_bias": False}
no_bias_kwargs.update(kwargs)

if tensorflow_version < version.parse("2.15"):
mish = tf.keras.layers.Activation(self.mish)
else:
mish = tf.keras.layers.Activation("mish")

return nn_ops.compose(
nn_ops.DarknetConv2D(*args, **no_bias_kwargs),
# TODO(nsavelyev) replace by BatchNormalization(synchronized=True) once support for TF < 2.12 is dropped
tf.keras.layers.experimental.SyncBatchNormalization(),
# TODO(nsavelyev) change to tf.keras.activations.mish after upgrade to TF 2.13
tf.keras.layers.Activation(self.mish),
mish,
)

def csp_resblock_body(self, x, num_filters, num_blocks, all_narrow=True):
Expand Down
3 changes: 2 additions & 1 deletion examples/tensorflow/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ absl-py==1.0.0
tensorflow
tensorflow_datasets==4.2.0
tensorflow_hub
tensorflow_addons==0.20.0
tensorflow_addons==0.20.0; python_version < '3.9'
tensorflow_addons==0.23.0; python_version >= '3.9'
tensorflow-metadata==1.13.0
opencv-python
pycocotools==2.0.6
14 changes: 13 additions & 1 deletion nncf/common/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def create(model: TModel) -> NNCFGraph:
if model_backend == BackendType.OPENVINO:
from nncf.openvino.graph.nncf_graph_builder import GraphConverter

return GraphConverter.create_nncf_graph(model)
if model_backend == BackendType.TORCH_FX:
from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter

return GraphConverter.create_nncf_graph(model)
if model_backend == BackendType.TORCH:
return model.nncf.get_graph()
Expand Down Expand Up @@ -72,6 +76,10 @@ def create(model: TModel, inplace: bool = False) -> ModelTransformer:
from nncf.torch.model_transformer import PTModelTransformer

return PTModelTransformer(model)
if model_backend == BackendType.TORCH_FX:
from nncf.experimental.torch.fx.model_transformer import FXModelTransformer

return FXModelTransformer(model)
raise nncf.UnsupportedBackendError(
"Cannot create backend-specific model transformer because {} is not supported!".format(model_backend.value)
)
Expand All @@ -95,7 +103,7 @@ def create(model: TModel) -> Engine:
from nncf.openvino.engine import OVNativeEngine

return OVNativeEngine(model)
if model_backend == BackendType.TORCH:
if model_backend in (BackendType.TORCH, BackendType.TORCH_FX):
from nncf.torch.engine import PTEngine

return PTEngine(model)
Expand Down Expand Up @@ -151,6 +159,10 @@ def create(model: TModel, dataset: Dataset) -> aggregator.StatisticsAggregator:
from nncf.torch.statistics.aggregator import PTStatisticsAggregator

return PTStatisticsAggregator(dataset)
if model_backend == BackendType.TORCH_FX:
from nncf.experimental.torch.fx.statistics.aggregator import FXStatisticsAggregator

return FXStatisticsAggregator(dataset)
raise nncf.UnsupportedBackendError(
"Cannot create backend-specific statistics aggregator because {} is not supported!".format(
model_backend.value
Expand Down
4 changes: 2 additions & 2 deletions nncf/common/graph/patterns/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _get_backend_hw_patterns_map(backend: BackendType) -> Dict[HWFusedPatternNam

registry = OPENVINO_HW_FUSED_PATTERNS.registry_dict
return registry
if backend == BackendType.TORCH:
if backend in (BackendType.TORCH, BackendType.TORCH_FX):
from nncf.torch.hardware.fused_patterns import PT_HW_FUSED_PATTERNS

registry = PT_HW_FUSED_PATTERNS.registry_dict
Expand Down Expand Up @@ -73,7 +73,7 @@ def _get_backend_ignored_patterns_map(

registry = OPENVINO_IGNORED_PATTERNS.registry_dict
return registry
if backend == BackendType.TORCH:
if backend in (BackendType.TORCH, BackendType.TORCH_FX):
from nncf.torch.quantization.ignored_patterns import PT_IGNORED_PATTERNS

registry = PT_IGNORED_PATTERNS.registry_dict
Expand Down
24 changes: 21 additions & 3 deletions nncf/common/utils/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

class BackendType(Enum):
TORCH = "Torch"
TORCH_FX = "TorchFX"
TENSORFLOW = "Tensorflow"
ONNX = "ONNX"
OPENVINO = "OpenVINO"
Expand All @@ -33,6 +34,7 @@ def get_available_backends() -> List[BackendType]:
"""
frameworks = [
("torch", BackendType.TORCH),
("torch.fx", BackendType.TORCH_FX),
("tensorflow", BackendType.TENSORFLOW),
("onnx", BackendType.ONNX),
("openvino.runtime", BackendType.OPENVINO),
Expand All @@ -51,14 +53,27 @@ def get_available_backends() -> List[BackendType]:

def is_torch_model(model: TModel) -> bool:
"""
Returns True if the model is an instance of torch.nn.Module, otherwise False.
Returns True if the model is an instance of torch.nn.Module and not a torch.fx.GraphModule, otherwise False.

:param model: A target model.
:return: True if the model is an instance of torch.nn.Module, otherwise False.
:return: True if the model is an instance of torch.nn.Module and not torch.fx.GraphModule, otherwise False.
"""
import torch
import torch.fx

return isinstance(model, torch.nn.Module)
return not isinstance(model, torch.fx.GraphModule) and isinstance(model, torch.nn.Module)


def is_torch_fx_model(model: TModel) -> bool:
"""
Returns True if the model is an instance of torch.fx.GraphModule, otherwise False.

:param model: A target model.
:return: True if the model is an instance of torch.fx.GraphModule, otherwise False.
"""
import torch.fx

return isinstance(model, torch.fx.GraphModule)


def is_tensorflow_model(model: TModel) -> bool:
Expand Down Expand Up @@ -118,6 +133,9 @@ def get_backend(model: TModel) -> BackendType:
"""
available_backends = get_available_backends()

if BackendType.TORCH_FX in available_backends and is_torch_fx_model(model):
return BackendType.TORCH_FX

if BackendType.TORCH in available_backends and is_torch_model(model):
return BackendType.TORCH

Expand Down
10 changes: 10 additions & 0 deletions nncf/experimental/torch/fx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
112 changes: 112 additions & 0 deletions nncf/experimental/torch/fx/model_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict
from typing import Callable, List, Union

import torch
import torch.fx
from torch.fx.passes.split_utils import split_by_tags

from nncf.common.graph.model_transformer import ModelTransformer
from nncf.common.graph.transformations.commands import Command
from nncf.common.graph.transformations.commands import TransformationPriority
from nncf.common.graph.transformations.commands import TransformationType
from nncf.torch.graph.transformations.commands import PTModelExtractionCommand
from nncf.torch.graph.transformations.layout import PTTransformationLayout


class FXApplyTransformationCommand(Command):
def __init__(
self,
transformation_fn: Callable[[torch.fx.GraphModule], None],
priority: Union[TransformationPriority, int] = TransformationPriority.DEFAULT_PRIORITY,
):
super().__init__(TransformationType.INSERT)
self.tranformation_fn = transformation_fn
self.priority = priority


class FXModelTransformer(ModelTransformer):
"""
Applies transformations upon Torch FX model.
"""

def __init__(self, model: torch.fx.GraphModule):
super().__init__(model)

self._command_transformation_ordered_pairs = [
(FXApplyTransformationCommand, self._apply_transformation),
(PTModelExtractionCommand, self._apply_model_extraction),
]

def transform(self, transformation_layout: PTTransformationLayout) -> torch.fx.GraphModule:
# TODO(dlyakhov): Manage priorities of transformations.
transformations = transformation_layout.transformations
aggregated_transformations = defaultdict(list)
for transformation in transformations:
aggregated_transformations[transformation.__class__].append(transformation)

model = self._model
for transformation_cls, transformation_fn in self._command_transformation_ordered_pairs:
transformations = aggregated_transformations[transformation_cls]
if transformations:
model = transformation_fn(model, transformations)

# Do not use model.graph.eliminate_dead_code()
# because the computational statistics code
# is interpolated as dead code.
model.recompile()
return model

@staticmethod
def _apply_model_extraction(
model: torch.fx.GraphModule,
transformations: List[PTModelExtractionCommand],
) -> torch.fx.GraphModule:
transformation = transformations[-1]
assert len(transformation.input_node_names) == 1
assert transformation.input_node_names == transformation.output_node_names
node_name = transformation.input_node_names[0]

tags = ["before", "extracted", "after"]
i = 0
for node in model.graph.nodes:
if node.name == node_name:
node.tag = tags[1]
weights = [node.all_input_nodes[1]]
while weights:
w_node = weights.pop()
assert w_node.tag in tags[0:2]
w_node.tag = tags[1]
weights.extend(w_node.all_input_nodes)
i = 2
continue
node.tag = tags[i]

splitted_gm = split_by_tags(model, tags)
return splitted_gm.extracted

@staticmethod
def get_graph_node_by_name(graph: torch.fx.Graph, name: str) -> torch.fx.Node:
for node in graph.nodes:
if node.name == name:
return node
raise RuntimeError(f"Node with name {name} is not found")

@staticmethod
def _apply_transformation(
model: torch.fx.GraphModule,
transformations: List[FXApplyTransformationCommand],
) -> torch.fx.GraphModule:
for transformation in transformations:
transformation.tranformation_fn(model)
return model
Loading
Loading