Skip to content

Commit

Permalink
Fix constant bool tensor importing
Browse files Browse the repository at this point in the history
  • Loading branch information
giacs-epic committed Feb 13, 2025
1 parent c9694c6 commit fd256ea
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 6 deletions.
24 changes: 24 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2910,6 +2910,30 @@ def NumelZeroRankModule_basic(module, tu: TestUtils):
# ==============================================================================


class BoolTensorConstantModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
]
)
def forward(self):
return torch.tensor(
[1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1], dtype=torch.bool
)


@register_test_case(module_factory=lambda: BoolTensorConstantModule())
def BoolTensorConstantModule_basic(module, tu: TestUtils):
module.forward()


# ==============================================================================


class BoolTensorReturnFalseModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
7 changes: 1 addition & 6 deletions python/torch_mlir/extras/onnx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,12 +1132,7 @@ def get_operator_function(
np.asarray(tp.float_data, dtype=np.float32).reshape(tp.dims), signless=False
),
onnx.TensorProto.DataType.BOOL: lambda tp: DenseElementsAttr.get(
np.packbits(
np.asarray(tp.int32_data, dtype=np.bool_).reshape(tp.dims),
axis=None,
bitorder="little",
),
signless=False,
np.asarray(tp.int32_data, dtype=np.bool_).reshape(tp.dims), signless=False
),
onnx.TensorProto.DataType.UINT8: lambda tp: DenseElementsAttr.get(
np.asarray(tp.int32_data, dtype=np.uint8).reshape(tp.dims), signless=False
Expand Down

0 comments on commit fd256ea

Please sign in to comment.