Skip to content

Commit

Permalink
List decompositions for torch.export (#26878)
Browse files Browse the repository at this point in the history
### Details:
 - *item1*
 - *...*

### Tickets:
 - *ticket-id*
  • Loading branch information
mvafin authored Oct 7, 2024
1 parent 9027e1d commit 46a6ccd
Show file tree
Hide file tree
Showing 8 changed files with 231 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
# flake8: noqa
# mypy: ignore-errors

import logging
import torch

from openvino.frontend.pytorch.py_pytorch_frontend import _FrontEndPytorchDecoder as Decoder
from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType
from openvino.runtime import op, PartialShape, Type as OVType, OVAny, Shape
from openvino.runtime import PartialShape, Type as OVType, OVAny, Shape
from openvino.frontend.pytorch.utils import make_constant, fetch_attr, pt_to_ov_type_map, torch_tensor_to_ov_const

import torch

import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def convolution_backward(

return grad_input, grad_weight, grad_bias


if len(get_decompositions([aten._scaled_dot_product_flash_attention.default])) == 0:

@register_decomposition(aten._scaled_dot_product_flash_attention.default)
def scaled_dot_product_flash_attention(
query,
Expand Down Expand Up @@ -101,16 +103,197 @@ def scaled_dot_product_flash_attention(


def get_aot_decomposition_list():
return ([torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops.aten._softmax.default,
torch.ops.aten._softmax_backward_data.default,
torch.ops.aten.convolution_backward.default,
torch.ops.aten.gelu_backward.default,
torch.ops.aten.native_group_norm.default,
torch.ops.aten.native_group_norm_backward.default,
torch.ops.aten.native_layer_norm.default,
torch.ops.aten.native_layer_norm_backward.default,
torch.ops.aten.slice_backward.default])
return [
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops.aten._softmax.default,
torch.ops.aten._softmax_backward_data.default,
torch.ops.aten.convolution_backward.default,
torch.ops.aten.gelu_backward.default,
torch.ops.aten.native_group_norm.default,
torch.ops.aten.native_group_norm_backward.default,
torch.ops.aten.native_layer_norm.default,
torch.ops.aten.native_layer_norm_backward.default,
torch.ops.aten.slice_backward.default,
]


def get_inf_decomposition_list():
return ([torch.ops.aten.nll_loss_forward.default])
return [torch.ops.aten.nll_loss_forward.default]


def get_export_decomposition_list():
# List of decompositions from torch._decomp.core_aten_decompositions
# removed _backward ops and ops supported without decomposition
decomp = [
torch.ops.aten.addcdiv,
torch.ops.aten.addcdiv_,
torch.ops.aten.addcmul,
torch.ops.aten.addcmul_,
torch.ops.aten.addr,
torch.ops.aten.affine_grid_generator,
torch.ops.aten.all,
torch.ops.aten.aminmax,
torch.ops.aten.arange.default,
torch.ops.aten.arange.start,
torch.ops.aten.baddbmm,
torch.ops.aten.binary_cross_entropy,
torch.ops.aten.binary_cross_entropy_with_logits,
torch.ops.aten.block_diag,
torch.ops.aten.celu,
torch.ops.aten.celu_,
torch.ops.aten.clamp_max,
torch.ops.aten.clamp_min,
torch.ops.aten.count_nonzero,
torch.ops.aten.linalg_cross,
torch.ops.aten.cudnn_batch_norm,
torch.ops.aten.deg2rad,
torch.ops.aten.deg2rad_,
torch.ops.aten.detach,
torch.ops.aten.diag_embed,
torch.ops.aten.dot,
torch.ops.aten.vdot,
torch.ops.aten.elu,
torch.ops.aten.elu_,
torch.ops.aten._embedding_bag,
torch.ops.aten.empty_like,
torch.ops.aten._euclidean_dist.default,
torch.ops.aten.expand_as,
torch.ops.aten.eye,
torch.ops.aten.fill,
torch.ops.aten.fill_,
torch.ops.aten.floor_divide,
torch.ops.aten.frac,
torch.ops.aten.frac_,
torch.ops.aten._fused_moving_avg_obs_fq_helper,
torch.ops.aten.gelu_,
torch.ops.aten.glu,
torch.ops.aten.hardshrink,
torch.ops.aten.hardsigmoid,
torch.ops.aten.hardsigmoid_,
torch.ops.aten.hardswish,
torch.ops.aten.hardswish_,
torch.ops.aten.hardtanh_,
torch.ops.aten.heaviside,
torch.ops.aten.heaviside_,
torch.ops.aten.huber_loss,
torch.ops.aten.im2col,
torch.ops.aten.index_add,
torch.ops.aten.index_add_,
torch.ops.aten.index_copy,
torch.ops.aten.index_copy_,
torch.ops.aten.index_fill,
torch.ops.aten.index_fill_,
torch.ops.aten.isin,
torch.ops.aten.isneginf,
torch.ops.aten.isposinf,
torch.ops.aten.l1_loss,
torch.ops.aten.leaky_relu_,
torch.ops.aten.lerp,
torch.ops.aten.lerp_,
torch.ops.aten.linspace,
torch.ops.aten.logaddexp,
torch.ops.aten.logaddexp2,
torch.ops.aten.logit,
torch.ops.aten.logit_,
torch.ops.aten.log_sigmoid_forward,
torch.ops.aten.logspace,
torch.ops.aten.logsumexp.default,
torch.ops.aten.masked_fill,
torch.ops.aten.masked_fill_,
torch.ops.aten.mish,
torch.ops.aten.mish_,
torch.ops.aten.mse_loss,
torch.ops.aten.multi_margin_loss,
torch.ops.aten.multilabel_margin_loss_forward,
torch.ops.aten.mv,
torch.ops.aten.mvlgamma,
torch.ops.aten.mvlgamma_,
torch.ops.aten.nansum,
torch.ops.aten.nan_to_num,
torch.ops.aten.nan_to_num_,
torch.ops.aten.narrow,
torch.ops.aten.new_empty,
torch.ops.aten.new_full,
torch.ops.aten.new_ones,
torch.ops.aten.new_zeros,
torch.ops.aten.nll_loss_forward,
torch.ops.aten.norm,
torch.ops.aten.ones,
torch.ops.aten.ones_like,
torch.ops.aten._prelu_kernel,
torch.ops.aten._reshape_alias,
torch.ops.aten.rad2deg,
torch.ops.aten.rad2deg_,
torch.ops.aten.reflection_pad1d,
torch.ops.aten.reflection_pad2d,
torch.ops.aten.reflection_pad3d,
torch.ops.aten.replication_pad1d,
torch.ops.aten.replication_pad2d,
torch.ops.aten.replication_pad3d,
torch.ops.aten.renorm,
torch.ops.aten.renorm_,
torch.ops.aten.resize_as,
torch.ops.aten.roll,
torch.ops.aten.rot90,
torch.ops.aten.rrelu_with_noise,
torch.ops.aten.rrelu_with_noise_,
torch.ops.aten.rsub,
torch.ops.aten.select_scatter,
torch.ops.aten.sgn,
torch.ops.aten.sgn_,
torch.ops.aten.silu,
torch.ops.aten.silu_,
torch.ops.aten.sinc,
torch.ops.aten.sinc_,
torch.ops.aten.smooth_l1_loss,
torch.ops.aten.soft_margin_loss,
torch.ops.aten.softplus,
torch.ops.aten.softshrink,
torch.ops.aten.special_entr,
torch.ops.aten.special_log_ndtr,
torch.ops.aten.special_xlog1py,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes_copy,
torch.ops.aten.squeeze.default,
torch.ops.aten.squeeze.dim,
torch.ops.aten.std,
torch.ops.aten.std_mean,
torch.ops.aten.stack,
torch.ops.aten.sum.default,
torch.ops.aten.sum.out,
torch.ops.aten.t,
torch.ops.aten.take,
torch.ops.aten.threshold,
torch.ops.aten.threshold_,
torch.ops.aten.trace,
torch.ops.aten.transpose.int,
torch.ops.aten.tril,
torch.ops.aten.tril_,
torch.ops.aten.triu,
torch.ops.aten.triu_,
torch.ops.aten.unbind,
torch.ops.aten.unfold_copy,
torch.ops.aten._unsafe_index,
torch.ops.aten.unsafe_split.Tensor,
torch.ops.aten.unsafe_split_with_sizes,
torch.ops.aten._unsafe_view,
torch.ops.aten.view_as_complex,
torch.ops.aten.xlogy,
torch.ops.aten.xlogy_,
torch.ops.aten.zero,
torch.ops.aten.zero_,
torch.ops.aten.zeros,
torch.ops.aten.zeros_like,
torch.ops.aten._weight_norm_interface,
]
try:
from packaging import version
if version.parse(torch.__version__) >= version.parse("2.3"):
decomp += [
torch.ops.aten._lazy_clone,
torch.ops.aten._test_parallel_materialize,
torch.ops.aten._chunk_cat,
]
except ImportError:
pass
return decomp
1 change: 1 addition & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.clamp_min.default", op::translate_1to1_match_2_inputs_align_types<opset10::Maximum>},
{"aten.clamp_min.Tensor", op::translate_1to1_match_2_inputs_align_types<opset10::Maximum>},
{"aten.clone.default", op::skip_node}, // ignore clone operators that are inserted by PyTorch autograd
{"aten.col2im.default", op::translate_col2im},
{"aten.constant_pad_nd.default", op::translate_constant_pad_nd_fx},
{"aten.convolution.default", op::translate_convolution},
{"aten.copy.default", op::translate_copy_fx},
Expand Down
19 changes: 8 additions & 11 deletions tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@
import warnings
from copy import deepcopy
import os

import torch
import pytest
import logging
import numpy as np

from common.constants import test_device, test_precision
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder

from openvino.frontend import FrontEndManager
from openvino.runtime import Core, Type, PartialShape
import openvino.properties.hint as hints
import torch
from packaging import version
import pytest

logging.basicConfig(level=logging.DEBUG)


def skip_check(param):
Expand Down Expand Up @@ -124,13 +125,9 @@ def numpy_to_torch_recursively(x):
from torch.export import export

em = export(model, tuple(torch_inputs))
if version.parse(torch.__version__) >= version.parse("2.3"):
em = em.run_decompositions()
gm = em.module()
print(gm.code)

converted_model = convert_model(
em, example_input=torch_inputs)
em, example_input=torch_inputs, verbose=True)
self._resolve_input_shape_dtype(
converted_model, ov_inputs, dynamic_shapes)
smodel = model
Expand Down Expand Up @@ -242,7 +239,7 @@ def convert_via_mo(self, model, example_input, trace_model, dynamic_shapes, ov_i
if not dynamic_shapes:
input_shapes = [inp.shape for inp in ov_inputs]
kwargs["input"] = input_shapes
om = convert_model(decoder, **kwargs)
om = convert_model(decoder, verbose=True, **kwargs)
self._resolve_input_shape_dtype(om, ov_inputs, dynamic_shapes)
return smodel, om

Expand Down
1 change: 1 addition & 0 deletions tests/layer_tests/pytorch_tests/test_col2im.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def forward(self, x):

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_torch_export
@pytest.mark.parametrize("output_size,kernel_size", [([4, 5], [2, 2])])
@pytest.mark.parametrize("dilation", [1, 2, [1, 2]])
@pytest.mark.parametrize("padding", [0, 5, [2, 3]])
Expand Down
20 changes: 11 additions & 9 deletions tests/layer_tests/pytorch_tests/test_eye.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
import torch
from packaging import version

from pytorch_layer_test_class import PytorchLayerTest

Expand All @@ -14,7 +15,6 @@ def _prepare_input(self, m, n=None):
return (np.array(m, dtype="int32"), )
return (np.array(m, dtype="int32"), np.array(n, dtype="int32"))


def create_model(self, num_inputs, dtype):
import torch
dtype_map = {
Expand Down Expand Up @@ -45,29 +45,31 @@ def __init__(self, dtype):
def forward(self, x, y):
return torch.eye(x, y, dtype=self.dtype)


ref_net = None

return aten_eye_1_input(pt_dtype) if num_inputs == 1 else aten_eye_2_inputs(pt_dtype), ref_net, ("aten::eye", "aten::IntImplicit")
model = aten_eye_1_input(pt_dtype) if num_inputs == 1 else aten_eye_2_inputs(pt_dtype)
return model, None, ["aten::eye", "aten::IntImplicit"]

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_torch_export
@pytest.mark.parametrize("dtype", ["bool", "int8", "uint8", "int32", "int64", "float32", "float64"])
@pytest.mark.parametrize("m", [2, 3, 4, 5])
@pytest.mark.skipif(torch.__version__ < '2.3.0', reason="`aten.eye` is not supported in PyTorch versions earlier than 2.3.")
def test_eye_square(self, dtype, m, ie_device, precision, ir_version):
if PytorchLayerTest.use_torch_export() and version.parse(torch.__version__) < version.parse("2.3"):
pytest.skip("Not supported in PyTorch versions earlier than 2.3.")
if ie_device == "GPU":
pytest.xfail(reason="eye is not supported on GPU")
self._test(*self.create_model(1, dtype), ie_device, precision, ir_version, kwargs_to_prepare_input={"m": m})
self._test(*self.create_model(1, dtype), ie_device, precision,
ir_version, kwargs_to_prepare_input={"m": m})

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_torch_export
@pytest.mark.parametrize("dtype", ["bool", "int8", "uint8", "int32", "int64", "float32", "float64"])
@pytest.mark.parametrize(("m", "n"), [[2, 2], [3, 4], [5, 3]])
@pytest.mark.skipif(torch.__version__ < '2.3.0', reason="`aten.eye` is not supported in PyTorch versions earlier than 2.3.")
def test_eye(self, dtype, m, n, ie_device, precision, ir_version):
if (PytorchLayerTest.use_torch_export() and version.parse(torch.__version__) < version.parse("2.3")):
pytest.skip("Not supported in PyTorch versions earlier than 2.3.")
if ie_device == "GPU":
pytest.xfail(reason="eye is not supported on GPU")
self._test(*self.create_model(2, dtype), ie_device, precision, ir_version, kwargs_to_prepare_input={"m": m, "n": n})
self._test(*self.create_model(2, dtype), ie_device, precision,
ir_version, kwargs_to_prepare_input={"m": m, "n": n})
5 changes: 4 additions & 1 deletion tests/model_hub_tests/pytorch/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ def convert_model_impl(self, model_obj):
pt_res = model_obj(**self.example)
graph = export(model_obj, tuple(), self.example)
if version.parse(torch.__version__) >= version.parse("2.2"):
graph = graph.run_decompositions()
from torch._decomp import get_decompositions
from openvino.frontend.pytorch.torchdynamo.decompositions import get_export_decomposition_list
decomp = get_decompositions(get_export_decomposition_list())
graph = graph.run_decompositions(decomp_table=decomp)

gm = graph.module()
print(gm.code)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,20 @@ def extract_module_extensions(args):
except:
pass
if not is_good_version:
raise RuntimeError(
"NNCF models produced by nncf<2.6 are not supported directly. Please upgrade nncf or export to ONNX first.")
raise RuntimeError("NNCF models produced by nncf<2.6 are not "
"supported directly. Please upgrade nncf or "
"export to ONNX first.")
inputs = prepare_torch_inputs(example_inputs)
if not isinstance(model, (TorchScriptPythonDecoder, TorchFXPythonDecoder)):
if hasattr(torch, "export") and isinstance(model, (torch.export.ExportedProgram)):
from packaging import version
if version.parse(torch.__version__) >= version.parse("2.2"):
model = model.run_decompositions()
from torch._decomp import get_decompositions
from openvino.frontend.pytorch.torchdynamo.decompositions import get_export_decomposition_list
decomp = get_decompositions(get_export_decomposition_list())
model = model.run_decompositions(decomp_table=decomp)
gm = model.module()
log.debug(gm.code)
decoder = TorchFXPythonDecoder(gm)
else:
decoder = TorchScriptPythonDecoder(
Expand Down

0 comments on commit 46a6ccd

Please sign in to comment.