diff --git a/backends/arm/arm_partitioner.py b/backends/arm/arm_partitioner.py index 8726533b34..f73d97480b 100644 --- a/backends/arm/arm_partitioner.py +++ b/backends/arm/arm_partitioner.py @@ -45,6 +45,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: exir_ops.edge.aten.div.Tensor, exir_ops.edge.aten.split_with_sizes_copy.default, exir_ops.edge.aten.full.default, + exir_ops.edge.aten.mul.Tensor, exir_ops.edge.aten._native_batch_norm_legit_no_training.default, exir_ops.edge.aten.avg_pool2d.default, exir_ops.edge.aten.sigmoid.default, diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index bfd043a803..94a16d8c94 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -17,6 +17,7 @@ op_hardtanh, op_mean_dim, op_mm, + op_mul, op_permute, op_quant, op_repeat, diff --git a/backends/arm/operators/op_mul.py b/backends/arm/operators/op_mul.py new file mode 100644 index 0000000000..e9cbfcbd7c --- /dev/null +++ b/backends/arm/operators/op_mul.py @@ -0,0 +1,83 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List + +import executorch.backends.arm.tosa_quant_utils as tqutils +import executorch.backends.arm.tosa_utils as tutils + +import serializer.tosa_serializer as ts +import torch + +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from serializer.tosa_serializer import TosaOp + + +@register_node_visitor +class MulVisitor(NodeVisitor): + target = "aten.mul.Tensor" + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + + if is_quant_node: + input_A = inputs[0] + input_B = inputs[1] + input_A_qargs = tqutils.get_quant_node_args(node.args[0]) + input_B_qargs = tqutils.get_quant_node_args(node.args[1]) + + input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order) + input_B.shape = tutils.tosa_shape(input_B.shape, input_B.dim_order) + output_shape = tutils.tosa_shape(output.shape, output.dim_order) + + # Rescale inputs to INT32 with zp=0 + input_A_rescaled = tqutils.build_rescale_to_int32( + tosa_graph, + input_A, + input_A_qargs.zp, + rescale_scale=1.0, + ) + input_B_rescaled = tqutils.build_rescale_to_int32( + tosa_graph, + input_B, + input_B_qargs.zp, + rescale_scale=1.0, + ) + + mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32) + + # Do the INT32 Mul + attr = ts.TosaSerializerAttribute() + attr.MulAttribute(shift=0) + tosa_graph.addOperator( + TosaOp.Op().MUL, + [ + input_A_rescaled.name, + input_B_rescaled.name, + ], + [mul_output.name], + attr, + ) + + tqutils.rescale_node_back_to_int8( + node, mul_output, input_A_qargs.scale * input_B_qargs.scale, tosa_graph + ) + + else: + attr = ts.TosaSerializerAttribute() + attr.MulAttribute(shift=0) + tosa_graph.addOperator( + TosaOp.Op().MUL, [inputs[0].name, inputs[1].name], [output.name], attr + ) diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index 3a4e697048..8d5edf386a 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -73,6 +73,7 @@ def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPattern [torch.nn.AdaptiveAvgPool2d], [F.adaptive_avg_pool2d], ], + "mul": [torch.mul], "sub": [[torch.sub]], } return copy.deepcopy(supported_operators) diff --git a/backends/arm/test/ops/test_mul.py b/backends/arm/test/ops/test_mul.py new file mode 100644 index 0000000000..dee8b62f1b --- /dev/null +++ b/backends/arm/test/ops/test_mul.py @@ -0,0 +1,154 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from parameterized import parameterized + +test_data_sute = [ + # (test_name, input, other,) See torch.mul() for info + ( + "op_mul_rank1_ones", + torch.ones(5), + torch.ones(5), + ), + ( + "op_mul_rank2_rand", + torch.rand(4, 5), + torch.rand(1, 5), + ), + ( + "op_mul_rank3_randn", + torch.randn(10, 5, 2), + torch.randn(10, 5, 2), + ), + ( + "op_mul_rank4_randn", + torch.randn(5, 10, 25, 20), + torch.randn(5, 10, 25, 20), + ), + ( + "op_mul_rank4_ones_mul_negative", + torch.ones(1, 10, 25, 20), + (-1) * torch.ones(5, 10, 25, 20), + ), + ( + "op_mul_rank4_negative_large_rand", + (-200) * torch.rand(5, 10, 25, 20), + torch.rand(5, 1, 1, 20), + ), + ( + "op_mul_rank4_large_randn", + 200 * torch.randn(5, 10, 25, 20), + torch.rand(5, 10, 25, 1), + ), +] + + +class TestMul(unittest.TestCase): + class Mul(torch.nn.Module): + + def forward( + self, + input_: torch.Tensor, + other_: torch.Tensor, + ): + return input_ * other_ + + def _test_mul_tosa_MI_pipeline( + self, module: torch.nn.Module, test_data: tuple[torch.Tensor, torch.Tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + ) + .export() + .check_count({"torch.ops.aten.mul.Tensor": 1}) + .check_not(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_mul_tosa_BI_pipeline( + self, module: torch.nn.Module, test_data: tuple[torch.Tensor, torch.Tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + ) + .quantize() + .export() + .check_count({"torch.ops.aten.mul.Tensor": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data, qtol=1.0) + ) + + def _test_mul_u55_BI_pipeline( + self, module: torch.nn.Module, test_data: tuple[torch.Tensor, torch.Tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_u55_compile_spec(permute_memory_to_nhwc=True), + ) + .quantize() + .export() + .check_count({"torch.ops.aten.mul.Tensor": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + ) + + @parameterized.expand(test_data_sute) + def test_mul_tosa_MI( + self, + test_name: str, + input_: torch.Tensor, + other_: torch.Tensor, + ): + test_data = (input_, other_) + self._test_mul_tosa_MI_pipeline(self.Mul(), test_data) + + @parameterized.expand(test_data_sute) + def test_mul_tosa_BI( + self, + test_name: str, + input_: torch.Tensor, + other_: torch.Tensor, + ): + + test_data = (input_, other_) + self._test_mul_tosa_BI_pipeline(self.Mul(), test_data) + + # Expected to fail since RESCALE cannot be fused with MUL in Vela. + @parameterized.expand(test_data_sute) + @unittest.expectedFailure + def test_mul_u55_BI( + self, + test_name: str, + input_: torch.Tensor, + other_: torch.Tensor, + ): + test_data = (input_, other_) + self._test_mul_u55_BI_pipeline(self.Mul(), test_data)