From fd256ea91461397164dfb4c6e1d2474f093efafe Mon Sep 17 00:00:00 2001 From: Giacomo Serafini <179146510+giacs-epic@users.noreply.github.com> Date: Thu, 13 Feb 2025 14:23:34 +0000 Subject: [PATCH] Fix constant bool tensor importing --- .../torch_mlir_e2e_test/test_suite/basic.py | 24 +++++++++++++++++++ python/torch_mlir/extras/onnx_importer.py | 7 +----- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 4ba497452a76..f1f2ebbac75f 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -2910,6 +2910,30 @@ def NumelZeroRankModule_basic(module, tu: TestUtils): # ============================================================================== +class BoolTensorConstantModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ] + ) + def forward(self): + return torch.tensor( + [1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1], dtype=torch.bool + ) + + +@register_test_case(module_factory=lambda: BoolTensorConstantModule()) +def BoolTensorConstantModule_basic(module, tu: TestUtils): + module.forward() + + +# ============================================================================== + + class BoolTensorReturnFalseModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index 7ce3647ee8c4..f1abaa29c05f 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -1132,12 +1132,7 @@ def get_operator_function( np.asarray(tp.float_data, dtype=np.float32).reshape(tp.dims), signless=False ), onnx.TensorProto.DataType.BOOL: lambda tp: DenseElementsAttr.get( - np.packbits( - np.asarray(tp.int32_data, dtype=np.bool_).reshape(tp.dims), - axis=None, - bitorder="little", - ), - signless=False, + np.asarray(tp.int32_data, dtype=np.bool_).reshape(tp.dims), signless=False ), onnx.TensorProto.DataType.UINT8: lambda tp: DenseElementsAttr.get( np.asarray(tp.int32_data, dtype=np.uint8).reshape(tp.dims), signless=False