diff --git a/nncf/common/factory.py b/nncf/common/factory.py index ca0b9d745bb..38a1671564b 100644 --- a/nncf/common/factory.py +++ b/nncf/common/factory.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TypeVar +from typing import TypeVar, cast import nncf from nncf.common.engine import Engine @@ -35,19 +35,27 @@ def create(model: TModel) -> NNCFGraph: """ model_backend = get_backend(model) if model_backend == BackendType.ONNX: - from nncf.onnx.graph.nncf_graph_builder import GraphConverter + from onnx import ModelProto # type: ignore - return GraphConverter.create_nncf_graph(model) + from nncf.onnx.graph.nncf_graph_builder import GraphConverter as ONNXGraphConverter + + return ONNXGraphConverter.create_nncf_graph(cast(ModelProto, model)) if model_backend == BackendType.OPENVINO: - from nncf.openvino.graph.nncf_graph_builder import GraphConverter + from openvino import Model # type: ignore + + from nncf.openvino.graph.nncf_graph_builder import GraphConverter as OVGraphConverter - return GraphConverter.create_nncf_graph(model) + return OVGraphConverter.create_nncf_graph(cast(Model, model)) if model_backend == BackendType.TORCH_FX: - from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter + from torch.fx import GraphModule + + from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter as FXGraphConverter - return GraphConverter.create_nncf_graph(model) + return FXGraphConverter.create_nncf_graph(cast(GraphModule, model)) if model_backend == BackendType.TORCH: - return model.nncf.get_graph() + from nncf.torch.nncf_network import NNCFNetwork + + return cast(NNCFNetwork, model).nncf.get_graph() raise nncf.UnsupportedBackendError( "Cannot create backend-specific graph because {} is not supported!".format(model_backend.value) ) @@ -65,21 +73,28 @@ def create(model: TModel, inplace: bool = False) -> ModelTransformer: """ model_backend = get_backend(model) if model_backend == BackendType.ONNX: + from onnx import ModelProto + from nncf.onnx.graph.model_transformer import ONNXModelTransformer - return ONNXModelTransformer(model) + return ONNXModelTransformer(cast(ModelProto, model)) if model_backend == BackendType.OPENVINO: + from openvino import Model + from nncf.openvino.graph.model_transformer import OVModelTransformer - return OVModelTransformer(model, inplace=inplace) + return OVModelTransformer(cast(Model, model), inplace=inplace) if model_backend == BackendType.TORCH: from nncf.torch.model_transformer import PTModelTransformer + from nncf.torch.nncf_network import NNCFNetwork - return PTModelTransformer(model) + return PTModelTransformer(cast(NNCFNetwork, model)) if model_backend == BackendType.TORCH_FX: + from torch.fx import GraphModule + from nncf.experimental.torch.fx.model_transformer import FXModelTransformer - return FXModelTransformer(model) + return FXModelTransformer(cast(GraphModule, model)) raise nncf.UnsupportedBackendError( "Cannot create backend-specific model transformer because {} is not supported!".format(model_backend.value) ) @@ -96,17 +111,23 @@ def create(model: TModel) -> Engine: """ model_backend = get_backend(model) if model_backend == BackendType.ONNX: + from onnx import ModelProto + from nncf.onnx.engine import ONNXEngine - return ONNXEngine(model) + return ONNXEngine(cast(ModelProto, model)) if model_backend == BackendType.OPENVINO: + from openvino import Model + from nncf.openvino.engine import OVNativeEngine - return OVNativeEngine(model) + return OVNativeEngine(cast(Model, model)) if model_backend in (BackendType.TORCH, BackendType.TORCH_FX): + from torch.nn import Module + from nncf.torch.engine import PTEngine - return PTEngine(model) + return PTEngine(cast(Module, model)) raise nncf.UnsupportedBackendError( "Cannot create backend-specific engine because {} is not supported!".format(model_backend.value) ) diff --git a/nncf/onnx/engine.py b/nncf/onnx/engine.py index 5b21a822616..543f9c73a6b 100644 --- a/nncf/onnx/engine.py +++ b/nncf/onnx/engine.py @@ -9,10 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict +from typing import Any, Dict import numpy as np import onnxruntime as rt +from onnx import ModelProto from nncf.common.engine import Engine @@ -22,7 +23,7 @@ class ONNXEngine(Engine): Engine for ONNX backend using ONNXRuntime to infer the model. """ - def __init__(self, model, **rt_session_options): + def __init__(self, model: ModelProto, **rt_session_options: Any): self.input_names = set() rt_session_options["providers"] = ["CPUExecutionProvider"] serialized_model = model.SerializeToString() diff --git a/pyproject.toml b/pyproject.toml index 87953dcdbe3..0912257a16d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,7 +101,7 @@ files = [ # "nncf/common/deprecation.py", "nncf/common/engine.py", "nncf/common/exporter.py", - # "nncf/common/factory.py", + "nncf/common/factory.py", "nncf/common/hook_handle.py", "nncf/common/insertion_point_graph.py", "nncf/common/plotting.py",