Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow binary ops where only one arg is an immutable IntVarTensor #987

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions fx2ait/fx2ait/converters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,28 @@ def create_binary_op(
)
return res

if (
isinstance(rhs, IntVarTensor)
and isinstance(rhs._attrs["int_var"], IntImm)
and rhs_is_constant
and isinstance(lhs, AITTensor)
and not isinstance(lhs, IntVarTensor)
):
# If rhs is a constant IntVarTensor but lhs is not, proceed
rhs = rhs_constant
return elementwise(op_type)(lhs, rhs)

if (
isinstance(lhs, IntVarTensor)
and isinstance(lhs._attrs["int_var"], IntImm)
and lhs_is_constant
and isinstance(rhs, AITTensor)
and not isinstance(rhs, IntVarTensor)
):
# If lhs is a constant IntVarTensor but rhs is not, proceed
lhs = lhs_constant
return elementwise(op_type)(lhs, rhs)

if isinstance(lhs, IntVarTensor) or isinstance(rhs, IntVarTensor):
lhs = IntVarTensor(IntImm(lhs)) if isinstance(lhs, int) else lhs
rhs = IntVarTensor(IntImm(rhs)) if isinstance(rhs, int) else rhs
Expand Down
24 changes: 24 additions & 0 deletions fx2ait/fx2ait/test/converters/test_ait_binary_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,27 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
[torch.randn(2, 4).half().cuda()],
expected_ops={acc_ops.reshape, acc_ops.mul},
)

def test_binary_one_intmm_constant_lhs(self) -> None:
class TestModule(torch.nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.add(input, input.size()[0])

model = TestModule().cuda()
self.run_test(
model,
[torch.randn((1, 1)).half().cuda()],
expected_ops={acc_ops.add},
)

def test_binary_one_intmm_constant_rhs(self) -> None:
class TestModule(torch.nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.add(input.size()[0], input)

model = TestModule().cuda()
self.run_test(
model,
[torch.randn((1, 1)).half().cuda()],
expected_ops={acc_ops.add},
)
Loading