Skip to content

Commit

Permalink
allow binary ops where only one arg is an immutable IntVarTensor (fac…
Browse files Browse the repository at this point in the history
…ebookincubator#987)

Summary:

rarely, the output of a call to .size() is one of the operands in a binary op, and as such has type IntVarTensor. In this case, it is okay to forward the call to elmentwise (instead of raising an error before a call to int_elementwise().

Reviewed By: khabinov

Differential Revision: D53240440
  • Loading branch information
bradleyhd authored and facebook-github-bot committed Jan 31, 2024
1 parent cc9703a commit 9a1213a
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
22 changes: 22 additions & 0 deletions fx2ait/fx2ait/converters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,28 @@ def create_binary_op(
)
return res

if (
isinstance(rhs, IntVarTensor)
and isinstance(rhs._attrs["int_var"], IntImm)
and rhs_is_constant
and isinstance(lhs, AITTensor)
and not isinstance(lhs, IntVarTensor)
):
# If rhs is a constant IntVarTensor but lhs is not, proceed
rhs = rhs_constant
return elementwise(op_type)(lhs, rhs)

if (
isinstance(lhs, IntVarTensor)
and isinstance(lhs._attrs["int_var"], IntImm)
and lhs_is_constant
and isinstance(rhs, AITTensor)
and not isinstance(rhs, IntVarTensor)
):
# If lhs is a constant IntVarTensor but rhs is not, proceed
lhs = lhs_constant
return elementwise(op_type)(lhs, rhs)

if isinstance(lhs, IntVarTensor) or isinstance(rhs, IntVarTensor):
lhs = IntVarTensor(IntImm(lhs)) if isinstance(lhs, int) else lhs
rhs = IntVarTensor(IntImm(rhs)) if isinstance(rhs, int) else rhs
Expand Down
24 changes: 24 additions & 0 deletions fx2ait/fx2ait/test/converters/test_ait_binary_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,27 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
[torch.randn(2, 4).half().cuda()],
expected_ops={acc_ops.reshape, acc_ops.mul},
)

def test_binary_one_intmm_constant_lhs(self) -> None:
class TestModule(torch.nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.add(input, input.size()[0])

model = TestModule().cuda()
self.run_test(
model,
[torch.randn((1, 1)).half().cuda()],
expected_ops={acc_ops.add},
)

def test_binary_one_intmm_constant_rhs(self) -> None:
class TestModule(torch.nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.add(input.size()[0], input)

model = TestModule().cuda()
self.run_test(
model,
[torch.randn((1, 1)).half().cuda()],
expected_ops={acc_ops.add},
)

0 comments on commit 9a1213a

Please sign in to comment.