Skip to content

Commit

Permalink
factory
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderDokuchaev committed Jan 10, 2025
1 parent d0ab1d9 commit d2efe95
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 19 deletions.
51 changes: 36 additions & 15 deletions nncf/common/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TypeVar
from typing import TypeVar, cast

import nncf
from nncf.common.engine import Engine
Expand All @@ -35,19 +35,27 @@ def create(model: TModel) -> NNCFGraph:
"""
model_backend = get_backend(model)
if model_backend == BackendType.ONNX:
from nncf.onnx.graph.nncf_graph_builder import GraphConverter
from onnx import ModelProto # type: ignore

return GraphConverter.create_nncf_graph(model)
from nncf.onnx.graph.nncf_graph_builder import GraphConverter as ONNXGraphConverter

return ONNXGraphConverter.create_nncf_graph(cast(ModelProto, model))
if model_backend == BackendType.OPENVINO:
from nncf.openvino.graph.nncf_graph_builder import GraphConverter
from openvino import Model # type: ignore

from nncf.openvino.graph.nncf_graph_builder import GraphConverter as OVGraphConverter

return GraphConverter.create_nncf_graph(model)
return OVGraphConverter.create_nncf_graph(cast(Model, model))
if model_backend == BackendType.TORCH_FX:
from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter
from torch.fx import GraphModule

from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter as FXGraphConverter

return GraphConverter.create_nncf_graph(model)
return FXGraphConverter.create_nncf_graph(cast(GraphModule, model))
if model_backend == BackendType.TORCH:
return model.nncf.get_graph()
from nncf.torch.nncf_network import NNCFNetwork

return cast(NNCFNetwork, model).nncf.get_graph()
raise nncf.UnsupportedBackendError(
"Cannot create backend-specific graph because {} is not supported!".format(model_backend.value)
)
Expand All @@ -65,21 +73,28 @@ def create(model: TModel, inplace: bool = False) -> ModelTransformer:
"""
model_backend = get_backend(model)
if model_backend == BackendType.ONNX:
from onnx import ModelProto

from nncf.onnx.graph.model_transformer import ONNXModelTransformer

return ONNXModelTransformer(model)
return ONNXModelTransformer(cast(ModelProto, model))
if model_backend == BackendType.OPENVINO:
from openvino import Model

from nncf.openvino.graph.model_transformer import OVModelTransformer

return OVModelTransformer(model, inplace=inplace)
return OVModelTransformer(cast(Model, model), inplace=inplace)
if model_backend == BackendType.TORCH:
from nncf.torch.model_transformer import PTModelTransformer
from nncf.torch.nncf_network import NNCFNetwork

return PTModelTransformer(model)
return PTModelTransformer(cast(NNCFNetwork, model))
if model_backend == BackendType.TORCH_FX:
from torch.fx import GraphModule

from nncf.experimental.torch.fx.model_transformer import FXModelTransformer

return FXModelTransformer(model)
return FXModelTransformer(cast(GraphModule, model))
raise nncf.UnsupportedBackendError(
"Cannot create backend-specific model transformer because {} is not supported!".format(model_backend.value)
)
Expand All @@ -96,17 +111,23 @@ def create(model: TModel) -> Engine:
"""
model_backend = get_backend(model)
if model_backend == BackendType.ONNX:
from onnx import ModelProto

from nncf.onnx.engine import ONNXEngine

return ONNXEngine(model)
return ONNXEngine(cast(ModelProto, model))
if model_backend == BackendType.OPENVINO:
from openvino import Model

from nncf.openvino.engine import OVNativeEngine

return OVNativeEngine(model)
return OVNativeEngine(cast(Model, model))
if model_backend in (BackendType.TORCH, BackendType.TORCH_FX):
from torch.nn import Module

from nncf.torch.engine import PTEngine

return PTEngine(model)
return PTEngine(cast(Module, model))
raise nncf.UnsupportedBackendError(
"Cannot create backend-specific engine because {} is not supported!".format(model_backend.value)
)
Expand Down
5 changes: 3 additions & 2 deletions nncf/onnx/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict
from typing import Any, Dict

import numpy as np
import onnxruntime as rt
from onnx import ModelProto

from nncf.common.engine import Engine

Expand All @@ -22,7 +23,7 @@ class ONNXEngine(Engine):
Engine for ONNX backend using ONNXRuntime to infer the model.
"""

def __init__(self, model, **rt_session_options):
def __init__(self, model: ModelProto, **rt_session_options: Any):
self.input_names = set()
rt_session_options["providers"] = ["CPUExecutionProvider"]
serialized_model = model.SerializeToString()
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ files = [
"nncf/common/collector.py",
# "nncf/common/composite_compression.py",
# "nncf/common/compression.py",
# "nncf/common/deprecation.py",
"nncf/common/deprecation.py",
"nncf/common/engine.py",
"nncf/common/exporter.py",
# "nncf/common/factory.py",
"nncf/common/factory.py",
"nncf/common/hook_handle.py",
"nncf/common/insertion_point_graph.py",
"nncf/common/plotting.py",
Expand All @@ -114,6 +114,7 @@ files = [
"nncf/common/accuracy_aware_training",
"nncf/common/graph",
"nncf/common/initialization",
"nncf/common/logging/logger.py",
"nncf/common/sparsity",
"nncf/common/tensor_statistics",
"nncf/common/utils",
Expand Down

0 comments on commit d2efe95

Please sign in to comment.