Skip to content

Commit

Permalink
allow type promotions from bool if common type is non-bool (facebooki…
Browse files Browse the repository at this point in the history
…ncubator#985)

Summary:

for mixed elementwise ops, allow bools to be cast to common dtype. this behavior mimics the behavior of PyTorch when doing things like multiplying a float16 tensor by a bool tensor (results in a float16 tensor).

Differential Revision: D52770344
  • Loading branch information
bradleyhd authored and facebook-github-bot committed Jan 30, 2024
1 parent 89fd517 commit d06ccca
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 6 deletions.
31 changes: 25 additions & 6 deletions python/aitemplate/compiler/ops/common/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from aitemplate.compiler.op_registry import OP_REGISTRY
from aitemplate.compiler.ops.common.epilogue import FuncEnum
from aitemplate.compiler.ops.common.int_elementwise import INT_ELEMENTWISE_FUNC

from aitemplate.compiler.ops.tensor import cast
from aitemplate.utils import shape_utils

# pylint: disable=C0103,W0221,W0102,C0301,W0223,R1724
Expand Down Expand Up @@ -225,12 +225,31 @@ def __call__(self, *args: Tensor) -> Tensor:
symbolic_args.append(arg._attrs["int_var"].symbolic_value())
elif isinstance(arg, Tensor):
converted_args.append(arg)
arg_dtype = normalize_dtype(arg.dtype())
if common_dtype is None:
common_dtype = normalize_dtype(arg.dtype())
elif normalize_dtype(arg.dtype()) != common_dtype:
raise NotImplementedError(
f"Type promotions are not supported; got dtype {arg.dtype()}, but expected {common_dtype}"
)
common_dtype = arg_dtype
elif arg_dtype != common_dtype:
if arg.dtype() == "bool" and common_dtype != "bool":
# If this arg is bool, and the common is not bool, cast to the common type.
converted_args[-1] = cast()(
x=converted_args[-1], dtype=common_dtype
)
elif (
arg.dtype() != "bool"
and common_dtype == "bool"
and len(converted_args) >= 2
):
# If this arg is non-bool and the common type is bool,
# cast all previous bool args to the non-bool type.
common_dtype = arg_dtype
for i in range(0, len(converted_args) - 1):
converted_args[i] = cast()(
x=converted_args[i], dtype=common_dtype
)
else:
raise NotImplementedError(
f"Type promotions are not supported; got dtype {arg.dtype()}, but expected {common_dtype}"
)
symbolic_args.append(arg._attrs.get("symbolic_value", None))
else:
raise RuntimeError(
Expand Down
104 changes: 104 additions & 0 deletions tests/unittest/compiler/test_op_common_elementwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest

from aitemplate import compiler

from aitemplate.compiler import compile_model, ops

from aitemplate.compiler.base import Tensor
from aitemplate.compiler.ops.common.epilogue import FuncEnum
from aitemplate.compiler.transform.fuse_ops import (
fuse_elementwise,
process_singleton_elementwise,
)
from aitemplate.testing import detect_target
from aitemplate.testing.test_utils import (
get_random_torch_tensor,
get_torch_empty_tensor,
)


def _make_graph():
X0 = Tensor(
shape=[3, 5, 7, 9],
dtype="float16",
name="X0",
is_input=True,
)

Y = ops.elementwise(FuncEnum.ABS)(ops.elementwise(FuncEnum.SIN)(X0))

Y._attrs["is_output"] = True
Y._attrs["name"] = "Y"
return Y


class OpCommonElementwiseTestCase(unittest.TestCase):
def test_elementwise_type_promotion_bool_rhs(self):
X0 = Tensor(
shape=[3, 5, 2],
dtype="float16",
name="X0",
is_input=True,
)
X1 = Tensor(
shape=[3, 5, 2],
dtype="bool",
name="X1",
is_input=True,
)
Y = ops.elementwise(FuncEnum.MUL)(X0, X1)
Y._attrs["name"] = "output0"
Y._attrs["is_output"] = True
target = detect_target()
module = compile_model(
Y,
target,
"./tmp",
"test_elementwise_type_promotion_bool_rhs",
)
x0_pt = get_random_torch_tensor([3, 5, 2], "float16")
x1_pt = get_random_torch_tensor([3, 5, 2], "bool")
out_pt = get_torch_empty_tensor([3, 5, 2], "float16")
module.run_with_tensors({"X0": x0_pt, "X1": x1_pt}, {"output0": out_pt})

def test_elementwise_type_promotion_bool_lhs(self):
X0 = Tensor(
shape=[3, 5, 2],
dtype="bool",
name="X1",
is_input=True,
)
X1 = Tensor(
shape=[3, 5, 2],
dtype="float16",
name="X0",
is_input=True,
)
Y = ops.elementwise(FuncEnum.MUL)(X0, X1)
Y._attrs["name"] = "output0"
Y._attrs["is_output"] = True
target = detect_target()
module = compile_model(
Y,
target,
"./tmp",
"test_elementwise_type_promotion_bool_lhs",
)
x0_pt = get_random_torch_tensor([3, 5, 2], "float16")
x1_pt = get_random_torch_tensor([3, 5, 2], "bool")
out_pt = get_torch_empty_tensor([3, 5, 2], "float16")
module.run_with_tensors({"X0": x0_pt, "X1": x1_pt}, {"output0": out_pt})

0 comments on commit d06ccca

Please sign in to comment.