Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mypy] statistics.py factory.py #3187

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 36 additions & 15 deletions nncf/common/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TypeVar
from typing import TypeVar, cast

import nncf
from nncf.common.engine import Engine
Expand All @@ -35,19 +35,27 @@ def create(model: TModel) -> NNCFGraph:
"""
model_backend = get_backend(model)
if model_backend == BackendType.ONNX:
from nncf.onnx.graph.nncf_graph_builder import GraphConverter
from onnx import ModelProto # type: ignore

return GraphConverter.create_nncf_graph(model)
from nncf.onnx.graph.nncf_graph_builder import GraphConverter as ONNXGraphConverter

return ONNXGraphConverter.create_nncf_graph(cast(ModelProto, model))
if model_backend == BackendType.OPENVINO:
from nncf.openvino.graph.nncf_graph_builder import GraphConverter
from openvino import Model # type: ignore

from nncf.openvino.graph.nncf_graph_builder import GraphConverter as OVGraphConverter

return GraphConverter.create_nncf_graph(model)
return OVGraphConverter.create_nncf_graph(cast(Model, model))
if model_backend == BackendType.TORCH_FX:
from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter
from torch.fx import GraphModule

from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter as FXGraphConverter

return GraphConverter.create_nncf_graph(model)
return FXGraphConverter.create_nncf_graph(cast(GraphModule, model))
if model_backend == BackendType.TORCH:
return model.nncf.get_graph()
from nncf.torch.nncf_network import NNCFNetwork

return cast(NNCFNetwork, model).nncf.get_graph()
raise nncf.UnsupportedBackendError(
"Cannot create backend-specific graph because {} is not supported!".format(model_backend.value)
)
Expand All @@ -65,21 +73,28 @@ def create(model: TModel, inplace: bool = False) -> ModelTransformer:
"""
model_backend = get_backend(model)
if model_backend == BackendType.ONNX:
from onnx import ModelProto

from nncf.onnx.graph.model_transformer import ONNXModelTransformer

return ONNXModelTransformer(model)
return ONNXModelTransformer(cast(ModelProto, model))
if model_backend == BackendType.OPENVINO:
from openvino import Model

from nncf.openvino.graph.model_transformer import OVModelTransformer

return OVModelTransformer(model, inplace=inplace)
return OVModelTransformer(cast(Model, model), inplace=inplace)
if model_backend == BackendType.TORCH:
from nncf.torch.model_transformer import PTModelTransformer
from nncf.torch.nncf_network import NNCFNetwork

return PTModelTransformer(model)
return PTModelTransformer(cast(NNCFNetwork, model))
if model_backend == BackendType.TORCH_FX:
from torch.fx import GraphModule

from nncf.experimental.torch.fx.model_transformer import FXModelTransformer

return FXModelTransformer(model)
return FXModelTransformer(cast(GraphModule, model))
raise nncf.UnsupportedBackendError(
"Cannot create backend-specific model transformer because {} is not supported!".format(model_backend.value)
)
Expand All @@ -96,17 +111,23 @@ def create(model: TModel) -> Engine:
"""
model_backend = get_backend(model)
if model_backend == BackendType.ONNX:
from onnx import ModelProto

from nncf.onnx.engine import ONNXEngine

return ONNXEngine(model)
return ONNXEngine(cast(ModelProto, model))
if model_backend == BackendType.OPENVINO:
from openvino import Model

from nncf.openvino.engine import OVNativeEngine

return OVNativeEngine(model)
return OVNativeEngine(cast(Model, model))
if model_backend in (BackendType.TORCH, BackendType.TORCH_FX):
from torch.nn import Module

from nncf.torch.engine import PTEngine

return PTEngine(model)
return PTEngine(cast(Module, model))
raise nncf.UnsupportedBackendError(
"Cannot create backend-specific engine because {} is not supported!".format(model_backend.value)
)
Expand Down
111 changes: 21 additions & 90 deletions nncf/common/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional
from dataclasses import dataclass
from dataclasses import fields
from typing import Iterator, Optional, Tuple

from nncf.api.statistics import Statistics
from nncf.common.pruning.statistics import FilterPruningStatistics
Expand All @@ -22,116 +24,45 @@


@api()
class NNCFStatistics(Statistics):
@dataclass
class NNCFStatistics:
"""
Groups statistics for all available NNCF compression algorithms.
Statistics are present only if the algorithm has been started.
"""

def __init__(self):
"""
Initializes nncf statistics.
"""
self._storage = {}

@property
def magnitude_sparsity(self) -> Optional[MagnitudeSparsityStatistics]:
"""
Returns statistics of the magnitude sparsity algorithm. If statistics
have not been collected, `None` will be returned.

:return: Instance of the `MagnitudeSparsityStatistics` class.
"""
return self._storage.get("magnitude_sparsity")

@property
def rb_sparsity(self) -> Optional[RBSparsityStatistics]:
"""
Returns statistics of the RB-sparsity algorithm. If statistics
have not been collected, `None` will be returned.

:return: Instance of the `RBSparsityStatistics` class.
"""
return self._storage.get("rb_sparsity")

@property
def movement_sparsity(self) -> Optional[MovementSparsityStatistics]:
"""
Returns statistics of the movement sparsity algorithm. If statistics
have not been collected, `None` will be returned.

:return: Instance of the `MovementSparsityStatistics` class.
"""
return self._storage.get("movement_sparsity")

@property
def const_sparsity(self) -> Optional[ConstSparsityStatistics]:
"""
Returns statistics of the const sparsity algorithm. If statistics
have not been collected, `None` will be returned.

:return: Instance of the `ConstSparsityStatistics` class.
"""
return self._storage.get("const_sparsity")
const_sparsity: Optional[ConstSparsityStatistics] = None
filter_pruning: Optional[FilterPruningStatistics] = None
magnitude_sparsity: Optional[MagnitudeSparsityStatistics] = None
movement_sparsity: Optional[MovementSparsityStatistics] = None
quantization: Optional[QuantizationStatistics] = None
rb_sparsity: Optional[RBSparsityStatistics] = None

@property
def quantization(self) -> Optional[QuantizationStatistics]:
"""
Returns statistics of the quantization algorithm. If statistics
have not been collected, `None` will be returned.

:return: Instance of the `QuantizationStatistics` class.
"""
return self._storage.get("quantization")

@property
def filter_pruning(self) -> Optional[FilterPruningStatistics]:
"""
Returns statistics of the filter pruning algorithm. If statistics
have not been collected, `None` will be returned.

:return: Instance of the `FilterPruningStatistics` class.
"""
return self._storage.get("filter_pruning")

def register(self, algorithm_name: str, stats: Statistics):
def register(self, algorithm_name: str, stats: Statistics) -> None:
"""
Registers statistics for the algorithm.

:param algorithm_name: Name of the algorithm. Should be one of the following
* magnitude_sparsity
* rb_sparsity
* const_sparsity
* quantization
* filter_pruning
* magnitude_sparsity
* movement_sparsity
* quantization
* rb_sparsity

:param stats: Statistics of the algorithm.
"""

available_algorithms = [
"magnitude_sparsity",
"rb_sparsity",
"movement_sparsity",
"const_sparsity",
"quantization",
"filter_pruning",
]
available_algorithms = [f.name for f in fields(self)]
if algorithm_name not in available_algorithms:
raise ValueError(
f"Can not register statistics for the algorithm. Unknown name of the algorithm: {algorithm_name}."
)

self._storage[algorithm_name] = stats
setattr(self, algorithm_name, stats)

def to_str(self) -> str:
"""
Calls `to_str()` method for all registered statistics of the algorithm and returns
a sum-up string.

:return: A representation of the NNCF statistics as a human-readable string.
"""
pretty_string = "\n\n".join([stats.to_str() for stats in self._storage.values()])
pretty_string = "\n\n".join([str(x[1].to_str()) for x in self])
kshpv marked this conversation as resolved.
Show resolved Hide resolved
return pretty_string

def __iter__(self):
return iter(self._storage.items())
def __iter__(self) -> Iterator[Tuple[str, Statistics]]:
return iter([(f.name, getattr(self, f.name)) for f in fields(self) if getattr(self, f.name) is not None])
11 changes: 3 additions & 8 deletions nncf/common/utils/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from functools import singledispatch
from typing import Any, Dict, Union

from nncf.api.statistics import Statistics
from nncf.common.pruning.statistics import FilterPruningStatistics
from nncf.common.sparsity.statistics import ConstSparsityStatistics
from nncf.common.sparsity.statistics import MagnitudeSparsityStatistics
Expand All @@ -27,7 +28,7 @@ def prepare_for_tensorboard(nncf_stats: NNCFStatistics) -> Dict[str, float]:
:param nncf_stats: NNCF Statistics.
:return: A dict storing name and value of the scalar.
"""
tensorboard_stats = {}
tensorboard_stats: Dict[str, float] = {}
for algorithm_name, stats in nncf_stats:
tensorboard_stats.update(convert_to_dict(stats, algorithm_name))

Expand All @@ -36,13 +37,7 @@ def prepare_for_tensorboard(nncf_stats: NNCFStatistics) -> Dict[str, float]:

@singledispatch
def convert_to_dict(
stats: Union[
FilterPruningStatistics,
MagnitudeSparsityStatistics,
RBSparsityStatistics,
ConstSparsityStatistics,
MovementSparsityStatistics,
],
stats: Statistics,
algorithm_name: str,
) -> Dict[Any, Any]:
return {}
Expand Down
5 changes: 3 additions & 2 deletions nncf/onnx/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict
from typing import Any, Dict

import numpy as np
import onnxruntime as rt
from onnx import ModelProto

from nncf.common.engine import Engine

Expand All @@ -22,7 +23,7 @@ class ONNXEngine(Engine):
Engine for ONNX backend using ONNXRuntime to infer the model.
"""

def __init__(self, model, **rt_session_options):
def __init__(self, model: ModelProto, **rt_session_options: Any):
self.input_names = set()
rt_session_options["providers"] = ["CPUExecutionProvider"]
serialized_model = model.SerializeToString()
Expand Down
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -96,19 +96,25 @@ files = [
"nncf/api",
"nncf/data",
"nncf/common/collector.py",
# "nncf/common/composite_compression.py",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Finally it soon should be squeezed to just nncf/common

# "nncf/common/compression.py",
# "nncf/common/deprecation.py",
"nncf/common/engine.py",
"nncf/common/exporter.py",
"nncf/common/factory.py",
"nncf/common/hook_handle.py",
"nncf/common/insertion_point_graph.py",
"nncf/common/logging/logger.py",
"nncf/common/plotting.py",
"nncf/common/schedulers.py",
"nncf/common/scopes.py",
"nncf/common/stateful_classes_registry.py",
"nncf/common/statistics.py",
"nncf/common/strip.py",
"nncf/common/tensor.py",
"nncf/common/accuracy_aware_training",
"nncf/common/graph",
"nncf/common/initialization",
"nncf/common/logging/logger.py",
"nncf/common/sparsity",
"nncf/common/tensor_statistics",
"nncf/common/utils",
Expand Down