Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for permute call with tuple args #986

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion fx2ait/fx2ait/converters/ait_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,10 +376,20 @@ def acc_ops_permute(
) -> ConverterOutput:
input_val = kwargs["input"]
if not isinstance(input_val, AITTensor):
raise ValueError(f"Unexpected input for {name}: {input_val}")
raise ValueError(f"Unexpected input for {name}: input={input_val}")

permutation = kwargs["permutation"]

if (
isinstance(permutation, (list, tuple))
and permutation
and isinstance(permutation[0], (list, tuple))
):
# If permutation is a nested list or tuple, unwrap one level.
# This is needed for some valid invocations of permute like
# t.permute((2, 0, 1)).
permutation = permutation[0]

return permute()(input_val, permutation)


Expand Down
67 changes: 67 additions & 0 deletions fx2ait/fx2ait/test/converters/test_ait_permute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#!/usr/bin/env fbpython
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

import torch
from fx2ait.acc_tracer import acc_ops
from fx2ait.tools.common_fx2ait import AITTestCase


class TestPermuteConverter(AITTestCase):
def test_permute_torch_op(
self,
):
class TestModule(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.permute(x, (2, 0, 1))

model = TestModule().half().cuda()
inputs = [torch.randn(32, 256, 256).cuda().half()]
self.run_test(
model,
inputs,
expected_ops={acc_ops.permute},
)

def test_permute_op_on_tensor_tuple(
self,
):
class TestModule(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.permute((2, 0, 1))

model = TestModule().half().cuda()
inputs = [torch.randn(32, 256, 256).cuda().half()]
self.run_test(
model,
inputs,
expected_ops={acc_ops.permute},
)

def test_permute_op_on_tensor_args(
self,
):
class TestModule(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.permute(2, 0, 1)

model = TestModule().half().cuda()
inputs = [torch.randn(32, 256, 256).cuda().half()]
self.run_test(
model,
inputs,
expected_ops={acc_ops.permute},
)
31 changes: 25 additions & 6 deletions python/aitemplate/compiler/ops/common/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from aitemplate.compiler.op_registry import OP_REGISTRY
from aitemplate.compiler.ops.common.epilogue import FuncEnum
from aitemplate.compiler.ops.common.int_elementwise import INT_ELEMENTWISE_FUNC

from aitemplate.compiler.ops.tensor import cast
from aitemplate.utils import shape_utils

# pylint: disable=C0103,W0221,W0102,C0301,W0223,R1724
Expand Down Expand Up @@ -225,12 +225,31 @@ def __call__(self, *args: Tensor) -> Tensor:
symbolic_args.append(arg._attrs["int_var"].symbolic_value())
elif isinstance(arg, Tensor):
converted_args.append(arg)
arg_dtype = normalize_dtype(arg.dtype())
if common_dtype is None:
common_dtype = normalize_dtype(arg.dtype())
elif normalize_dtype(arg.dtype()) != common_dtype:
raise NotImplementedError(
f"Type promotions are not supported; got dtype {arg.dtype()}, but expected {common_dtype}"
)
common_dtype = arg_dtype
elif arg_dtype != common_dtype:
if arg.dtype() == "bool" and common_dtype != "bool":
# If this arg is bool, and the common is not bool, cast to the common type.
converted_args[-1] = cast()(
x=converted_args[-1], dtype=common_dtype
)
elif (
arg.dtype() != "bool"
and common_dtype == "bool"
and len(converted_args) >= 2
):
# If this arg is non-bool and the common type is bool,
# cast all previous bool args to the non-bool type.
common_dtype = arg_dtype
for i in range(0, len(converted_args) - 1):
converted_args[i] = cast()(
x=converted_args[i], dtype=common_dtype
)
else:
raise NotImplementedError(
f"Type promotions are not supported; got dtype {arg.dtype()}, but expected {common_dtype}"
)
symbolic_args.append(arg._attrs.get("symbolic_value", None))
else:
raise RuntimeError(
Expand Down
98 changes: 98 additions & 0 deletions tests/unittest/compiler/test_op_common_elementwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest

from aitemplate.compiler import compile_model, ops

from aitemplate.compiler.base import Tensor
from aitemplate.compiler.ops.common.epilogue import FuncEnum
from aitemplate.testing import detect_target
from aitemplate.testing.test_utils import (
get_random_torch_tensor,
get_torch_empty_tensor,
)


def _make_graph():
X0 = Tensor(
shape=[3, 5, 7, 9],
dtype="float16",
name="X0",
is_input=True,
)

Y = ops.elementwise(FuncEnum.ABS)(ops.elementwise(FuncEnum.SIN)(X0))

Y._attrs["is_output"] = True
Y._attrs["name"] = "Y"
return Y


class OpCommonElementwiseTestCase(unittest.TestCase):
def test_elementwise_type_promotion_bool_rhs(self):
X0 = Tensor(
shape=[3, 5, 2],
dtype="float16",
name="X0",
is_input=True,
)
X1 = Tensor(
shape=[3, 5, 2],
dtype="bool",
name="X1",
is_input=True,
)
Y = ops.elementwise(FuncEnum.MUL)(X0, X1)
Y._attrs["name"] = "output0"
Y._attrs["is_output"] = True
target = detect_target()
module = compile_model(
Y,
target,
"./tmp",
"test_elementwise_type_promotion_bool_rhs",
)
x0_pt = get_random_torch_tensor([3, 5, 2], "float16")
x1_pt = get_random_torch_tensor([3, 5, 2], "bool")
out_pt = get_torch_empty_tensor([3, 5, 2], "float16")
module.run_with_tensors({"X0": x0_pt, "X1": x1_pt}, {"output0": out_pt})

def test_elementwise_type_promotion_bool_lhs(self):
X0 = Tensor(
shape=[3, 5, 2],
dtype="bool",
name="X1",
is_input=True,
)
X1 = Tensor(
shape=[3, 5, 2],
dtype="float16",
name="X0",
is_input=True,
)
Y = ops.elementwise(FuncEnum.MUL)(X0, X1)
Y._attrs["name"] = "output0"
Y._attrs["is_output"] = True
target = detect_target()
module = compile_model(
Y,
target,
"./tmp",
"test_elementwise_type_promotion_bool_lhs",
)
x0_pt = get_random_torch_tensor([3, 5, 2], "float16")
x1_pt = get_random_torch_tensor([3, 5, 2], "bool")
out_pt = get_torch_empty_tensor([3, 5, 2], "float16")
module.run_with_tensors({"X0": x0_pt, "X1": x1_pt}, {"output0": out_pt})
Loading