diff --git a/docs/api/source/conf.py b/docs/api/source/conf.py index 55a40d04220..b623b1a9f4e 100644 --- a/docs/api/source/conf.py +++ b/docs/api/source/conf.py @@ -145,6 +145,9 @@ def collect_api_entities() -> APIInfo: "nncf.tensor.functions.torch_linalg", "nncf.tensor.functions.torch_io", "nncf.tensor.functions.numpy_io", + "nncf.tensor.functions.tf_numeric", + "nncf.tensor.functions.tf_io", + "nncf.tensor.functions.tf_linalg", ] with mock(mock_modules): diff --git a/nncf/tensor/definitions.py b/nncf/tensor/definitions.py index cf9518f0ea4..7a43335cacd 100644 --- a/nncf/tensor/definitions.py +++ b/nncf/tensor/definitions.py @@ -20,6 +20,7 @@ class TensorBackend(Enum): """ numpy = auto() + tf = auto() torch = auto() diff --git a/nncf/tensor/functions/__init__.py b/nncf/tensor/functions/__init__.py index 9b6b66df746..baa207c6e17 100644 --- a/nncf/tensor/functions/__init__.py +++ b/nncf/tensor/functions/__init__.py @@ -74,6 +74,11 @@ def _initialize_backends(): import nncf.tensor.functions.numpy_linalg import nncf.tensor.functions.numpy_numeric + with contextlib.suppress(ImportError): + import nncf.tensor.functions.tf_io + import nncf.tensor.functions.tf_linalg + import nncf.tensor.functions.tf_numeric + with contextlib.suppress(ImportError): import nncf.tensor.functions.torch_io import nncf.tensor.functions.torch_linalg diff --git a/nncf/tensor/functions/dispatcher.py b/nncf/tensor/functions/dispatcher.py index 193d5a2b15a..df2c23e54b7 100644 --- a/nncf/tensor/functions/dispatcher.py +++ b/nncf/tensor/functions/dispatcher.py @@ -97,6 +97,10 @@ def get_numeric_backend_fn(fn_name: str, backend: TensorBackend) -> Callable: from nncf.tensor.functions import torch_numeric return getattr(torch_numeric, fn_name) + if backend == TensorBackend.tf: + from nncf.tensor.functions import tf_numeric + + return getattr(tf_numeric, fn_name) def get_io_backend_fn(fn_name: str, backend: TensorBackend) -> Callable: @@ -111,6 +115,10 @@ def get_io_backend_fn(fn_name: str, backend: TensorBackend) -> Callable: from nncf.tensor.functions import numpy_io return getattr(numpy_io, fn_name) + if backend == TensorBackend.tf: + from nncf.tensor.functions import tf_io + + return getattr(tf_io, fn_name) if backend == TensorBackend.torch: from nncf.tensor.functions import torch_io diff --git a/nncf/tensor/functions/tf_io.py b/nncf/tensor/functions/tf_io.py new file mode 100644 index 00000000000..bf97ab43121 --- /dev/null +++ b/nncf/tensor/functions/tf_io.py @@ -0,0 +1,28 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional + +import tensorflow as tf +from safetensors.tensorflow import load_file as tf_load_file +from safetensors.tensorflow import save_file as tf_save_file + +from nncf.tensor import TensorDeviceType +from nncf.tensor.functions import io as io + + +def load_file(file_path: str, *, device: Optional[TensorDeviceType] = None) -> Dict[str, tf.Tensor]: + return tf_load_file(file_path) + + +@io.save_file.register(tf.Tensor) +def _(data: Dict[str, tf.Tensor], file_path: str) -> None: + return tf_save_file(data, file_path) diff --git a/nncf/tensor/functions/tf_linalg.py b/nncf/tensor/functions/tf_linalg.py new file mode 100644 index 00000000000..f0a5b8db290 --- /dev/null +++ b/nncf/tensor/functions/tf_linalg.py @@ -0,0 +1,129 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import Optional, Tuple, Union + +import tensorflow as tf + +from nncf.tensor.functions import linalg + + +@linalg.norm.register(tf.Tensor) +def _( + a: tf.Tensor, + ord: Optional[Union[str, float, int]] = None, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> tf.Tensor: + if ord is None: + ord = "euclidean" + rank = tf.rank(a) + if rank == 2 and axis is None: + axis = (0, 1) + + with tf.device(a.device): + if ord == "nuc" and isinstance(axis, tuple) and len(axis) != 1: + if rank != 2: + raise ValueError("ord='nuc' is only supported for 2D tensors") + s = tf.linalg.svd(a, compute_uv=False) + return tf.reduce_sum(s, axis=-1) + + if ord == -1 and isinstance(axis, tuple) and len(axis) != 1: + if rank != 2: + raise ValueError("ord=-1 is only supported for 2D tensors") + return tf.reduce_min(tf.reduce_sum(tf.abs(a), axis=axis[0]), keepdims=keepdims) + + if ord == 1 and isinstance(axis, tuple) and len(axis) != 1: + if rank != 2: + raise ValueError("ord=1 is only supported for 2D tensors") + return tf.reduce_max(tf.reduce_sum(tf.abs(a), axis=axis[0]), keepdims=keepdims) + + if ord == -2 and isinstance(axis, tuple) and len(axis) != 1: + if rank != 2: + raise ValueError("ord=-2 is only supported for 2D tensors") + s = tf.linalg.svd(a, compute_uv=False) + return tf.reduce_min(s, axis=-1) + + if ord == 2 and isinstance(axis, tuple) and len(axis) != 1: + if rank != 2: + raise ValueError("ord=2 is only supported for 2D tensors") + s = tf.linalg.svd(a, compute_uv=False) + return tf.reduce_max(s, axis=-1) + + if ord == float("inf") and isinstance(axis, tuple) and len(axis) != 1: + if rank != 2: + raise ValueError("ord=inf is only supported for 2D tensors") + return tf.reduce_max(tf.reduce_sum(tf.abs(a), axis=axis[1]), keepdims=keepdims) + + if ord == -float("inf") and isinstance(axis, tuple) and len(axis) != 1: + if rank != 2: + raise ValueError("ord=-inf is only supported for 2D tensors") + return tf.reduce_min(tf.reduce_sum(tf.abs(a), axis=axis[1]), keepdims=keepdims) + + return tf.linalg.norm(a, ord=ord, axis=axis, keepdims=keepdims) + + +@linalg.cholesky.register(tf.Tensor) +def _(a: tf.Tensor, upper: bool = False) -> tf.Tensor: + with tf.device(a.device): + cholesky = tf.linalg.cholesky(a) + if upper: + perm = list(range(tf.rank(a))) + perm[-1], perm[-2] = perm[-2], perm[-1] + cholesky = tf.transpose(cholesky, perm=perm) + return cholesky + + +@linalg.cholesky_inverse.register(tf.Tensor) +def _(a: tf.Tensor, upper: bool = False) -> tf.Tensor: + with tf.device(a.device): + if upper: + perm = list(range(tf.rank(a))) + perm[-1], perm[-2] = perm[-2], perm[-1] + a = tf.transpose(a, perm=perm) + + eye = tf.eye(a.shape[0], dtype=a.dtype) + return tf.linalg.cholesky_solve(a, eye) + + +@linalg.inv.register(tf.Tensor) +def _(a: tf.Tensor) -> tf.Tensor: + with tf.device(a.device): + return tf.linalg.inv(a) + + +@linalg.pinv.register(tf.Tensor) +def _(a: tf.Tensor) -> tf.Tensor: + with tf.device(a.device): + return tf.linalg.pinv(a) + + +@linalg.lstsq.register(tf.Tensor) +def _(a: tf.Tensor, b: tf.Tensor, driver: Optional[str] = None) -> tf.Tensor: + with tf.device(a.device): + if driver is not None: + warnings.warn("Driver specifying is not supported in TensorFlow lstsq method") + if tf.rank(b) == 1: + b = tf.expand_dims(b, axis=0) + perm = list(range(tf.rank(b))) + perm[-1], perm[-2] = perm[-2], perm[-1] + b = tf.transpose(b, perm=perm) + + return tf.linalg.lstsq(a, b) + + +@linalg.svd.register(tf.Tensor) +def _(a: tf.Tensor, full_matrices: Optional[bool] = True) -> tf.Tensor: + with tf.device(a.device): + s, u, v = tf.linalg.svd(a, full_matrices=full_matrices) + + return u, s, tf.transpose(v, conjugate=True) diff --git a/nncf/tensor/functions/tf_numeric.py b/nncf/tensor/functions/tf_numeric.py new file mode 100644 index 00000000000..f6e6228966d --- /dev/null +++ b/nncf/tensor/functions/tf_numeric.py @@ -0,0 +1,545 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union + +import numpy as np +import tensorflow as tf + +from nncf.tensor import TensorDataType +from nncf.tensor import TensorDeviceType +from nncf.tensor.definitions import TensorBackend +from nncf.tensor.definitions import TypeInfo +from nncf.tensor.functions import numeric as numeric +from nncf.tensor.tensor import TTensor + +DTYPE_MAP = { + TensorDataType.float16: tf.float16, + TensorDataType.bfloat16: tf.bfloat16, + TensorDataType.float32: tf.float32, + TensorDataType.float64: tf.float64, + TensorDataType.int8: tf.int8, + TensorDataType.int32: tf.int32, + TensorDataType.int64: tf.int64, + TensorDataType.uint8: tf.uint8, +} + +DEVICE_MAP = {TensorDeviceType.CPU: "CPU", TensorDeviceType.GPU: "GPU"} + +DTYPE_MAP_REV = {v: k for k, v in DTYPE_MAP.items()} +DEVICE_MAP_REV = {v: k for k, v in DEVICE_MAP.items()} + + +def convert_to_tf_device(device: TensorDeviceType) -> str: + return DEVICE_MAP[device] if device is not None else None + + +def convert_to_tf_dtype(dtype: TensorDataType) -> tf.DType: + return DTYPE_MAP[dtype] if dtype is not None else None + + +@numeric.device.register(tf.Tensor) +def _(a: tf.Tensor) -> TensorDeviceType: + if "CPU" in a.device: + return DEVICE_MAP_REV["CPU"] + if "GPU" in a.device: + return DEVICE_MAP_REV["GPU"] + + +@numeric.backend.register(tf.Tensor) +def _(a: tf.Tensor) -> TensorBackend: + return TensorBackend.tf + + +@numeric.squeeze.register(tf.Tensor) +def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> tf.Tensor: + with tf.device(a.device): + return tf.squeeze(a, axis) + + +@numeric.flatten.register(tf.Tensor) +def _(a: tf.Tensor) -> tf.Tensor: + with tf.device(a.device): + return tf.reshape(a, [-1]) + + +@numeric.max.register(tf.Tensor) +def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> tf.Tensor: + with tf.device(a.device): + return tf.reduce_max(a, axis=axis, keepdims=keepdims) + + +@numeric.min.register(tf.Tensor) +def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> tf.Tensor: + with tf.device(a.device): + return tf.reduce_min(a, axis=axis, keepdims=keepdims) + + +@numeric.abs.register(tf.Tensor) +def _(a: tf.Tensor) -> tf.Tensor: + with tf.device(a.device): + return tf.abs(a) + + +@numeric.astype.register(tf.Tensor) +def _(a: tf.Tensor, dtype: TensorDataType) -> tf.Tensor: + with tf.device(a.device): + return tf.cast(a, DTYPE_MAP[dtype]) + + +@numeric.dtype.register(tf.Tensor) +def _(a: tf.Tensor) -> TensorDataType: + return DTYPE_MAP_REV[a.dtype] + + +@numeric.reshape.register(tf.Tensor) +def _(a: tf.Tensor, shape: Tuple[int, ...]) -> tf.Tensor: + with tf.device(a.device): + return tf.reshape(a, shape) + + +@numeric.all.register(tf.Tensor) +def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> tf.Tensor: + with tf.device(a.device): + if axis is None: + return tf.reduce_all(a) + return tf.reduce_all(a, axis=axis) + + +@numeric.allclose.register(tf.Tensor) +def _( + a: tf.Tensor, b: Union[tf.Tensor, float], rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False +) -> bool: + with tf.device(a.device): + return bool(tf.experimental.numpy.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)) + + +@numeric.any.register(tf.Tensor) +def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> tf.Tensor: + with tf.device(a.device): + if axis is None: + return tf.reduce_any(a) + return tf.reduce_any(a, axis=axis) + + +@numeric.count_nonzero.register(tf.Tensor) +def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> tf.Tensor: + with tf.device(a.device): + return tf.math.count_nonzero(a, axis=axis) + + +@numeric.isempty.register(tf.Tensor) +def _(a: tf.Tensor) -> bool: + return bool(tf.equal(tf.size(a), 0)) + + +@numeric.isclose.register(tf.Tensor) +def _( + a: tf.Tensor, b: Union[tf.Tensor, float], rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False +) -> tf.Tensor: + with tf.device(a.device): + return tf.experimental.numpy.isclose(a, b, atol=atol, rtol=rtol, equal_nan=equal_nan) + + +@numeric.maximum.register(tf.Tensor) +def _(x1: tf.Tensor, x2: Union[tf.Tensor, float]) -> tf.Tensor: + with tf.device(x1.device): + return tf.maximum(x1, x2) + + +@numeric.minimum.register(tf.Tensor) +def _(x1: tf.Tensor, x2: Union[tf.Tensor, float]) -> tf.Tensor: + with tf.device(x1.device): + return tf.minimum(x1, x2) + + +@numeric.ones_like.register(tf.Tensor) +def _(a: tf.Tensor) -> tf.Tensor: + with tf.device(a.device): + return tf.ones_like(a) + + +@numeric.where.register(tf.Tensor) +def _(condition: tf.Tensor, x: Union[tf.Tensor, float, bool], y: Union[tf.Tensor, float, bool]) -> tf.Tensor: + with tf.device(condition.device): + return tf.where(condition, x, y) + + +@numeric.zeros_like.register(tf.Tensor) +def _(a: tf.Tensor) -> tf.Tensor: + with tf.device(a.device): + return tf.zeros_like(a) + + +@numeric.stack.register(tf.Tensor) +def _(x: List[tf.Tensor], axis: int = 0) -> tf.Tensor: + with tf.device(x[0].device): + return tf.stack(x, axis=axis) + + +@numeric.concatenate.register(tf.Tensor) +def _(x: List[tf.Tensor], axis: int = 0) -> tf.Tensor: + with tf.device(x[0].device): + return tf.concat(x, axis=axis) + + +@numeric.unstack.register(tf.Tensor) +def _(x: tf.Tensor, axis: int = 0) -> List[tf.Tensor]: + with tf.device(x.device): + if not list(x.shape): + tf.expand_dims(x, 0) + return tf.unstack(x, axis=axis) + + +@numeric.moveaxis.register(tf.Tensor) +def _(a: tf.Tensor, source: Union[int, Tuple[int, ...]], destination: Union[int, Tuple[int, ...]]) -> tf.Tensor: + with tf.device(a.device): + return tf.experimental.numpy.moveaxis(a, source, destination) + + +@numeric.mean.register(tf.Tensor) +def _( + a: tf.Tensor, + axis: Union[int, Tuple[int, ...]] = None, + keepdims: bool = False, + dtype: Optional[TensorDataType] = None, +) -> tf.Tensor: + with tf.device(a.device): + a = tf.cast(a, DTYPE_MAP[dtype]) if dtype is not None else a + return tf.reduce_mean(a, axis=axis, keepdims=keepdims) + + +@numeric.median.register(tf.Tensor) +def _( + a: tf.Tensor, + axis: Union[int, Tuple[int, ...]] = None, + keepdims: bool = False, +) -> tf.Tensor: + with tf.device(a.device): + if axis is None: + a = tf.reshape(a, [-1]) + else: + if isinstance(axis, int): + axis = (axis,) + destination_axis = tuple([-(i + 1) for i in range(len(axis))]) + a = tf.experimental.numpy.moveaxis(a, axis, destination_axis) + last_axis = 1 + for i in range(len(axis)): + last_axis *= a.shape[-(i + 1)] + new_shape = a.shape[: -len(axis)] + [last_axis] + a = tf.reshape(a, new_shape) + k = 1 + a.shape[-1] // 2 + top_k = tf.math.top_k(a, k=k, sorted=True).values + if a.shape[-1] % 2 == 0: + median = (tf.gather(top_k, indices=[k - 2], axis=-1) + tf.gather(top_k, indices=[k - 1], axis=-1)) / 2 + else: + median = tf.gather(top_k, indices=[k - 1], axis=-1) + median = tf.squeeze(median, axis=-1) + if keepdims: + for axe in sorted(axis, key=lambda x: abs(x)): + median = tf.expand_dims(median, axe) + + return median + + +@numeric.round.register(tf.Tensor) +def _(a: tf.Tensor, decimals=0) -> tf.Tensor: + scale_factor = 10**decimals + scaled_tensor = a * scale_factor + with tf.device(a.device): + rounded_tensor = tf.round(scaled_tensor) + return rounded_tensor / scale_factor + + +@numeric.power.register(tf.Tensor) +def _(a: tf.Tensor, exponent: Union[tf.Tensor, float]) -> tf.Tensor: + with tf.device(a.device): + return tf.pow(a, exponent) + + +@numeric.quantile.register(tf.Tensor) +def quantile( + a: tf.Tensor, + q: Union[float, List[float]], + axis: Optional[Union[int, Tuple[int]]] = None, + keepdims: bool = False, +) -> tf.Tensor: + a_np = a.numpy() + quantile_np = np.quantile(a_np, q=q, axis=axis, keepdims=keepdims) + with tf.device(a.device): + return tf.constant(quantile_np) + + +@numeric.percentile.register(tf.Tensor) +def _( + a: tf.Tensor, + q: Union[float, List[float]], + axis: Union[int, Tuple[int, ...], List[int]], + keepdims: bool = False, +) -> List[Union[tf.Tensor, np.generic]]: + with tf.device(a.device): + q = [x / 100 for x in q] if isinstance(q, (list, tuple)) else q / 100 + return numeric.quantile(a, q=q, axis=axis, keepdims=keepdims) + + +@numeric._binary_op_nowarn.register(tf.Tensor) +def _(a: tf.Tensor, b: Union[tf.Tensor, float], operator_fn: Callable) -> tf.Tensor: + with tf.device(a.device): + return operator_fn(a, b) + + +@numeric._binary_reverse_op_nowarn.register(tf.Tensor) +def _(a: tf.Tensor, b: Union[tf.Tensor, float], operator_fn: Callable) -> tf.Tensor: + with tf.device(a.device): + return operator_fn(b, a) + + +@numeric.clip.register(tf.Tensor) +def _(a: tf.Tensor, a_min: Union[tf.Tensor, float], a_max: Union[tf.Tensor, float]) -> tf.Tensor: + with tf.device(a.device): + return tf.clip_by_value(a, a_min, a_max) + + +@numeric.finfo.register(tf.Tensor) +def _(a: tf.Tensor) -> TypeInfo: + ti = tf.experimental.numpy.finfo(a.dtype) + return TypeInfo(ti.eps, ti.max, ti.min) + + +@numeric.as_tensor_like.register(tf.Tensor) +def _(a: tf.Tensor, data: Any) -> tf.Tensor: + with tf.device(a.device): + return tf.convert_to_tensor(data) + + +@numeric.item.register(tf.Tensor) +def _(a: tf.Tensor) -> Union[int, float, bool]: + return a.numpy().item() + + +@numeric.sum.register(tf.Tensor) +def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> tf.Tensor: + with tf.device(a.device): + return tf.reduce_sum(a, axis=axis, keepdims=keepdims) + + +@numeric.multiply.register(tf.Tensor) +def _(x1: tf.Tensor, x2: Union[tf.Tensor, float]) -> tf.Tensor: + with tf.device(x1.device): + return tf.multiply(x1, x2) + + +@numeric.var.register(tf.Tensor) +def _( + a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ddof: int = 0 +) -> tf.Tensor: + with tf.device(a.device): + tf_var = tf.math.reduce_variance(a, axis=axis, keepdims=keepdims) + if ddof: + n = tf.shape(a)[axis] if axis is not None else tf.size(a) + tf_var *= float(n) / float(n - ddof) + return tf_var + + +@numeric.size.register(tf.Tensor) +def _(a: tf.Tensor) -> int: + return tf.size(a) + + +@numeric.matmul.register(tf.Tensor) +def _(x1: tf.Tensor, x2: tf.Tensor) -> tf.Tensor: + with tf.device(x1.device): + return tf.matmul(x1, x2) + + +@numeric.unsqueeze.register(tf.Tensor) +def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> tf.Tensor: + with tf.device(a.device): + return tf.expand_dims(a, axis=axis) + + +@numeric.transpose.register(tf.Tensor) +def _(a: tf.Tensor, axes: Optional[Tuple[int, ...]] = None) -> tf.Tensor: + with tf.device(a.device): + return tf.transpose(a, perm=axes) + + +@numeric.argsort.register(tf.Tensor) +def _(a: tf.Tensor, axis: int = -1, descending=False, stable=False) -> tf.Tensor: + with tf.device(a.device): + direction = "DESCENDING" if descending else "ASCENDING" + return tf.argsort(a, axis=axis, direction=direction, stable=stable) + + +@numeric.diag.register(tf.Tensor) +def _(a: tf.Tensor, k: int = 0) -> tf.Tensor: + with tf.device(a.device): + rank = tf.rank(a) + if rank == 1: + return tf.linalg.diag(a, k=k) + elif rank == 2: + return tf.linalg.diag_part(a, k=k) + else: + raise ValueError("Input tensor must be 1D or 2D.") + + +@numeric.logical_or.register(tf.Tensor) +def _(x1: tf.Tensor, x2: tf.Tensor) -> tf.Tensor: + with tf.device(x1.device): + return tf.logical_or(x1, x2) + + +@numeric.masked_mean.register(tf.Tensor) +def _( + x: tf.Tensor, mask: Optional[tf.Tensor], axis: Union[int, Tuple[int, ...], List[int]], keepdims=False +) -> tf.Tensor: + with tf.device(x.device): + if mask is None: + return tf.reduce_mean(x, axis=axis, keepdims=keepdims) + flipped_mask = ~mask + valid_counts = tf.reduce_sum(tf.cast(flipped_mask, x.dtype), axis=axis, keepdims=keepdims) + masked_x = tf.where(mask, tf.zeros_like(x), x) + valid_sum = tf.reduce_sum(masked_x, axis=axis, keepdims=keepdims) + + ret = valid_sum / valid_counts + ret = tf.where(tf.math.is_nan(ret), tf.zeros_like(ret), ret) + + return ret + + +@numeric.masked_median.register(tf.Tensor) +def _( + x: tf.Tensor, mask: Optional[tf.Tensor], axis: Union[int, Tuple[int, ...], List[int]], keepdims=False +) -> tf.Tensor: + if mask is None: + return numeric.median(x, axis=axis, keepdims=keepdims) + + masked_x = tf.where(mask, np.nan, x) + np_masked_x = masked_x.numpy() + np_masked_median = np.nanquantile(np_masked_x, 0.5, axis=axis, keepdims=keepdims) + + with tf.device(x.device): + ret = tf.constant(np_masked_median) + ret = tf.where(tf.math.is_nan(ret), tf.zeros_like(ret), ret) + + return ret + + +@numeric.expand_dims.register(tf.Tensor) +def _(a: tf.Tensor, axis: Union[int, Tuple[int, ...], List[int]]) -> np.ndarray: + if type(axis) not in (tuple, list): + axis = (axis,) + + if len(set(axis)) != len(axis): + raise ValueError("repeated axis") + + out_ndim = len(axis) + a.ndim + + norm_axis = [] + for ax in axis: + if ax < -out_ndim or ax >= out_ndim: + raise ValueError(f"axis {ax} is out of bounds for array of dimension {out_ndim}") + norm_axis.append(ax + out_ndim if ax < 0 else ax) + + shape_it = iter(a.shape) + shape = [1 if ax in norm_axis else next(shape_it) for ax in range(out_ndim)] + return tf.reshape(a, shape) + + +@numeric.clone.register(tf.Tensor) +def _(a: tf.Tensor) -> tf.Tensor: + with tf.device(a.device): + return tf.identity(a) + + +@numeric.searchsorted.register(tf.Tensor) +def _(a: tf.Tensor, v: tf.Tensor, side: str = "left", sorter: Optional[tf.Tensor] = None) -> tf.Tensor: + if side not in ["right", "left"]: + raise ValueError(f"Invalid value for 'side': {side}. Expected 'right' or 'left'.") + if a.ndim != 1: + raise ValueError(f"Input tensor 'a' must be 1-D. Received {a.ndim}-D tensor.") + sorted_a = tf.sort(a) + return tf.searchsorted(sorted_sequence=sorted_a, values=v, side=side) + + +def zeros( + shape: Tuple[int, ...], + *, + dtype: Optional[TensorDataType] = None, + device: Optional[TensorDeviceType] = None, +) -> tf.Tensor: + if dtype is not None: + dtype = DTYPE_MAP[dtype] + if device is not None: + device = DEVICE_MAP[device] + with tf.device(device): + return tf.zeros(shape, dtype=dtype) + + +def eye( + n: int, + m: Optional[int] = None, + *, + dtype: Optional[TensorDataType] = None, + device: Optional[TensorDeviceType] = None, +) -> tf.Tensor: + if dtype is not None: + dtype = DTYPE_MAP[dtype] + if device is not None: + device = DEVICE_MAP[device] + p_args = (n,) if m is None else (n, m) + with tf.device(device): + return tf.eye(*p_args, dtype=dtype) + + +def arange( + start: float, + end: float, + step: float, + *, + dtype: Optional[TensorDataType] = None, + device: Optional[TensorDeviceType] = None, +) -> tf.Tensor: + if dtype is not None: + dtype = DTYPE_MAP[dtype] + if device is not None: + device = DEVICE_MAP[device] + with tf.device(device): + return tf.range(start, end, step, dtype=dtype) + + +def from_numpy(ndarray: np.ndarray) -> tf.Tensor: + with tf.device("CPU"): + return tf.constant(ndarray) + + +@numeric.log2.register(tf.Tensor) +def _(a: tf.Tensor) -> tf.Tensor: + with tf.device(a.device): + return tf.math.log(a) / tf.math.log(2.0) + + +@numeric.ceil.register(tf.Tensor) +def _(a: tf.Tensor) -> tf.Tensor: + with tf.device(a.device): + return tf.math.ceil(a) + + +def tensor( + data: Union[TTensor, Sequence[float]], + *, + dtype: Optional[TensorDataType] = None, + device: Optional[TensorDeviceType] = None, +) -> tf.Tensor: + device = convert_to_tf_device(device) + dtype = convert_to_tf_dtype(dtype) + with tf.device(device): + return tf.constant(data, dtype=dtype) diff --git a/nncf/tensor/tensor.py b/nncf/tensor/tensor.py index 20cc73ea1f6..bb6db19d982 100644 --- a/nncf/tensor/tensor.py +++ b/nncf/tensor/tensor.py @@ -135,7 +135,7 @@ def __rfloordiv__(self, other: Union[Tensor, float]) -> Tensor: return _call_function("_binary_reverse_op_nowarn", self, other, operator.floordiv) def __ifloordiv__(self, other: Union[Tensor, float]) -> Tensor: - self._data /= unwrap_tensor_data(other) + self._data //= unwrap_tensor_data(other) return self def __matmul__(self, other: Union[Tensor, float]) -> Tensor: diff --git a/tests/cross_fw/test_templates/template_test_nncf_tensor.py b/tests/cross_fw/test_templates/template_test_nncf_tensor.py index 7cff6938fac..33bac309a57 100644 --- a/tests/cross_fw/test_templates/template_test_nncf_tensor.py +++ b/tests/cross_fw/test_templates/template_test_nncf_tensor.py @@ -113,7 +113,8 @@ def test_operators_tensor(self, op_name): assert res.dtype == res_nncf.data.dtype assert all(res == res_nncf.data) assert isinstance(res_nncf, Tensor) - assert res_nncf.device == nncf_tensor_a.device + if not (self.backend() == TensorBackend.tf and self.device() == TensorDeviceType.CPU): + assert res_nncf.device == nncf_tensor_a.device @pytest.mark.parametrize("op_name", OPERATOR_MAP.keys()) def test_operators_int(self, op_name): @@ -129,7 +130,8 @@ def test_operators_int(self, op_name): assert res.dtype == res_nncf.data.dtype assert all(res == res_nncf.data) assert isinstance(res_nncf, Tensor) - assert res_nncf.device == nncf_tensor_a.device + if not (self.backend() == TensorBackend.tf and self.device() == TensorDeviceType.CPU): + assert res_nncf.device == nncf_tensor_a.device @pytest.mark.parametrize("op_name", BINARY_OPERATORS) def test_operators_int_rev(self, op_name): @@ -145,7 +147,11 @@ def test_operators_int_rev(self, op_name): assert res.dtype == res_nncf.data.dtype assert all(res == res_nncf.data) assert isinstance(res_nncf, Tensor) - assert res_nncf.device == nncf_tensor_a.device + if not ( + (self.backend() == TensorBackend.tf and self.device() == TensorDeviceType.CPU) + or (self.backend() == TensorBackend.tf and self.device() == TensorDeviceType.GPU and op_name == "pow") + ): + assert res_nncf.device == nncf_tensor_a.device @pytest.mark.parametrize("op_name", COMPARISON_OPERATOR_MAP.keys()) def test_comparison_tensor(self, op_name): @@ -159,7 +165,7 @@ def test_comparison_tensor(self, op_name): res = fn(tensor_a, tensor_b) res_nncf = fn(nncf_tensor_a, nncf_tensor_b) - assert res == res_nncf + assert res_nncf == res assert isinstance(res_nncf, Tensor) @pytest.mark.parametrize("op_name", COMPARISON_OPERATOR_MAP.keys()) @@ -173,7 +179,7 @@ def test_comparison_int(self, op_name): res = fn(tensor_a, value) res_nncf = fn(nncf_tensor_a, value) - assert res == res_nncf + assert res_nncf == res assert isinstance(res_nncf, Tensor) @pytest.mark.parametrize("op_name", COMPARISON_OPERATOR_MAP.keys()) @@ -187,7 +193,7 @@ def test_comparison_int_rev(self, op_name): res = fn(value, tensor_a) res_nncf = fn(value, nncf_tensor_a) - assert res == res_nncf + assert res_nncf == res assert isinstance(res_nncf, Tensor) @pytest.mark.parametrize( @@ -390,7 +396,8 @@ def test_getitem_for_index(self): res = nncf_tensor[1] assert res == 1 assert isinstance(res, Tensor) - assert res.device == nncf_tensor.device + if not (self.backend() == TensorBackend.tf and self.device() == TensorDeviceType.CPU): + assert res.device == nncf_tensor.device @pytest.mark.parametrize("is_tensor_indecies", (False, True)) def test_getitem_for_indecies(self, is_tensor_indecies): @@ -527,7 +534,8 @@ def test_fn_where(self): res = fns.where(tensor > 0, 1, 0) assert all(res.data == tensor_ref) assert isinstance(res, Tensor) - assert res.device == tensor.device + if not (self.backend() == TensorBackend.tf and self.device() == TensorDeviceType.CPU): + assert res.device == tensor.device @pytest.mark.parametrize( "val, ref", @@ -558,19 +566,23 @@ def test_isempty(self, val, ref): assert isinstance(res, bool) @pytest.mark.parametrize( - "x1, x2, rtol, atol, ref", + "x1, x2, is_tensor, rtol, atol, ref", ( - ([0.1], [0.1], None, None, True), - ([0.1], [0.10001], None, None, False), - ([0.1], [0.10001], 0.1, None, True), - ([0.1], [0.10001], None, 0.1, True), - ([0.1], [0.20001], None, 0.1, False), - ([0.1], 0.1, None, None, True), + ([0.1], [0.1], True, None, None, True), + ([0.1], [0.10001], True, None, None, False), + ([0.1], [0.10001], True, 0.1, None, True), + ([0.1], [0.10001], True, None, 0.1, True), + ([0.1], [0.20001], True, None, 0.1, False), + ([0.1], 0.1, True, None, None, True), + ([0.1], 0.1, False, None, None, True), ), ) - def test_fn_allclose(self, x1, x2, rtol, atol, ref): + def test_fn_allclose(self, x1, x2, is_tensor, rtol, atol, ref): tensor1 = Tensor(self.to_tensor(x1)) - tensor2 = Tensor(self.to_tensor(x2)) + if is_tensor: + tensor2 = Tensor(self.to_tensor(x2)) + else: + tensor2 = x2 if rtol is not None: res = fns.allclose(tensor1, tensor2, rtol=rtol) elif atol is not None: @@ -804,6 +816,7 @@ def test_fn_median(self, x, axis, keepdims, ref): (1.1, 0, 1.0), ([1.1, 0.9], 0, [1.0, 1.0]), ([1.11, 0.91], 1, [1.1, 0.9]), + ([5.5, 3.3], -1, [10.0, 0.0]), ), ) def test_fn_round(self, val, decimals, ref): @@ -1045,6 +1058,13 @@ def test_fn_var(self, x, axis, keepdims, ddof, ref): True, [[1.53063197]], ), + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + "nuc", + (0, 1), + False, + [1.53063197], + ), ( [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], float("inf"), @@ -1059,6 +1079,49 @@ def test_fn_var(self, x, axis, keepdims, ddof, ref): False, 0.9364634205074938, ), + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + 2, + 0, + False, + [0.8062258, 0.72801095, 0.22360681], + ), + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + 1, + None, + False, + 0.9, + ), + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + -1, + None, + False, + 0.3, + ), + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + -2, + None, + False, + 0.59416854, + ), + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + float("inf"), + None, + False, + 1.2, + ), + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + -float("inf"), + None, + False, + 0.9, + ), + ([[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]], None, None, False, 2.82842708), ), ) def test_fn_linalg_norm(self, x, ord, axis, keepdims, ref): @@ -1101,7 +1164,8 @@ def test_fn_matmul(self, m1, m2, ref): assert isinstance(res, Tensor) assert fns.allclose(res.data, ref_tensor) - assert res.device == tensor1.device + if not (self.backend() == TensorBackend.tf and self.device() == TensorDeviceType.CPU): + assert res.device == tensor1.device @pytest.mark.parametrize( "val, axis, ref", @@ -1528,6 +1592,8 @@ def test_fn_eye(self, n, m, ref): for dtype in TensorDataType: if dtype == TensorDataType.bfloat16 and self.backend() == TensorBackend.numpy: continue + if (not dtype.is_float()) and self.backend() == TensorBackend.tf and self.device() == TensorDeviceType.GPU: + continue tensor_a = fns.eye(n, m, backend=self.backend(), dtype=dtype, device=self.device()) assert isinstance(tensor_a, Tensor) assert tensor_a.device == self.device() @@ -1547,18 +1613,20 @@ def test_fn_arange(self, start, end, stop, ref): args.append(end) if stop is not None: args.append(stop) - ref = Tensor(self.to_tensor(ref)) + for dtype in [TensorDataType.int32, TensorDataType.float32]: + tensor_ref = Tensor(fns.astype(self.to_tensor(ref), dtype)) tensor_a = fns.arange(*tuple(args), backend=self.backend(), dtype=dtype, device=self.device()) assert isinstance(tensor_a, Tensor) assert tensor_a.device == self.device() assert tensor_a.backend == self.backend() assert tensor_a.dtype == dtype - assert fns.all(tensor_a == ref) + assert fns.all(tensor_a == tensor_ref) def test_fn_from_numpy(self): ndarray = np.array([1, 2]) - ref = Tensor(self.to_cpu(self.to_tensor(ndarray))) + ref_cpu = self.to_cpu(self.to_tensor(ndarray)) + ref = Tensor(ref_cpu) tensor = fns.from_numpy(ndarray, backend=ref.backend) assert isinstance(tensor, Tensor) assert tensor.device == ref.device @@ -1727,6 +1795,13 @@ def test_save_load_symlink_error(self, tmp_path): @pytest.mark.parametrize("data", [[3.0, 2.0, 2.0], [1, 2, 3]]) @pytest.mark.parametrize("dtype", [TensorDataType.float32, TensorDataType.int32, TensorDataType.uint8, None]) def test_fn_tensor(self, data, dtype): + if ( + self.backend() == TensorBackend.tf + and dtype is not None + and not dtype.is_float() + and (data == [3.0, 2.0, 2.0]) + ): + pytest.skip("TF backend does not support non-float dtypes for float data") nncf_tensor = fns.tensor(data, backend=self.backend(), dtype=dtype, device=self.device()) backend_tensor = Tensor(self.to_tensor(data)) if dtype is not None: diff --git a/tests/tensorflow/test_tensor.py b/tests/tensorflow/test_tensor.py new file mode 100644 index 00000000000..f76de5deb21 --- /dev/null +++ b/tests/tensorflow/test_tensor.py @@ -0,0 +1,112 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import tensorflow as tf + +from nncf.tensor import Tensor +from nncf.tensor import TensorDataType +from nncf.tensor.definitions import TensorBackend +from nncf.tensor.definitions import TensorDeviceType +from tests.cross_fw.test_templates.template_test_nncf_tensor import TemplateTestNNCFTensorOperators + + +def cast_to(x: tf.Tensor, dtype: TensorDataType) -> tf.Tensor: + if dtype is TensorDataType.float32: + return tf.cast(x, tf.float32) + if dtype is TensorDataType.float16: + return tf.cast(x, tf.float16) + raise NotImplementedError + + +class TestTFNNCFTensorOperators(TemplateTestNNCFTensorOperators): + @staticmethod + def to_tensor(x): + with tf.device("CPU"): + return tf.constant(x) + + @staticmethod + def to_cpu(x): + return x + + @staticmethod + def cast_to(x: tf.Tensor, dtype: TensorDataType) -> tf.Tensor: + return cast_to(x, dtype) + + @staticmethod + def backend() -> TensorBackend: + return TensorBackend.tf + + @staticmethod + def device() -> TensorDeviceType: + return TensorDeviceType.CPU + + @pytest.mark.skip("Desired slicing is not supported for TensorFlow") + @pytest.mark.parametrize("is_tensor_indecies", (False, True)) + def test_getitem_for_indecies(self, is_tensor_indecies): + pass + + @pytest.mark.skip("TensorFlow throws different kind of exceptions") + @pytest.mark.parametrize( + "val, axis, exception_type, exception_match", + ( + ([[[[1], [2]], [[1], [2]]]], (0, 1), ValueError, "not equal to one"), + ([[[[1], [2]], [[1], [2]]]], 42, IndexError, "out of"), + ([[[[1], [2]], [[1], [2]]]], (0, 42), IndexError, "out of"), + ), + ) + def test_squeeze_axis_error(self, val, axis, exception_type, exception_match): + pass + + +@pytest.mark.skipif(len(tf.config.list_physical_devices("GPU")) == 0, reason="Skipping for CPU-only setups") +class TestGPUTFNNCFTensorOperators(TemplateTestNNCFTensorOperators): + @staticmethod + def to_tensor(x): + with tf.device("GPU"): + return tf.constant(x) + + @staticmethod + def to_cpu(x): + with tf.device("CPU"): + return tf.constant(x.numpy()) + + @staticmethod + def cast_to(x: tf.Tensor, dtype: TensorDataType) -> tf.Tensor: + return cast_to(x, dtype) + + def test_device(self): + tensor = Tensor(self.to_tensor([1])) + assert tensor.device == TensorDeviceType.GPU + + @staticmethod + def backend() -> TensorBackend: + return TensorBackend.tf + + @staticmethod + def device() -> TensorDeviceType: + return TensorDeviceType.GPU + + @pytest.mark.skip("Desired slicing is not supported for TensorFlow") + @pytest.mark.parametrize("is_tensor_indecies", (False, True)) + def test_getitem_for_indecies(self, is_tensor_indecies): + pass + + @pytest.mark.skip("TensorFlow throws different kind of exceptions") + @pytest.mark.parametrize( + "val, axis, exception_type, exception_match", + ( + ([[[[1], [2]], [[1], [2]]]], (0, 1), ValueError, "not equal to one"), + ([[[[1], [2]], [[1], [2]]]], 42, IndexError, "out of"), + ([[[[1], [2]], [[1], [2]]]], (0, 42), IndexError, "out of"), + ), + ) + def test_squeeze_axis_error(self, val, axis, exception_type, exception_match): + pass