Skip to content

Commit

Permalink
add test on scale estimation
Browse files Browse the repository at this point in the history
  • Loading branch information
kshpv committed Jan 24, 2025
1 parent 568809c commit 8c7efd6
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 25 deletions.
1 change: 0 additions & 1 deletion nncf/torch/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,4 @@ def infer(
return self._model(**input_data)
if isinstance(input_data, tuple):
return self._model(*input_data)

return self._model(input_data)
45 changes: 41 additions & 4 deletions tests/cross_fw/test_templates/template_test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import math
from abc import ABC
from abc import abstractmethod
from copy import deepcopy
from typing import List, TypeVar

import numpy as np
Expand Down Expand Up @@ -143,24 +144,60 @@ def get_scale_estimation_ref():
"""

def test_scale_estimation(self, mocker):
"""Checks that scales match the reference."""
calc_q_params_spy = mocker.spy(ScaleEstimation, "calculate_quantization_params")
model = self.get_model_for_test_scale_estimation()

# prepare dataset with one input tensor
input = np.arange(0, 8 * 8, dtype=np.float32).reshape(1, 8, 8)
input[0, 4] *= 100 # make one channel relatively higher.

input = np.arange(0, 4 * 8, dtype=np.float32).reshape(1, 4, 8)
input = self.to_tensor(input)
dataset = Dataset([input])

_ = compress_weights(
model,
mode=CompressWeightsMode.INT4_ASYM,
ratio=1.0,
group_size=4,
group_size=8,
scale_estimation=True,
all_layers=True,
dataset=dataset,
)
reference = self.get_scale_estimation_ref()
assert fns.allclose(Tensor(reference), calc_q_params_spy.spy_return[0])

@abstractmethod
def get_orig_weight(model: TModel) -> Tensor:
"""Returns original weight."""

@abstractmethod
def get_decompressed_weight(compressed_model: TModel, input: TTensor) -> Tensor:
"""Returns decompressed weight"""

def test_scale_estimation_outlier_channel_has_lowest_error(self):
"""Checks that outlier channel has a lowest error after quantization."""
OUTLIER_CHANNEL = 4
model = self.get_model_for_test_scale_estimation()

# prepare dataset with one input tensor
input = np.arange(0, 4 * 8, dtype=np.float32).reshape(1, 4, 8)
input[
:, :, OUTLIER_CHANNEL
] *= 1000 # make one channel relatively higher. This channel should have lowest error.
input = self.to_tensor(input)
dataset = Dataset([input])

compressed_model = compress_weights(
deepcopy(model),
mode=CompressWeightsMode.INT4_ASYM,
ratio=1.0,
group_size=-1,
scale_estimation=True,
all_layers=True,
dataset=dataset,
)

decompressed_weight = self.get_decompressed_weight(compressed_model, input)
original_weight = self.get_orig_weight(model)
diff = (decompressed_weight - original_weight) ** 2
layer_err = fns.mean(diff, axis=0) / fns.mean(original_weight**2, axis=0)
assert fns.argsort(layer_err)[0] == OUTLIER_CHANNEL
4 changes: 2 additions & 2 deletions tests/openvino/native/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,9 +1189,9 @@ def _create_ov_model(self):

class MatMul(OVReferenceModel):
def _create_ov_model(self):
input_node = opset.parameter([1, 8, 8], name="Input")
input_node = opset.parameter([1, 4, 8], name="Input")

weights_data = np.arange(0, 8 * 8, dtype=np.float32).reshape(8, 8)
weights_data = np.arange(0, 16 * 8, dtype=np.float32).reshape(16, 8)
weights_node = opset.constant(weights_data, dtype=np.float32, name="Weights")

matmul_node = opset.matmul(input_node, weights_node, transpose_a=False, transpose_b=True, name="MatMul")
Expand Down
43 changes: 35 additions & 8 deletions tests/openvino/native/quantization/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from nncf.common.utils.debug import nncf_debug
from nncf.data.dataset import Dataset
from nncf.experimental.common.tensor_statistics.collectors import AggregatorBase
from nncf.openvino.graph.model_transformer import OVModelTransformer
from nncf.openvino.graph.node_utils import get_const_value
from nncf.parameters import BackupMode
from nncf.quantization import compress_weights
Expand Down Expand Up @@ -1524,13 +1525,39 @@ def get_model_for_test_scale_estimation():
def get_scale_estimation_ref():
return np.array(
[
[[0.2], [0.41354424]],
[[0.6782236], [0.9470368]],
[[1.1691767], [1.4355733]],
[[1.7025099], [1.9689066]],
[[2.2722175], [2.543369]],
[[2.8146443], [3.0858421]],
[[3.3025098], [3.5689068]],
[[3.8358433], [4.1022396]],
[[0.473328]],
[[0.929023]],
[[1.446527]],
[[1.920595]],
[[2.517053]],
[[3.030101]],
[[3.584278]],
[[4.04351]],
[[4.620007]],
[[5.165322]],
[[5.710637]],
[[6.122580]],
[[6.655914]],
[[7.237173]],
[[7.722581]],
[[8.255914]],
]
)

@staticmethod
def get_orig_weight(model: ov.Model) -> Tensor:
for op in model.get_ordered_ops():
op_name = op.get_friendly_name()
if op.get_type_name() == "Constant" and op_name == "Weights":
return Tensor(op.data)

@staticmethod
def get_decompressed_weight(compressed_model: ov.Model, input: np.ndarray) -> Tensor:
# Insert extra output to get the compressed weights.
node = [op for op in compressed_model.get_ops() if op.get_friendly_name() == "Weights/fq_weights_1/convert"][0]
output = node.output(0)
extra_outputs = [(output, 0, None)]
model = OVModelTransformer._insert_outputs(compressed_model, extra_outputs)
compiled_model = ov.compile_model(model, device_name="CPU")
weight_output = compiled_model(input)[1]
return Tensor(weight_output)
43 changes: 33 additions & 10 deletions tests/torch/ptq/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from nncf import SensitivityMetric
from nncf.quantization import compress_weights
from nncf.quantization.advanced_parameters import AdvancedCompressionParameters
from nncf.tensor import Tensor
from nncf.tensor import TensorDataType
from nncf.torch import wrap_model
from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor
Expand Down Expand Up @@ -63,6 +64,9 @@ def forward(self, x):
x = layer(x)
return x

def get_weight_names_in_exec_order(self):
return [f"{i}_weight" for i in range(len(self.main_values))]


class MatMulModel(torch.nn.Module):
def __init__(self, weight: torch.Tensor = torch.ones(size=(256, 256), dtype=torch.float32)):
Expand Down Expand Up @@ -375,27 +379,46 @@ def cast_to(x: torch.Tensor, dtype: TensorDataType) -> torch.Tensor:

@staticmethod
def check_weights(model: torch.nn.Module, ref_ids: List[int]) -> None:
low_precision_nodes = {f"{i}_weight" for i in ref_ids}
all_names = model.get_weight_names_in_exec_order()
low_precision_nodes = list(map(lambda i: all_names[i], ref_ids))
for op_name, op in model.nncf.external_op.items():
for name in low_precision_nodes:
if name in op_name:
assert isinstance(op, INT4SymmetricWeightsDecompressor)

@staticmethod
def get_model_for_test_scale_estimation():
return LinearModel(torch.arange(0, 8 * 8, dtype=torch.float32).reshape(8, 8))
return LinearModel(torch.arange(0, 8 * 16, dtype=torch.float32).reshape(16, 8))

@staticmethod
def get_scale_estimation_ref():
return torch.tensor(
[
[[0.200000], [0.413544]],
[[0.678224], [0.947037]],
[[1.169177], [1.435573]],
[[1.702510], [1.968907]],
[[2.272218], [2.543369]],
[[2.814644], [3.085842]],
[[3.302510], [3.568907]],
[[3.835843], [4.102240]],
[[0.473328]],
[[0.929023]],
[[1.446527]],
[[1.920595]],
[[2.517054]],
[[3.030102]],
[[3.584279]],
[[4.043509]],
[[4.620008]],
[[5.165322]],
[[5.710637]],
[[6.122581]],
[[6.655914]],
[[7.237174]],
[[7.722580]],
[[8.255914]],
]
)

@staticmethod
def get_orig_weight(model: torch.nn.Module) -> Tensor:
return Tensor(model.linear.weight)

@staticmethod
def get_decompressed_weight(compressed_model: torch.nn.Module, input: torch.Tensor) -> Tensor:
weight = compressed_model.linear.weight
unpacked_w = compressed_model.nncf.external_op.weights_decompressor_linear_weight(weight)
return Tensor(unpacked_w)

0 comments on commit 8c7efd6

Please sign in to comment.