Skip to content

Commit

Permalink
ait: Explicitly throw when indexing a boolean tensor for masking
Browse files Browse the repository at this point in the history
Summary:
Currently AIT hasn't implemented `tensor[boolean_tensor]` for masking. It fail shortly after this call, at:

>  'Tensor' object has no attribute 'upper_bound'


```
> link-tree/aitemplate/utils/shape_utils.py(195)convert_IntVar_to_int()
-> if var.upper_bound() == var.lower_bound():
```

Reviewed By: khabinov

Differential Revision: D53654054
  • Loading branch information
kflu authored and facebook-github-bot committed Feb 12, 2024
1 parent 4b27253 commit 39c7d22
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
7 changes: 7 additions & 0 deletions fx2ait/fx2ait/converters/ait_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,13 @@ def acc_ops_getitem(
isinstance(idx, Sequence) and any(isinstance(x, slice) for x in idx)
):
return acc_ops_slice(target, args, kwargs, name)

if isinstance(idx, AITTensor) and idx.dtype() == "bool":
# TODO: could do something similar to acc_ops_masked_select
raise NotImplementedError(
"AIT does not support tensor[boolean_tensor] masking yet"
)

if isinstance(input_val, AITTensor):
return acc_ops_slice(target, args, kwargs, name)

Expand Down
19 changes: 19 additions & 0 deletions fx2ait/fx2ait/test/converters/test_ait_binary_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,25 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
expected_ops={acc_op},
)

def test_getitem_boolean_index(self) -> None:
"""Verify that NotImplementatedError is thrown encountering
tensor[boolean_mask_tensor]
"""

class TestModule(torch.nn.Module):
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
return x[mask]

mod = TestModule().cuda()
x = torch.rand(10, 4).half().cuda()
mask = (torch.rand((10,)) > 0.5).cuda()
mod(x, mask)

self.assertRaises(
NotImplementedError,
lambda: self.run_test(mod, [x, mask], expected_ops={}),
)

# This is a common binary op combo usage for ads models.
def test_binary_op_combo(self) -> None:
class TestModule(torch.nn.Module):
Expand Down

0 comments on commit 39c7d22

Please sign in to comment.