diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index dbd573f96a79f8..39d0ee122b1630 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -878,8 +878,9 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { OpFoldResult PadOp::fold(FoldAdaptor adaptor) { // If the pad is all zeros we can fold this operation away. if (adaptor.getPadding() && getInput1().getType() == getType()) { - auto densePad = llvm::cast(adaptor.getPadding()); - if (densePad.isSplat() && densePad.getSplatValue().isZero()) { + auto densePad = llvm::dyn_cast(adaptor.getPadding()); + if (densePad && densePad.isSplat() && + densePad.getSplatValue().isZero()) { return getInput1(); } } diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 3bcf58015831ba..67cd01f62f0bdf 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -217,6 +217,17 @@ func.func @pad_noop(%arg0: tensor) -> tensor { // ----- +// CHECK-LABEL: @pad_noop_padding_mismatch_nofold +func.func @pad_noop_padding_mismatch_nofold(%arg0: tensor) -> tensor { + // CHECK: %[[PAD:.+]] = tosa.pad + // CHECK: return %[[PAD]] + %0 = "tosa.const"() { value = dense_resource<__elided__> : tensor<2x2xi32>} : () -> tensor<2x2xi32> + %1 = tosa.pad %arg0, %0 : (tensor, tensor<2x2xi32>) -> tensor + return %1 : tensor +} + +// ----- + // CHECK-LABEL: @pad_noop_type_mismatch_nofold func.func @pad_noop_type_mismatch_nofold(%arg0: tensor<10xf32>) -> tensor { // CHECK: %[[PAD:.+]] = tosa.pad