From 3e6e8eba4ae0b032e09cb10bd684297f2c42516a Mon Sep 17 00:00:00 2001 From: Bradley Davis Date: Tue, 30 Jan 2024 13:33:29 -0800 Subject: [PATCH 1/2] allow type promotions from bool if common type is non-bool (#985) Summary: for mixed elementwise ops, allow bools to be cast to common dtype. this behavior mimics the behavior of PyTorch when doing things like multiplying a float16 tensor by a bool tensor (results in a float16 tensor). Reviewed By: khabinov Differential Revision: D52770344 --- .../compiler/ops/common/elementwise.py | 31 ++++-- .../compiler/test_op_common_elementwise.py | 98 +++++++++++++++++++ 2 files changed, 123 insertions(+), 6 deletions(-) create mode 100644 tests/unittest/compiler/test_op_common_elementwise.py diff --git a/python/aitemplate/compiler/ops/common/elementwise.py b/python/aitemplate/compiler/ops/common/elementwise.py index a5bc5847a..1881074bf 100644 --- a/python/aitemplate/compiler/ops/common/elementwise.py +++ b/python/aitemplate/compiler/ops/common/elementwise.py @@ -23,7 +23,7 @@ from aitemplate.compiler.op_registry import OP_REGISTRY from aitemplate.compiler.ops.common.epilogue import FuncEnum from aitemplate.compiler.ops.common.int_elementwise import INT_ELEMENTWISE_FUNC - +from aitemplate.compiler.ops.tensor import cast from aitemplate.utils import shape_utils # pylint: disable=C0103,W0221,W0102,C0301,W0223,R1724 @@ -225,12 +225,31 @@ def __call__(self, *args: Tensor) -> Tensor: symbolic_args.append(arg._attrs["int_var"].symbolic_value()) elif isinstance(arg, Tensor): converted_args.append(arg) + arg_dtype = normalize_dtype(arg.dtype()) if common_dtype is None: - common_dtype = normalize_dtype(arg.dtype()) - elif normalize_dtype(arg.dtype()) != common_dtype: - raise NotImplementedError( - f"Type promotions are not supported; got dtype {arg.dtype()}, but expected {common_dtype}" - ) + common_dtype = arg_dtype + elif arg_dtype != common_dtype: + if arg.dtype() == "bool" and common_dtype != "bool": + # If this arg is bool, and the common is not bool, cast to the common type. + converted_args[-1] = cast()( + x=converted_args[-1], dtype=common_dtype + ) + elif ( + arg.dtype() != "bool" + and common_dtype == "bool" + and len(converted_args) >= 2 + ): + # If this arg is non-bool and the common type is bool, + # cast all previous bool args to the non-bool type. + common_dtype = arg_dtype + for i in range(0, len(converted_args) - 1): + converted_args[i] = cast()( + x=converted_args[i], dtype=common_dtype + ) + else: + raise NotImplementedError( + f"Type promotions are not supported; got dtype {arg.dtype()}, but expected {common_dtype}" + ) symbolic_args.append(arg._attrs.get("symbolic_value", None)) else: raise RuntimeError( diff --git a/tests/unittest/compiler/test_op_common_elementwise.py b/tests/unittest/compiler/test_op_common_elementwise.py new file mode 100644 index 000000000..403445034 --- /dev/null +++ b/tests/unittest/compiler/test_op_common_elementwise.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from aitemplate.compiler import compile_model, ops + +from aitemplate.compiler.base import Tensor +from aitemplate.compiler.ops.common.epilogue import FuncEnum +from aitemplate.testing import detect_target +from aitemplate.testing.test_utils import ( + get_random_torch_tensor, + get_torch_empty_tensor, +) + + +def _make_graph(): + X0 = Tensor( + shape=[3, 5, 7, 9], + dtype="float16", + name="X0", + is_input=True, + ) + + Y = ops.elementwise(FuncEnum.ABS)(ops.elementwise(FuncEnum.SIN)(X0)) + + Y._attrs["is_output"] = True + Y._attrs["name"] = "Y" + return Y + + +class OpCommonElementwiseTestCase(unittest.TestCase): + def test_elementwise_type_promotion_bool_rhs(self): + X0 = Tensor( + shape=[3, 5, 2], + dtype="float16", + name="X0", + is_input=True, + ) + X1 = Tensor( + shape=[3, 5, 2], + dtype="bool", + name="X1", + is_input=True, + ) + Y = ops.elementwise(FuncEnum.MUL)(X0, X1) + Y._attrs["name"] = "output0" + Y._attrs["is_output"] = True + target = detect_target() + module = compile_model( + Y, + target, + "./tmp", + "test_elementwise_type_promotion_bool_rhs", + ) + x0_pt = get_random_torch_tensor([3, 5, 2], "float16") + x1_pt = get_random_torch_tensor([3, 5, 2], "bool") + out_pt = get_torch_empty_tensor([3, 5, 2], "float16") + module.run_with_tensors({"X0": x0_pt, "X1": x1_pt}, {"output0": out_pt}) + + def test_elementwise_type_promotion_bool_lhs(self): + X0 = Tensor( + shape=[3, 5, 2], + dtype="bool", + name="X1", + is_input=True, + ) + X1 = Tensor( + shape=[3, 5, 2], + dtype="float16", + name="X0", + is_input=True, + ) + Y = ops.elementwise(FuncEnum.MUL)(X0, X1) + Y._attrs["name"] = "output0" + Y._attrs["is_output"] = True + target = detect_target() + module = compile_model( + Y, + target, + "./tmp", + "test_elementwise_type_promotion_bool_lhs", + ) + x0_pt = get_random_torch_tensor([3, 5, 2], "float16") + x1_pt = get_random_torch_tensor([3, 5, 2], "bool") + out_pt = get_torch_empty_tensor([3, 5, 2], "float16") + module.run_with_tensors({"X0": x0_pt, "X1": x1_pt}, {"output0": out_pt}) From 5ee2e16c9ab883b3d7a5e36bf7a27df6250fea4f Mon Sep 17 00:00:00 2001 From: Bradley Davis Date: Tue, 30 Jan 2024 13:33:29 -0800 Subject: [PATCH 2/2] add support for permute call with tuple args (#986) Summary: PyTorch allows permute ops to be called on tensors with tuple args, e.g. `tensor.permute((1, 0, 2))`. this diff adds support for unwrapping the nested tuple/list. Reviewed By: khabinov Differential Revision: D53236434 --- fx2ait/fx2ait/converters/ait_converters.py | 12 +++- .../test/converters/test_ait_permute.py | 67 +++++++++++++++++++ 2 files changed, 78 insertions(+), 1 deletion(-) create mode 100644 fx2ait/fx2ait/test/converters/test_ait_permute.py diff --git a/fx2ait/fx2ait/converters/ait_converters.py b/fx2ait/fx2ait/converters/ait_converters.py index 5149481a2..3c21094ba 100644 --- a/fx2ait/fx2ait/converters/ait_converters.py +++ b/fx2ait/fx2ait/converters/ait_converters.py @@ -376,10 +376,20 @@ def acc_ops_permute( ) -> ConverterOutput: input_val = kwargs["input"] if not isinstance(input_val, AITTensor): - raise ValueError(f"Unexpected input for {name}: {input_val}") + raise ValueError(f"Unexpected input for {name}: input={input_val}") permutation = kwargs["permutation"] + if ( + isinstance(permutation, (list, tuple)) + and permutation + and isinstance(permutation[0], (list, tuple)) + ): + # If permutation is a nested list or tuple, unwrap one level. + # This is needed for some valid invocations of permute like + # t.permute((2, 0, 1)). + permutation = permutation[0] + return permute()(input_val, permutation) diff --git a/fx2ait/fx2ait/test/converters/test_ait_permute.py b/fx2ait/fx2ait/test/converters/test_ait_permute.py new file mode 100644 index 000000000..9a910cee2 --- /dev/null +++ b/fx2ait/fx2ait/test/converters/test_ait_permute.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +#!/usr/bin/env fbpython +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +import torch +from fx2ait.acc_tracer import acc_ops +from fx2ait.tools.common_fx2ait import AITTestCase + + +class TestPermuteConverter(AITTestCase): + def test_permute_torch_op( + self, + ): + class TestModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.permute(x, (2, 0, 1)) + + model = TestModule().half().cuda() + inputs = [torch.randn(32, 256, 256).cuda().half()] + self.run_test( + model, + inputs, + expected_ops={acc_ops.permute}, + ) + + def test_permute_op_on_tensor_tuple( + self, + ): + class TestModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.permute((2, 0, 1)) + + model = TestModule().half().cuda() + inputs = [torch.randn(32, 256, 256).cuda().half()] + self.run_test( + model, + inputs, + expected_ops={acc_ops.permute}, + ) + + def test_permute_op_on_tensor_args( + self, + ): + class TestModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.permute(2, 0, 1) + + model = TestModule().half().cuda() + inputs = [torch.randn(32, 256, 256).cuda().half()] + self.run_test( + model, + inputs, + expected_ops={acc_ops.permute}, + )