Skip to content

Commit

Permalink
statistics
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderDokuchaev committed Jan 10, 2025
1 parent 983a94d commit d0ab1d9
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 99 deletions.
111 changes: 21 additions & 90 deletions nncf/common/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional
from dataclasses import dataclass
from dataclasses import fields
from typing import Iterator, Optional, Tuple

from nncf.api.statistics import Statistics
from nncf.common.pruning.statistics import FilterPruningStatistics
Expand All @@ -22,116 +24,45 @@


@api()
class NNCFStatistics(Statistics):
@dataclass
class NNCFStatistics:
"""
Groups statistics for all available NNCF compression algorithms.
Statistics are present only if the algorithm has been started.
"""

def __init__(self):
"""
Initializes nncf statistics.
"""
self._storage = {}

@property
def magnitude_sparsity(self) -> Optional[MagnitudeSparsityStatistics]:
"""
Returns statistics of the magnitude sparsity algorithm. If statistics
have not been collected, `None` will be returned.
:return: Instance of the `MagnitudeSparsityStatistics` class.
"""
return self._storage.get("magnitude_sparsity")

@property
def rb_sparsity(self) -> Optional[RBSparsityStatistics]:
"""
Returns statistics of the RB-sparsity algorithm. If statistics
have not been collected, `None` will be returned.
:return: Instance of the `RBSparsityStatistics` class.
"""
return self._storage.get("rb_sparsity")

@property
def movement_sparsity(self) -> Optional[MovementSparsityStatistics]:
"""
Returns statistics of the movement sparsity algorithm. If statistics
have not been collected, `None` will be returned.
:return: Instance of the `MovementSparsityStatistics` class.
"""
return self._storage.get("movement_sparsity")

@property
def const_sparsity(self) -> Optional[ConstSparsityStatistics]:
"""
Returns statistics of the const sparsity algorithm. If statistics
have not been collected, `None` will be returned.
:return: Instance of the `ConstSparsityStatistics` class.
"""
return self._storage.get("const_sparsity")
const_sparsity: Optional[ConstSparsityStatistics] = None
filter_pruning: Optional[FilterPruningStatistics] = None
magnitude_sparsity: Optional[MagnitudeSparsityStatistics] = None
movement_sparsity: Optional[MovementSparsityStatistics] = None
quantization: Optional[QuantizationStatistics] = None
rb_sparsity: Optional[RBSparsityStatistics] = None

@property
def quantization(self) -> Optional[QuantizationStatistics]:
"""
Returns statistics of the quantization algorithm. If statistics
have not been collected, `None` will be returned.
:return: Instance of the `QuantizationStatistics` class.
"""
return self._storage.get("quantization")

@property
def filter_pruning(self) -> Optional[FilterPruningStatistics]:
"""
Returns statistics of the filter pruning algorithm. If statistics
have not been collected, `None` will be returned.
:return: Instance of the `FilterPruningStatistics` class.
"""
return self._storage.get("filter_pruning")

def register(self, algorithm_name: str, stats: Statistics):
def register(self, algorithm_name: str, stats: Statistics) -> None:
"""
Registers statistics for the algorithm.
:param algorithm_name: Name of the algorithm. Should be one of the following
* magnitude_sparsity
* rb_sparsity
* const_sparsity
* quantization
* filter_pruning
* magnitude_sparsity
* movement_sparsity
* quantization
* rb_sparsity
:param stats: Statistics of the algorithm.
"""

available_algorithms = [
"magnitude_sparsity",
"rb_sparsity",
"movement_sparsity",
"const_sparsity",
"quantization",
"filter_pruning",
]
available_algorithms = [f.name for f in fields(self)]
if algorithm_name not in available_algorithms:
raise ValueError(
f"Can not register statistics for the algorithm. Unknown name of the algorithm: {algorithm_name}."
)

self._storage[algorithm_name] = stats
setattr(self, algorithm_name, stats)

def to_str(self) -> str:
"""
Calls `to_str()` method for all registered statistics of the algorithm and returns
a sum-up string.
:return: A representation of the NNCF statistics as a human-readable string.
"""
pretty_string = "\n\n".join([stats.to_str() for stats in self._storage.values()])
pretty_string = "\n\n".join([str(x[1].to_str()) for x in self])
return pretty_string

def __iter__(self):
return iter(self._storage.items())
def __iter__(self) -> Iterator[Tuple[str, Statistics]]:
return iter([(f.name, getattr(self, f.name)) for f in fields(self) if getattr(self, f.name) is not None])
11 changes: 3 additions & 8 deletions nncf/common/utils/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from functools import singledispatch
from typing import Any, Dict, Union

from nncf.api.statistics import Statistics
from nncf.common.pruning.statistics import FilterPruningStatistics
from nncf.common.sparsity.statistics import ConstSparsityStatistics
from nncf.common.sparsity.statistics import MagnitudeSparsityStatistics
Expand All @@ -27,7 +28,7 @@ def prepare_for_tensorboard(nncf_stats: NNCFStatistics) -> Dict[str, float]:
:param nncf_stats: NNCF Statistics.
:return: A dict storing name and value of the scalar.
"""
tensorboard_stats = {}
tensorboard_stats: Dict[str, float] = {}
for algorithm_name, stats in nncf_stats:
tensorboard_stats.update(convert_to_dict(stats, algorithm_name))

Expand All @@ -36,13 +37,7 @@ def prepare_for_tensorboard(nncf_stats: NNCFStatistics) -> Dict[str, float]:

@singledispatch
def convert_to_dict(
stats: Union[
FilterPruningStatistics,
MagnitudeSparsityStatistics,
RBSparsityStatistics,
ConstSparsityStatistics,
MovementSparsityStatistics,
],
stats: Statistics,
algorithm_name: str,
) -> Dict[Any, Any]:
return {}
Expand Down
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,19 @@ files = [
"nncf/api",
"nncf/data",
"nncf/common/collector.py",
# "nncf/common/composite_compression.py",
# "nncf/common/compression.py",
# "nncf/common/deprecation.py",
"nncf/common/engine.py",
"nncf/common/exporter.py",
# "nncf/common/factory.py",
"nncf/common/hook_handle.py",
"nncf/common/insertion_point_graph.py",
"nncf/common/logging/logger.py",
"nncf/common/plotting.py",
"nncf/common/schedulers.py",
"nncf/common/scopes.py",
"nncf/common/stateful_classes_registry.py",
"nncf/common/statistics.py",
"nncf/common/strip.py",
"nncf/common/tensor.py",
"nncf/common/accuracy_aware_training",
Expand Down

0 comments on commit d0ab1d9

Please sign in to comment.