From 46a6ccd4ed93d36dd24183b69bd64204c634debe Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Mon, 7 Oct 2024 19:12:12 +0200 Subject: [PATCH] List decompositions for torch.export (#26878) ### Details: - *item1* - *...* ### Tickets: - *ticket-id* --- .../openvino/frontend/pytorch/fx_decoder.py | 8 +- .../pytorch/torchdynamo/decompositions.py | 205 +++++++++++++++++- src/frontends/pytorch/src/op_table.cpp | 1 + .../pytorch_tests/pytorch_layer_test_class.py | 19 +- .../layer_tests/pytorch_tests/test_col2im.py | 1 + tests/layer_tests/pytorch_tests/test_eye.py | 20 +- tests/model_hub_tests/pytorch/torch_utils.py | 5 +- .../moc_frontend/pytorch_frontend_utils.py | 11 +- 8 files changed, 231 insertions(+), 39 deletions(-) diff --git a/src/bindings/python/src/openvino/frontend/pytorch/fx_decoder.py b/src/bindings/python/src/openvino/frontend/pytorch/fx_decoder.py index d9dae251aa64e7..a7e9f895b5334b 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/fx_decoder.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/fx_decoder.py @@ -4,14 +4,14 @@ # flake8: noqa # mypy: ignore-errors +import logging +import torch + from openvino.frontend.pytorch.py_pytorch_frontend import _FrontEndPytorchDecoder as Decoder from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType -from openvino.runtime import op, PartialShape, Type as OVType, OVAny, Shape +from openvino.runtime import PartialShape, Type as OVType, OVAny, Shape from openvino.frontend.pytorch.utils import make_constant, fetch_attr, pt_to_ov_type_map, torch_tensor_to_ov_const -import torch - -import logging logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) diff --git a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/decompositions.py b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/decompositions.py index 368dbc4cbfa358..eb117f56ab167d 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/decompositions.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/decompositions.py @@ -46,7 +46,9 @@ def convolution_backward( return grad_input, grad_weight, grad_bias + if len(get_decompositions([aten._scaled_dot_product_flash_attention.default])) == 0: + @register_decomposition(aten._scaled_dot_product_flash_attention.default) def scaled_dot_product_flash_attention( query, @@ -101,16 +103,197 @@ def scaled_dot_product_flash_attention( def get_aot_decomposition_list(): - return ([torch.ops.aten._scaled_dot_product_flash_attention.default, - torch.ops.aten._softmax.default, - torch.ops.aten._softmax_backward_data.default, - torch.ops.aten.convolution_backward.default, - torch.ops.aten.gelu_backward.default, - torch.ops.aten.native_group_norm.default, - torch.ops.aten.native_group_norm_backward.default, - torch.ops.aten.native_layer_norm.default, - torch.ops.aten.native_layer_norm_backward.default, - torch.ops.aten.slice_backward.default]) + return [ + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten._softmax.default, + torch.ops.aten._softmax_backward_data.default, + torch.ops.aten.convolution_backward.default, + torch.ops.aten.gelu_backward.default, + torch.ops.aten.native_group_norm.default, + torch.ops.aten.native_group_norm_backward.default, + torch.ops.aten.native_layer_norm.default, + torch.ops.aten.native_layer_norm_backward.default, + torch.ops.aten.slice_backward.default, + ] + def get_inf_decomposition_list(): - return ([torch.ops.aten.nll_loss_forward.default]) + return [torch.ops.aten.nll_loss_forward.default] + + +def get_export_decomposition_list(): + # List of decompositions from torch._decomp.core_aten_decompositions + # removed _backward ops and ops supported without decomposition + decomp = [ + torch.ops.aten.addcdiv, + torch.ops.aten.addcdiv_, + torch.ops.aten.addcmul, + torch.ops.aten.addcmul_, + torch.ops.aten.addr, + torch.ops.aten.affine_grid_generator, + torch.ops.aten.all, + torch.ops.aten.aminmax, + torch.ops.aten.arange.default, + torch.ops.aten.arange.start, + torch.ops.aten.baddbmm, + torch.ops.aten.binary_cross_entropy, + torch.ops.aten.binary_cross_entropy_with_logits, + torch.ops.aten.block_diag, + torch.ops.aten.celu, + torch.ops.aten.celu_, + torch.ops.aten.clamp_max, + torch.ops.aten.clamp_min, + torch.ops.aten.count_nonzero, + torch.ops.aten.linalg_cross, + torch.ops.aten.cudnn_batch_norm, + torch.ops.aten.deg2rad, + torch.ops.aten.deg2rad_, + torch.ops.aten.detach, + torch.ops.aten.diag_embed, + torch.ops.aten.dot, + torch.ops.aten.vdot, + torch.ops.aten.elu, + torch.ops.aten.elu_, + torch.ops.aten._embedding_bag, + torch.ops.aten.empty_like, + torch.ops.aten._euclidean_dist.default, + torch.ops.aten.expand_as, + torch.ops.aten.eye, + torch.ops.aten.fill, + torch.ops.aten.fill_, + torch.ops.aten.floor_divide, + torch.ops.aten.frac, + torch.ops.aten.frac_, + torch.ops.aten._fused_moving_avg_obs_fq_helper, + torch.ops.aten.gelu_, + torch.ops.aten.glu, + torch.ops.aten.hardshrink, + torch.ops.aten.hardsigmoid, + torch.ops.aten.hardsigmoid_, + torch.ops.aten.hardswish, + torch.ops.aten.hardswish_, + torch.ops.aten.hardtanh_, + torch.ops.aten.heaviside, + torch.ops.aten.heaviside_, + torch.ops.aten.huber_loss, + torch.ops.aten.im2col, + torch.ops.aten.index_add, + torch.ops.aten.index_add_, + torch.ops.aten.index_copy, + torch.ops.aten.index_copy_, + torch.ops.aten.index_fill, + torch.ops.aten.index_fill_, + torch.ops.aten.isin, + torch.ops.aten.isneginf, + torch.ops.aten.isposinf, + torch.ops.aten.l1_loss, + torch.ops.aten.leaky_relu_, + torch.ops.aten.lerp, + torch.ops.aten.lerp_, + torch.ops.aten.linspace, + torch.ops.aten.logaddexp, + torch.ops.aten.logaddexp2, + torch.ops.aten.logit, + torch.ops.aten.logit_, + torch.ops.aten.log_sigmoid_forward, + torch.ops.aten.logspace, + torch.ops.aten.logsumexp.default, + torch.ops.aten.masked_fill, + torch.ops.aten.masked_fill_, + torch.ops.aten.mish, + torch.ops.aten.mish_, + torch.ops.aten.mse_loss, + torch.ops.aten.multi_margin_loss, + torch.ops.aten.multilabel_margin_loss_forward, + torch.ops.aten.mv, + torch.ops.aten.mvlgamma, + torch.ops.aten.mvlgamma_, + torch.ops.aten.nansum, + torch.ops.aten.nan_to_num, + torch.ops.aten.nan_to_num_, + torch.ops.aten.narrow, + torch.ops.aten.new_empty, + torch.ops.aten.new_full, + torch.ops.aten.new_ones, + torch.ops.aten.new_zeros, + torch.ops.aten.nll_loss_forward, + torch.ops.aten.norm, + torch.ops.aten.ones, + torch.ops.aten.ones_like, + torch.ops.aten._prelu_kernel, + torch.ops.aten._reshape_alias, + torch.ops.aten.rad2deg, + torch.ops.aten.rad2deg_, + torch.ops.aten.reflection_pad1d, + torch.ops.aten.reflection_pad2d, + torch.ops.aten.reflection_pad3d, + torch.ops.aten.replication_pad1d, + torch.ops.aten.replication_pad2d, + torch.ops.aten.replication_pad3d, + torch.ops.aten.renorm, + torch.ops.aten.renorm_, + torch.ops.aten.resize_as, + torch.ops.aten.roll, + torch.ops.aten.rot90, + torch.ops.aten.rrelu_with_noise, + torch.ops.aten.rrelu_with_noise_, + torch.ops.aten.rsub, + torch.ops.aten.select_scatter, + torch.ops.aten.sgn, + torch.ops.aten.sgn_, + torch.ops.aten.silu, + torch.ops.aten.silu_, + torch.ops.aten.sinc, + torch.ops.aten.sinc_, + torch.ops.aten.smooth_l1_loss, + torch.ops.aten.soft_margin_loss, + torch.ops.aten.softplus, + torch.ops.aten.softshrink, + torch.ops.aten.special_entr, + torch.ops.aten.special_log_ndtr, + torch.ops.aten.special_xlog1py, + torch.ops.aten.split.Tensor, + torch.ops.aten.split_with_sizes_copy, + torch.ops.aten.squeeze.default, + torch.ops.aten.squeeze.dim, + torch.ops.aten.std, + torch.ops.aten.std_mean, + torch.ops.aten.stack, + torch.ops.aten.sum.default, + torch.ops.aten.sum.out, + torch.ops.aten.t, + torch.ops.aten.take, + torch.ops.aten.threshold, + torch.ops.aten.threshold_, + torch.ops.aten.trace, + torch.ops.aten.transpose.int, + torch.ops.aten.tril, + torch.ops.aten.tril_, + torch.ops.aten.triu, + torch.ops.aten.triu_, + torch.ops.aten.unbind, + torch.ops.aten.unfold_copy, + torch.ops.aten._unsafe_index, + torch.ops.aten.unsafe_split.Tensor, + torch.ops.aten.unsafe_split_with_sizes, + torch.ops.aten._unsafe_view, + torch.ops.aten.view_as_complex, + torch.ops.aten.xlogy, + torch.ops.aten.xlogy_, + torch.ops.aten.zero, + torch.ops.aten.zero_, + torch.ops.aten.zeros, + torch.ops.aten.zeros_like, + torch.ops.aten._weight_norm_interface, + ] + try: + from packaging import version + if version.parse(torch.__version__) >= version.parse("2.3"): + decomp += [ + torch.ops.aten._lazy_clone, + torch.ops.aten._test_parallel_materialize, + torch.ops.aten._chunk_cat, + ] + except ImportError: + pass + return decomp diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 1e4ecfc1e1367f..31cf99a2e1b9d7 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -787,6 +787,7 @@ const std::unordered_map get_supported_ops_fx() { {"aten.clamp_min.default", op::translate_1to1_match_2_inputs_align_types}, {"aten.clamp_min.Tensor", op::translate_1to1_match_2_inputs_align_types}, {"aten.clone.default", op::skip_node}, // ignore clone operators that are inserted by PyTorch autograd + {"aten.col2im.default", op::translate_col2im}, {"aten.constant_pad_nd.default", op::translate_constant_pad_nd_fx}, {"aten.convolution.default", op::translate_convolution}, {"aten.copy.default", op::translate_copy_fx}, diff --git a/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py b/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py index a2f54076de9d7f..5bf019db3c131e 100644 --- a/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py +++ b/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py @@ -5,17 +5,18 @@ import warnings from copy import deepcopy import os - +import torch +import pytest +import logging import numpy as np + from common.constants import test_device, test_precision from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder - from openvino.frontend import FrontEndManager from openvino.runtime import Core, Type, PartialShape import openvino.properties.hint as hints -import torch -from packaging import version -import pytest + +logging.basicConfig(level=logging.DEBUG) def skip_check(param): @@ -124,13 +125,9 @@ def numpy_to_torch_recursively(x): from torch.export import export em = export(model, tuple(torch_inputs)) - if version.parse(torch.__version__) >= version.parse("2.3"): - em = em.run_decompositions() - gm = em.module() - print(gm.code) converted_model = convert_model( - em, example_input=torch_inputs) + em, example_input=torch_inputs, verbose=True) self._resolve_input_shape_dtype( converted_model, ov_inputs, dynamic_shapes) smodel = model @@ -242,7 +239,7 @@ def convert_via_mo(self, model, example_input, trace_model, dynamic_shapes, ov_i if not dynamic_shapes: input_shapes = [inp.shape for inp in ov_inputs] kwargs["input"] = input_shapes - om = convert_model(decoder, **kwargs) + om = convert_model(decoder, verbose=True, **kwargs) self._resolve_input_shape_dtype(om, ov_inputs, dynamic_shapes) return smodel, om diff --git a/tests/layer_tests/pytorch_tests/test_col2im.py b/tests/layer_tests/pytorch_tests/test_col2im.py index 8cb7ea96cb8391..1dc44557c359fb 100644 --- a/tests/layer_tests/pytorch_tests/test_col2im.py +++ b/tests/layer_tests/pytorch_tests/test_col2im.py @@ -40,6 +40,7 @@ def forward(self, x): @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.precommit_torch_export @pytest.mark.parametrize("output_size,kernel_size", [([4, 5], [2, 2])]) @pytest.mark.parametrize("dilation", [1, 2, [1, 2]]) @pytest.mark.parametrize("padding", [0, 5, [2, 3]]) diff --git a/tests/layer_tests/pytorch_tests/test_eye.py b/tests/layer_tests/pytorch_tests/test_eye.py index 37b850088844cd..f93e77a8b2844a 100644 --- a/tests/layer_tests/pytorch_tests/test_eye.py +++ b/tests/layer_tests/pytorch_tests/test_eye.py @@ -3,6 +3,7 @@ import pytest import torch +from packaging import version from pytorch_layer_test_class import PytorchLayerTest @@ -14,7 +15,6 @@ def _prepare_input(self, m, n=None): return (np.array(m, dtype="int32"), ) return (np.array(m, dtype="int32"), np.array(n, dtype="int32")) - def create_model(self, num_inputs, dtype): import torch dtype_map = { @@ -45,29 +45,31 @@ def __init__(self, dtype): def forward(self, x, y): return torch.eye(x, y, dtype=self.dtype) - - ref_net = None - - return aten_eye_1_input(pt_dtype) if num_inputs == 1 else aten_eye_2_inputs(pt_dtype), ref_net, ("aten::eye", "aten::IntImplicit") + model = aten_eye_1_input(pt_dtype) if num_inputs == 1 else aten_eye_2_inputs(pt_dtype) + return model, None, ["aten::eye", "aten::IntImplicit"] @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export @pytest.mark.parametrize("dtype", ["bool", "int8", "uint8", "int32", "int64", "float32", "float64"]) @pytest.mark.parametrize("m", [2, 3, 4, 5]) - @pytest.mark.skipif(torch.__version__ < '2.3.0', reason="`aten.eye` is not supported in PyTorch versions earlier than 2.3.") def test_eye_square(self, dtype, m, ie_device, precision, ir_version): + if PytorchLayerTest.use_torch_export() and version.parse(torch.__version__) < version.parse("2.3"): + pytest.skip("Not supported in PyTorch versions earlier than 2.3.") if ie_device == "GPU": pytest.xfail(reason="eye is not supported on GPU") - self._test(*self.create_model(1, dtype), ie_device, precision, ir_version, kwargs_to_prepare_input={"m": m}) + self._test(*self.create_model(1, dtype), ie_device, precision, + ir_version, kwargs_to_prepare_input={"m": m}) @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export @pytest.mark.parametrize("dtype", ["bool", "int8", "uint8", "int32", "int64", "float32", "float64"]) @pytest.mark.parametrize(("m", "n"), [[2, 2], [3, 4], [5, 3]]) - @pytest.mark.skipif(torch.__version__ < '2.3.0', reason="`aten.eye` is not supported in PyTorch versions earlier than 2.3.") def test_eye(self, dtype, m, n, ie_device, precision, ir_version): + if (PytorchLayerTest.use_torch_export() and version.parse(torch.__version__) < version.parse("2.3")): + pytest.skip("Not supported in PyTorch versions earlier than 2.3.") if ie_device == "GPU": pytest.xfail(reason="eye is not supported on GPU") - self._test(*self.create_model(2, dtype), ie_device, precision, ir_version, kwargs_to_prepare_input={"m": m, "n": n}) + self._test(*self.create_model(2, dtype), ie_device, precision, + ir_version, kwargs_to_prepare_input={"m": m, "n": n}) diff --git a/tests/model_hub_tests/pytorch/torch_utils.py b/tests/model_hub_tests/pytorch/torch_utils.py index 09826b058c7855..5b351c6317e9bd 100644 --- a/tests/model_hub_tests/pytorch/torch_utils.py +++ b/tests/model_hub_tests/pytorch/torch_utils.py @@ -75,7 +75,10 @@ def convert_model_impl(self, model_obj): pt_res = model_obj(**self.example) graph = export(model_obj, tuple(), self.example) if version.parse(torch.__version__) >= version.parse("2.2"): - graph = graph.run_decompositions() + from torch._decomp import get_decompositions + from openvino.frontend.pytorch.torchdynamo.decompositions import get_export_decomposition_list + decomp = get_decompositions(get_export_decomposition_list()) + graph = graph.run_decompositions(decomp_table=decomp) gm = graph.module() print(gm.code) diff --git a/tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py b/tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py index b79b24e9ce76a3..dfe25f27d13d7d 100644 --- a/tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py +++ b/tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py @@ -40,15 +40,20 @@ def extract_module_extensions(args): except: pass if not is_good_version: - raise RuntimeError( - "NNCF models produced by nncf<2.6 are not supported directly. Please upgrade nncf or export to ONNX first.") + raise RuntimeError("NNCF models produced by nncf<2.6 are not " + "supported directly. Please upgrade nncf or " + "export to ONNX first.") inputs = prepare_torch_inputs(example_inputs) if not isinstance(model, (TorchScriptPythonDecoder, TorchFXPythonDecoder)): if hasattr(torch, "export") and isinstance(model, (torch.export.ExportedProgram)): from packaging import version if version.parse(torch.__version__) >= version.parse("2.2"): - model = model.run_decompositions() + from torch._decomp import get_decompositions + from openvino.frontend.pytorch.torchdynamo.decompositions import get_export_decomposition_list + decomp = get_decompositions(get_export_decomposition_list()) + model = model.run_decompositions(decomp_table=decomp) gm = model.module() + log.debug(gm.code) decoder = TorchFXPythonDecoder(gm) else: decoder = TorchScriptPythonDecoder(