Skip to content

Commit

Permalink
Computation of compression parameters via OpenVINO models (openvinoto…
Browse files Browse the repository at this point in the history
…olkit#2727)

### Changes

- Implemented OpenVINO model graphs which are used for calculation of
compressed and decompressed weights. Since these models are compiled,
calculation become significantly faster especially for larger models and
int4 compression.
- This functionality is exposed by two methods at `weight_lowering.py`:
- `do_int_quantization()` is used for computing a compressed weight.
Possible signatures:
- `weight` -> `compressed_weight`, `scale`, (`zero_point` for asymmetric
compression)
- `weight`, `scale`, (`zero_point`) -> `compressed_weight`, `scale`,
(`zero_point`)
- `calculate_quantized_dequantized_weight()` is used for computing a
decompressed weight. Possible signatures:
    - `weight` -> `decompressed_weight`
    - `weight`, `scale`, (`zero_point`) -> `decompressed_weight`
- `weight` -> `decompressed_weight`, `compressed_weight`, `scale`,
(`zero_point`)
- `weight`, `scale`, (`zero_point`) -> `decompressed_weight`,
`compressed_weight`, `scale`, (`zero_point`)
- Output `scale` and `zero_point` are the same as the ones given as
input (if they were given at all).
- Computation is done via OV models only if openvino package is
installed and input tensors are not torch tensors.
- Introduce a new NNCF Tensor backend for storing instances of
`openvino.Tensor`. Implementation for this backend is limited by only
the required functionality, e.g. addition of OV Tensors is not supported
because it is not needed.
- Introduction of OV Tensors is required for seamless handling of
tensors in `bf16`, `u4` and `i4` data types. For example, `bf16`
constants are read from an OpenVINO LLM and given as inputs to a
compressing OpenVINO model. `u4` and `i4` compressed weights are
seamlessly inserted into the resulting compressed OpenVINO model.
- Added `as_numpy_tensor()` method to convert an NNCF Tensor to numpy
backend. Currently only OV -> NP conversion is required.
- All calculations are aligned with reference numpy implementation. Some
performance and memory sacrifices had to be made for such alignment.

Data-free asymmetric compression:

![image](https://github.com/user-attachments/assets/efd76b2f-1a3e-4037-8165-0bd5812de94d)

Data-free symmetric compression:

![image](https://github.com/user-attachments/assets/c61b98c6-cc96-4125-b21e-90c7d0827e22)

Data-aware compression:

![image](https://github.com/user-attachments/assets/b9823594-9915-4ca5-9e50-7bffa6777104)


### Reason for changes

Reducing model compression time. Only OpenVINO model compression backend
is affected.

### Related tickets

139047

### Tests

-
`tests/openvino/native/quantization/test_ov_modeling_compression.py::test_quantization_alignment`
-- check aligment with reference numpy implementation
- `tests/openvino/native/test_openvino_modeling.py` -- checks OV
modeling framework hyperparameters
- `tests/openvino/native/test_tensor.py` -- NNCF OV Tensor backend tests

Validation jobs:
- `NNCF/job/manual/job/post_training_weight_compression/299/`
- `NNCF/job/nightly/job/test_examples/650`
- OVVP validation ✅
- optimum-intel test job
https://github.com/huggingface/optimum-intel/actions/runs/12912964434/job/36009036879?pr=734
  • Loading branch information
nikita-savelyevv authored Jan 23, 2025
1 parent b6f2e75 commit f3f232f
Show file tree
Hide file tree
Showing 32 changed files with 2,195 additions and 293 deletions.
1 change: 1 addition & 0 deletions docs/api/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def collect_api_entities() -> APIInfo:
"nncf.tensor.functions.torch_linalg",
"nncf.tensor.functions.torch_io",
"nncf.tensor.functions.numpy_io",
"nncf.tensor.functions.ov_numeric",
]

with mock(mock_modules):
Expand Down
44 changes: 32 additions & 12 deletions nncf/common/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,41 @@

import logging
import sys
from typing import Set
from functools import lru_cache
from typing import cast


class NNCFLogger(logging.Logger):
def __init__(self, name: str, level: int = logging.NOTSET):
super().__init__(name, level)

@lru_cache(None)
def _log_once(self, level: int, msg: str) -> None:
self.log(level, msg)

def debug_once(self, msg: str) -> None:
"""
Log a message at the DEBUG level, ensuring the message is logged only once.
"""
self._log_once(logging.DEBUG, msg)

def info_once(self, msg: str) -> None:
"""
Log a message at the INFO level, ensuring the message is logged only once.
"""
self._log_once(logging.INFO, msg)

def warning_once(self, msg: str) -> None:
"""
Log a message at the WARNING level, ensuring the message is logged only once.
"""
self._log_once(logging.WARNING, msg)


NNCF_LOGGER_NAME = "nncf"

nncf_logger = logging.getLogger(NNCF_LOGGER_NAME)
logging.setLoggerClass(NNCFLogger)
nncf_logger = cast(NNCFLogger, logging.getLogger(NNCF_LOGGER_NAME))
nncf_logger.propagate = False

stdout_handler = logging.StreamHandler(sys.stdout)
Expand Down Expand Up @@ -60,16 +90,6 @@ def disable_logging() -> None:
nncf_logger.handlers = []


class DuplicateFilter:
def __init__(self) -> None:
self.msgs: Set[str] = set()

def filter(self, rec: logging.LogRecord) -> bool:
retval = rec.msg not in self.msgs
self.msgs.add(rec.msg)
return retval


NNCFDeprecationWarning = FutureWarning


Expand Down
15 changes: 15 additions & 0 deletions nncf/common/utils/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@

TModel = TypeVar("TModel")

try:
import openvino # type: ignore # noqa: F401

_OPENVINO_AVAILABLE = True
except ImportError:
_OPENVINO_AVAILABLE = False


class BackendType(Enum):
TORCH = "Torch"
Expand Down Expand Up @@ -159,3 +166,11 @@ def copy_model(model: TModel) -> TModel:
model = TFModelTransformer(model).transform(TFTransformationLayout())
return model
return deepcopy(model)


def is_openvino_available() -> bool:
"""
Check if OpenVINO is available.
:return: True if openvino package is installed, False otherwise.
"""
return _OPENVINO_AVAILABLE
103 changes: 103 additions & 0 deletions nncf/common/utils/caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) 2025 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import inspect
from contextlib import contextmanager
from functools import wraps
from typing import Any, Callable, Dict, Iterator, TypeVar, cast


class ResultsCache:
"""
A container for results decorated with @cache_results decorator.
"""

def __init__(self) -> None:
self._enabled = True
# Stores the results of the decorated function
self._cache: Dict[Any, Any] = {}
# Stores the number of times the cached result was accessed
self._access_count: Dict[Any, int] = {}

def enable(self) -> None:
self._enabled = True

def disable(self) -> None:
self._enabled = False

def enabled(self) -> bool:
return self._enabled

def access_count(self) -> Dict[Any, int]:
return copy.deepcopy(self._access_count)

def clear(self) -> None:
self._cache.clear()
self._access_count.clear()

def __getitem__(self, key: Any) -> Any:
self._access_count[key] += 1
return self._cache[key]

def __setitem__(self, key: Any, value: Any) -> None:
self._access_count[key] = 0
self._cache[key] = value

def __contains__(self, key: Any) -> bool:
return key in self._cache


TFunc = TypeVar("TFunc", bound=Callable[..., Any])


def cache_results(cache: ResultsCache) -> Callable[[TFunc], TFunc]:
"""
Decorator to cache results of a function. When decorated function is called with the same set of arguments, it
will return the cached result instead of recomputing it. If it was the first call with such set of arguments, the
result will be computed and stored in the cache. The cache is stored in the `cache` object. Function arguments
must be hashable.
:param cache: A cache container where results will be stored.
"""

def decorator(func: TFunc) -> TFunc:
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
if not cache.enabled():
return func(*args, **kwargs)
sig = inspect.signature(func)
new_kwargs = {name: arg for name, arg in zip(sig.parameters, args)}
new_kwargs.update(kwargs)
cache_key = (func.__name__, frozenset(new_kwargs.items()))
if cache_key in cache:
return cache[cache_key]
result = func(*args, **kwargs)
cache[cache_key] = result
return result

return cast(TFunc, wrapper)

return decorator


@contextmanager
def disable_results_caching(cache: ResultsCache) -> Iterator[None]:
"""
Context manager to disable caching of results for a block of code.
:param cache: A cache container where results are stored.
"""
if cache.enabled():
cache.disable()
yield
cache.enable()
else:
yield
57 changes: 50 additions & 7 deletions nncf/openvino/graph/node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import numpy as np
import openvino.runtime as ov
import openvino.runtime.op as op
import openvino.runtime.opset13 as opset

import nncf
Expand Down Expand Up @@ -41,6 +42,8 @@
from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVOpMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import get_node_metatype
from nncf.tensor import Tensor
from nncf.tensor import TensorBackend

InplaceInsertionFnType = Callable[[ov.Node, int, str], ov.Node]

Expand Down Expand Up @@ -97,26 +100,27 @@ def get_number_if_op(model: ov.Model) -> int:
"""

def cnt_if_op(model: ov.Model, cnt: int) -> int:
for op in model.get_ops():
if get_node_metatype(op) == OVIfMetatype:
for model_op in model.get_ops():
if get_node_metatype(model_op) == OVIfMetatype:
cnt += 1
cnt = cnt_if_op(op.get_function(0), cnt)
cnt = cnt_if_op(op.get_function(1), cnt)
cnt = cnt_if_op(model_op.get_function(0), cnt)
cnt = cnt_if_op(model_op.get_function(1), cnt)
return cnt

return cnt_if_op(model, 0)


def get_const_value(const_node: ov.Node) -> np.ndarray:
def get_const_value(const_node: ov.Node, cast_bf16_to_fp32: bool = True) -> np.ndarray:
"""
Returns the constant tensor for the node.
This method is applicable only for the floating-point constant data.
:param const_node: OpenVINO node.
:param cast_bf16_to_fp32: Whether to cast bf16 node data to fp32 or not. If False and the node contains bf16 data,
the resulting bf16 value will be returned encoded inside a numpy.float16 array.
:return: The constant value.
"""
if const_node.get_element_type() == ov.Type.bf16:
# Fixed FP32 data type as the result for BF16 constant
if const_node.get_element_type() == ov.Type.bf16 and cast_bf16_to_fp32:
return const_node.get_data(dtype=np.float32)
return const_node.data

Expand Down Expand Up @@ -635,3 +639,42 @@ def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: Tuple
channel_axis = activations_layout.index(OVLayoutElem.C_IN)

return channel_axis


def convert_op(node: ov.Node, target_dtype: ov.Type) -> ov.Node:
"""
Return a subgraph which converts the given node output to the target data type. If the output is already in the
target data type then the given node is returned.
:param node: The input node to convert.
:param target_dtype: The target data type to convert the input node to.
:return: The converted node.
"""
if node.get_element_type() == target_dtype:
return node
return opset.convert(node, target_dtype)


def non_convertable_divide_op(a: ov.Node, b: ov.Node) -> ov.Node:
"""
Creates a "non-convertable" divide operation. It won't be converted to a*(1/b).
"""
divide_node = a / b
divide_node.get_rt_info()["nonconvertable_divide_0"] = True
return divide_node


def create_ov_const_from_tensor(x: Tensor, dtype: ov.Type, name: Optional[str] = None) -> op.Constant:
"""
Create an OpenVINO Constant node from the given tensor.
:param x: Data tensor. Supports NumPy and OV tensor backends. If x backend is OV, the constant node is created
directly from underlying OV tensor.
:param dtype: Data type of the constant.
:param name: Optional name of the constant.
:return: OpenVINO Constant node.
"""
if x.backend == TensorBackend.ov:
assert x.data.get_element_type() == dtype
return opset.constant(x.data, name=name, shared_memory=True)
const = opset.constant(x.data, dtype=dtype, name=name)
return const
16 changes: 16 additions & 0 deletions nncf/openvino/optimized_functions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) 2025 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nncf.openvino.optimized_functions.functions import astype as astype
from nncf.openvino.optimized_functions.functions import do_int_quantization as do_int_quantization
from nncf.openvino.optimized_functions.functions import quantize_dequantize_weight as quantize_dequantize_weight
from nncf.openvino.optimized_functions.models import OVModelParameters as OVModelParameters
from nncf.openvino.optimized_functions.models import clear_ov_model_cache as clear_ov_model_cache
Loading

0 comments on commit f3f232f

Please sign in to comment.