diff --git a/fx2ait/fx2ait/converters/utils.py b/fx2ait/fx2ait/converters/utils.py index d64cb084a..540024fab 100644 --- a/fx2ait/fx2ait/converters/utils.py +++ b/fx2ait/fx2ait/converters/utils.py @@ -86,6 +86,28 @@ def create_binary_op( ) return res + if ( + isinstance(rhs, IntVarTensor) + and isinstance(rhs._attrs["int_var"], IntImm) + and rhs_is_constant + and isinstance(lhs, AITTensor) + and not isinstance(lhs, IntVarTensor) + ): + # If rhs is a constant IntVarTensor but lhs is not, proceed + rhs = rhs_constant + return elementwise(op_type)(lhs, rhs) + + if ( + isinstance(lhs, IntVarTensor) + and isinstance(lhs._attrs["int_var"], IntImm) + and lhs_is_constant + and isinstance(rhs, AITTensor) + and not isinstance(rhs, IntVarTensor) + ): + # If lhs is a constant IntVarTensor but rhs is not, proceed + lhs = lhs_constant + return elementwise(op_type)(lhs, rhs) + if isinstance(lhs, IntVarTensor) or isinstance(rhs, IntVarTensor): lhs = IntVarTensor(IntImm(lhs)) if isinstance(lhs, int) else lhs rhs = IntVarTensor(IntImm(rhs)) if isinstance(rhs, int) else rhs diff --git a/fx2ait/fx2ait/test/converters/test_ait_binary_op.py b/fx2ait/fx2ait/test/converters/test_ait_binary_op.py index 1a13daaa9..1da7b11e1 100644 --- a/fx2ait/fx2ait/test/converters/test_ait_binary_op.py +++ b/fx2ait/fx2ait/test/converters/test_ait_binary_op.py @@ -167,3 +167,27 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: [torch.randn(2, 4).half().cuda()], expected_ops={acc_ops.reshape, acc_ops.mul}, ) + + def test_binary_one_intmm_constant_lhs(self) -> None: + class TestModule(torch.nn.Module): + def forward(self, input: torch.Tensor) -> torch.Tensor: + return torch.add(input, input.size()[0]) + + model = TestModule().cuda() + self.run_test( + model, + [torch.randn((1, 1)).half().cuda()], + expected_ops={acc_ops.add}, + ) + + def test_binary_one_intmm_constant_rhs(self) -> None: + class TestModule(torch.nn.Module): + def forward(self, input: torch.Tensor) -> torch.Tensor: + return torch.add(input.size()[0], input) + + model = TestModule().cuda() + self.run_test( + model, + [torch.randn((1, 1)).half().cuda()], + expected_ops={acc_ops.add}, + )