-
Notifications
You must be signed in to change notification settings - Fork 441
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Differential Revision: D61341327 Pull Request resolved: #4730
- Loading branch information
1 parent
b66d62a
commit d8a00e6
Showing
5 changed files
with
240 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
op_hardtanh, | ||
op_mean_dim, | ||
op_mm, | ||
op_mul, | ||
op_permute, | ||
op_quant, | ||
op_repeat, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import List | ||
|
||
import executorch.backends.arm.tosa_quant_utils as tqutils | ||
import executorch.backends.arm.tosa_utils as tutils | ||
|
||
import serializer.tosa_serializer as ts | ||
import torch | ||
|
||
from executorch.backends.arm.operators.node_visitor import ( | ||
NodeVisitor, | ||
register_node_visitor, | ||
) | ||
from executorch.backends.arm.tosa_mapping import TosaArg | ||
from serializer.tosa_serializer import TosaOp | ||
|
||
|
||
@register_node_visitor | ||
class MulVisitor(NodeVisitor): | ||
target = "aten.mul.Tensor" | ||
|
||
def define_node( | ||
self, | ||
node: torch.fx.Node, | ||
tosa_graph: ts.TosaSerializer, | ||
inputs: List[TosaArg], | ||
output: TosaArg, | ||
is_quant_node: bool, | ||
) -> None: | ||
|
||
if is_quant_node: | ||
input_A = inputs[0] | ||
input_B = inputs[1] | ||
input_A_qargs = tqutils.get_quant_node_args(node.args[0]) | ||
input_B_qargs = tqutils.get_quant_node_args(node.args[1]) | ||
|
||
input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order) | ||
input_B.shape = tutils.tosa_shape(input_B.shape, input_B.dim_order) | ||
output_shape = tutils.tosa_shape(output.shape, output.dim_order) | ||
|
||
# Rescale inputs to INT32 with zp=0 | ||
input_A_rescaled = tqutils.build_rescale_to_int32( | ||
tosa_graph, | ||
input_A, | ||
input_A_qargs.zp, | ||
rescale_scale=1.0, | ||
) | ||
input_B_rescaled = tqutils.build_rescale_to_int32( | ||
tosa_graph, | ||
input_B, | ||
input_B_qargs.zp, | ||
rescale_scale=1.0, | ||
) | ||
|
||
mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32) | ||
|
||
# Do the INT32 Mul | ||
attr = ts.TosaSerializerAttribute() | ||
attr.MulAttribute(shift=0) | ||
tosa_graph.addOperator( | ||
TosaOp.Op().MUL, | ||
[ | ||
input_A_rescaled.name, | ||
input_B_rescaled.name, | ||
], | ||
[mul_output.name], | ||
attr, | ||
) | ||
|
||
tqutils.rescale_node_back_to_int8( | ||
node, mul_output, input_A_qargs.scale * input_B_qargs.scale, tosa_graph | ||
) | ||
|
||
else: | ||
attr = ts.TosaSerializerAttribute() | ||
attr.MulAttribute(shift=0) | ||
tosa_graph.addOperator( | ||
TosaOp.Op().MUL, [inputs[0].name, inputs[1].name], [output.name], attr | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import unittest | ||
|
||
import torch | ||
from executorch.backends.arm.test import common | ||
from executorch.backends.arm.test.tester.arm_tester import ArmTester | ||
from parameterized import parameterized | ||
|
||
test_data_sute = [ | ||
# (test_name, input, other,) See torch.mul() for info | ||
( | ||
"op_mul_rank1_ones", | ||
torch.ones(5), | ||
torch.ones(5), | ||
), | ||
( | ||
"op_mul_rank2_rand", | ||
torch.rand(4, 5), | ||
torch.rand(1, 5), | ||
), | ||
( | ||
"op_mul_rank3_randn", | ||
torch.randn(10, 5, 2), | ||
torch.randn(10, 5, 2), | ||
), | ||
( | ||
"op_mul_rank4_randn", | ||
torch.randn(5, 10, 25, 20), | ||
torch.randn(5, 10, 25, 20), | ||
), | ||
( | ||
"op_mul_rank4_ones_mul_negative", | ||
torch.ones(1, 10, 25, 20), | ||
(-1) * torch.ones(5, 10, 25, 20), | ||
), | ||
( | ||
"op_mul_rank4_negative_large_rand", | ||
(-200) * torch.rand(5, 10, 25, 20), | ||
torch.rand(5, 1, 1, 20), | ||
), | ||
( | ||
"op_mul_rank4_large_randn", | ||
200 * torch.randn(5, 10, 25, 20), | ||
torch.rand(5, 10, 25, 1), | ||
), | ||
] | ||
|
||
|
||
class TestMul(unittest.TestCase): | ||
class Mul(torch.nn.Module): | ||
|
||
def forward( | ||
self, | ||
input_: torch.Tensor, | ||
other_: torch.Tensor, | ||
): | ||
return input_ * other_ | ||
|
||
def _test_mul_tosa_MI_pipeline( | ||
self, module: torch.nn.Module, test_data: tuple[torch.Tensor, torch.Tensor] | ||
): | ||
( | ||
ArmTester( | ||
module, | ||
example_inputs=test_data, | ||
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), | ||
) | ||
.export() | ||
.check_count({"torch.ops.aten.mul.Tensor": 1}) | ||
.check_not(["torch.ops.quantized_decomposed"]) | ||
.to_edge() | ||
.partition() | ||
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) | ||
.to_executorch() | ||
.run_method_and_compare_outputs(inputs=test_data) | ||
) | ||
|
||
def _test_mul_tosa_BI_pipeline( | ||
self, module: torch.nn.Module, test_data: tuple[torch.Tensor, torch.Tensor] | ||
): | ||
( | ||
ArmTester( | ||
module, | ||
example_inputs=test_data, | ||
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), | ||
) | ||
.quantize() | ||
.export() | ||
.check_count({"torch.ops.aten.mul.Tensor": 1}) | ||
.check(["torch.ops.quantized_decomposed"]) | ||
.to_edge() | ||
.partition() | ||
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) | ||
.to_executorch() | ||
.run_method_and_compare_outputs(inputs=test_data, qtol=1.0) | ||
) | ||
|
||
def _test_mul_u55_BI_pipeline( | ||
self, module: torch.nn.Module, test_data: tuple[torch.Tensor, torch.Tensor] | ||
): | ||
( | ||
ArmTester( | ||
module, | ||
example_inputs=test_data, | ||
compile_spec=common.get_u55_compile_spec(permute_memory_to_nhwc=True), | ||
) | ||
.quantize() | ||
.export() | ||
.check_count({"torch.ops.aten.mul.Tensor": 1}) | ||
.check(["torch.ops.quantized_decomposed"]) | ||
.to_edge() | ||
.partition() | ||
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) | ||
.to_executorch() | ||
) | ||
|
||
@parameterized.expand(test_data_sute) | ||
def test_mul_tosa_MI( | ||
self, | ||
test_name: str, | ||
input_: torch.Tensor, | ||
other_: torch.Tensor, | ||
): | ||
test_data = (input_, other_) | ||
self._test_mul_tosa_MI_pipeline(self.Mul(), test_data) | ||
|
||
@parameterized.expand(test_data_sute) | ||
def test_mul_tosa_BI( | ||
self, | ||
test_name: str, | ||
input_: torch.Tensor, | ||
other_: torch.Tensor, | ||
): | ||
|
||
test_data = (input_, other_) | ||
self._test_mul_tosa_BI_pipeline(self.Mul(), test_data) | ||
|
||
# Expected to fail since RESCALE cannot be fused with MUL in Vela. | ||
@parameterized.expand(test_data_sute) | ||
@unittest.expectedFailure | ||
def test_mul_u55_BI( | ||
self, | ||
test_name: str, | ||
input_: torch.Tensor, | ||
other_: torch.Tensor, | ||
): | ||
test_data = (input_, other_) | ||
self._test_mul_u55_BI_pipeline(self.Mul(), test_data) |