diff --git a/fx2ait/fx2ait/converters/ait_converters.py b/fx2ait/fx2ait/converters/ait_converters.py index 5acb24c2d..5149481a2 100644 --- a/fx2ait/fx2ait/converters/ait_converters.py +++ b/fx2ait/fx2ait/converters/ait_converters.py @@ -1607,6 +1607,54 @@ def acc_ops_avg_pool2d( return ait_nhwc2nchw(result) +@ait_converter(acc_ops.avg_pool3d) +def acc_ops_avg_pool3d( + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> ConverterOutput: + input_val = kwargs["input"] + + if not isinstance(input_val, AITTensor): + raise RuntimeError(f"Non-tensor inputs for {name}: {input_val}") + + input_val = ait_ncdhw2ndhwc(input_val) + + kernel_size = identical_elem_tuple_to_int(kwargs["kernel_size"]) + stride = ( + identical_elem_tuple_to_int(kwargs["stride"]) + if kwargs["stride"] + else kernel_size + ) + padding = identical_elem_tuple_to_int(kwargs["padding"]) + + assert kernel_size == 1, "avg_pool3d only supports kT == 1 currently" + assert stride == 1, "avg_pool3d only supports sT == 1 currently" + assert padding == 0, "avg_pool3d only supports T_padding == 0 currently" + + ceil_mode = kwargs["ceil_mode"] + count_include_pad = kwargs["count_include_pad"] + divisor_override = kwargs["divisor_override"] + if ceil_mode or not count_include_pad or divisor_override: + raise RuntimeError( + "Non-default ceil_mode/count_include_pad/divisor_override not supported yet" + ) + + N, D, H, W, C = input_val.shape() + + shape_0 = (-1, H, W, C) + input_val = reshape()(input_val, shape_0) + + output = avg_pool2d(kernel_size=kernel_size, stride=stride, pad=padding)(input_val) + + _, H_o, W_o, _ = output.shape() + shape_1 = (N, D, H_o, W_o, C) + + output = reshape()(output, shape_1) + return ait_ndhwc2ncdhw(output) + + @ait_converter(acc_ops.adaptive_avg_pool2d) def acc_ops_adaptive_avg_pool2d( target: Target, diff --git a/fx2ait/fx2ait/test/converters/test_ait_avg_pool3d.py b/fx2ait/fx2ait/test/converters/test_ait_avg_pool3d.py new file mode 100644 index 000000000..141a01050 --- /dev/null +++ b/fx2ait/fx2ait/test/converters/test_ait_avg_pool3d.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +import torch +from fx2ait.acc_tracer import acc_ops +from fx2ait.tools.common_fx2ait import AITTestCase +from parameterized import parameterized + + +class TestAvgPool3dConverter(AITTestCase): + @parameterized.expand([(1, 1, 0), ((1, 1, 1), (1, 1, 1), (0, 0, 0))]) + def test_avgpool3d(self, kernel_size, stride, padding): + class TestModule(torch.nn.Module): + def __init__(self, kernel_size, stride, padding): + super().__init__() + self.pool = torch.nn.AvgPool3d(kernel_size, stride, padding) + + def forward(self, x): + return self.pool(x) + + model = TestModule(kernel_size, stride, padding).half().cuda() + inputs = [torch.randn(1, 4, 256, 256, 256).cuda().half()] + self.run_test( + model, + inputs, + expected_ops={acc_ops.avg_pool3d}, + )