Skip to content

Commit

Permalink
Add mul-op for Arm backend
Browse files Browse the repository at this point in the history
Differential Revision: D61341327

Pull Request resolved: #4730
  • Loading branch information
oscarandersson8218 authored Aug 21, 2024
1 parent b66d62a commit d8a00e6
Show file tree
Hide file tree
Showing 5 changed files with 240 additions and 0 deletions.
1 change: 1 addition & 0 deletions backends/arm/arm_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
exir_ops.edge.aten.div.Tensor,
exir_ops.edge.aten.split_with_sizes_copy.default,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
exir_ops.edge.aten.avg_pool2d.default,
exir_ops.edge.aten.sigmoid.default,
Expand Down
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
op_hardtanh,
op_mean_dim,
op_mm,
op_mul,
op_permute,
op_quant,
op_repeat,
Expand Down
83 changes: 83 additions & 0 deletions backends/arm/operators/op_mul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List

import executorch.backends.arm.tosa_quant_utils as tqutils
import executorch.backends.arm.tosa_utils as tutils

import serializer.tosa_serializer as ts
import torch

from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from serializer.tosa_serializer import TosaOp


@register_node_visitor
class MulVisitor(NodeVisitor):
target = "aten.mul.Tensor"

def define_node(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
) -> None:

if is_quant_node:
input_A = inputs[0]
input_B = inputs[1]
input_A_qargs = tqutils.get_quant_node_args(node.args[0])
input_B_qargs = tqutils.get_quant_node_args(node.args[1])

input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order)
input_B.shape = tutils.tosa_shape(input_B.shape, input_B.dim_order)
output_shape = tutils.tosa_shape(output.shape, output.dim_order)

# Rescale inputs to INT32 with zp=0
input_A_rescaled = tqutils.build_rescale_to_int32(
tosa_graph,
input_A,
input_A_qargs.zp,
rescale_scale=1.0,
)
input_B_rescaled = tqutils.build_rescale_to_int32(
tosa_graph,
input_B,
input_B_qargs.zp,
rescale_scale=1.0,
)

mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32)

# Do the INT32 Mul
attr = ts.TosaSerializerAttribute()
attr.MulAttribute(shift=0)
tosa_graph.addOperator(
TosaOp.Op().MUL,
[
input_A_rescaled.name,
input_B_rescaled.name,
],
[mul_output.name],
attr,
)

tqutils.rescale_node_back_to_int8(
node, mul_output, input_A_qargs.scale * input_B_qargs.scale, tosa_graph
)

else:
attr = ts.TosaSerializerAttribute()
attr.MulAttribute(shift=0)
tosa_graph.addOperator(
TosaOp.Op().MUL, [inputs[0].name, inputs[1].name], [output.name], attr
)
1 change: 1 addition & 0 deletions backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPattern
[torch.nn.AdaptiveAvgPool2d],
[F.adaptive_avg_pool2d],
],
"mul": [torch.mul],
"sub": [[torch.sub]],
}
return copy.deepcopy(supported_operators)
Expand Down
154 changes: 154 additions & 0 deletions backends/arm/test/ops/test_mul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from parameterized import parameterized

test_data_sute = [
# (test_name, input, other,) See torch.mul() for info
(
"op_mul_rank1_ones",
torch.ones(5),
torch.ones(5),
),
(
"op_mul_rank2_rand",
torch.rand(4, 5),
torch.rand(1, 5),
),
(
"op_mul_rank3_randn",
torch.randn(10, 5, 2),
torch.randn(10, 5, 2),
),
(
"op_mul_rank4_randn",
torch.randn(5, 10, 25, 20),
torch.randn(5, 10, 25, 20),
),
(
"op_mul_rank4_ones_mul_negative",
torch.ones(1, 10, 25, 20),
(-1) * torch.ones(5, 10, 25, 20),
),
(
"op_mul_rank4_negative_large_rand",
(-200) * torch.rand(5, 10, 25, 20),
torch.rand(5, 1, 1, 20),
),
(
"op_mul_rank4_large_randn",
200 * torch.randn(5, 10, 25, 20),
torch.rand(5, 10, 25, 1),
),
]


class TestMul(unittest.TestCase):
class Mul(torch.nn.Module):

def forward(
self,
input_: torch.Tensor,
other_: torch.Tensor,
):
return input_ * other_

def _test_mul_tosa_MI_pipeline(
self, module: torch.nn.Module, test_data: tuple[torch.Tensor, torch.Tensor]
):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True),
)
.export()
.check_count({"torch.ops.aten.mul.Tensor": 1})
.check_not(["torch.ops.quantized_decomposed"])
.to_edge()
.partition()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data)
)

def _test_mul_tosa_BI_pipeline(
self, module: torch.nn.Module, test_data: tuple[torch.Tensor, torch.Tensor]
):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True),
)
.quantize()
.export()
.check_count({"torch.ops.aten.mul.Tensor": 1})
.check(["torch.ops.quantized_decomposed"])
.to_edge()
.partition()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data, qtol=1.0)
)

def _test_mul_u55_BI_pipeline(
self, module: torch.nn.Module, test_data: tuple[torch.Tensor, torch.Tensor]
):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_u55_compile_spec(permute_memory_to_nhwc=True),
)
.quantize()
.export()
.check_count({"torch.ops.aten.mul.Tensor": 1})
.check(["torch.ops.quantized_decomposed"])
.to_edge()
.partition()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
)

@parameterized.expand(test_data_sute)
def test_mul_tosa_MI(
self,
test_name: str,
input_: torch.Tensor,
other_: torch.Tensor,
):
test_data = (input_, other_)
self._test_mul_tosa_MI_pipeline(self.Mul(), test_data)

@parameterized.expand(test_data_sute)
def test_mul_tosa_BI(
self,
test_name: str,
input_: torch.Tensor,
other_: torch.Tensor,
):

test_data = (input_, other_)
self._test_mul_tosa_BI_pipeline(self.Mul(), test_data)

# Expected to fail since RESCALE cannot be fused with MUL in Vela.
@parameterized.expand(test_data_sute)
@unittest.expectedFailure
def test_mul_u55_BI(
self,
test_name: str,
input_: torch.Tensor,
other_: torch.Tensor,
):
test_data = (input_, other_)
self._test_mul_u55_BI_pipeline(self.Mul(), test_data)

0 comments on commit d8a00e6

Please sign in to comment.