diff --git a/fx2ait/fx2ait/converters/ait_converters.py b/fx2ait/fx2ait/converters/ait_converters.py index 5149481a2..3c21094ba 100644 --- a/fx2ait/fx2ait/converters/ait_converters.py +++ b/fx2ait/fx2ait/converters/ait_converters.py @@ -376,10 +376,20 @@ def acc_ops_permute( ) -> ConverterOutput: input_val = kwargs["input"] if not isinstance(input_val, AITTensor): - raise ValueError(f"Unexpected input for {name}: {input_val}") + raise ValueError(f"Unexpected input for {name}: input={input_val}") permutation = kwargs["permutation"] + if ( + isinstance(permutation, (list, tuple)) + and permutation + and isinstance(permutation[0], (list, tuple)) + ): + # If permutation is a nested list or tuple, unwrap one level. + # This is needed for some valid invocations of permute like + # t.permute((2, 0, 1)). + permutation = permutation[0] + return permute()(input_val, permutation) diff --git a/fx2ait/fx2ait/test/converters/test_ait_permute.py b/fx2ait/fx2ait/test/converters/test_ait_permute.py new file mode 100644 index 000000000..9a910cee2 --- /dev/null +++ b/fx2ait/fx2ait/test/converters/test_ait_permute.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +#!/usr/bin/env fbpython +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +import torch +from fx2ait.acc_tracer import acc_ops +from fx2ait.tools.common_fx2ait import AITTestCase + + +class TestPermuteConverter(AITTestCase): + def test_permute_torch_op( + self, + ): + class TestModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.permute(x, (2, 0, 1)) + + model = TestModule().half().cuda() + inputs = [torch.randn(32, 256, 256).cuda().half()] + self.run_test( + model, + inputs, + expected_ops={acc_ops.permute}, + ) + + def test_permute_op_on_tensor_tuple( + self, + ): + class TestModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.permute((2, 0, 1)) + + model = TestModule().half().cuda() + inputs = [torch.randn(32, 256, 256).cuda().half()] + self.run_test( + model, + inputs, + expected_ops={acc_ops.permute}, + ) + + def test_permute_op_on_tensor_args( + self, + ): + class TestModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.permute(2, 0, 1) + + model = TestModule().half().cuda() + inputs = [torch.randn(32, 256, 256).cuda().half()] + self.run_test( + model, + inputs, + expected_ops={acc_ops.permute}, + )