From c311b2fdac03665da3f240b3bd7c16da6f36b0d9 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 6 Jan 2025 14:24:02 -0500 Subject: [PATCH] fixup --- .github/workflows/enzyme-mlir.yml | 2 +- enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp | 7 +++---- enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp | 1 - enzyme/Enzyme/MLIR/Passes/RemovalUtils.h | 2 +- enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp | 8 +++++++- enzyme/test/MLIR/ForwardMode/batched_scalar.mlir | 2 +- 6 files changed, 13 insertions(+), 9 deletions(-) diff --git a/.github/workflows/enzyme-mlir.yml b/.github/workflows/enzyme-mlir.yml index 89ae72957c7..16b3fe6e11e 100644 --- a/.github/workflows/enzyme-mlir.yml +++ b/.github/workflows/enzyme-mlir.yml @@ -36,7 +36,7 @@ jobs: - uses: actions/checkout@v4 with: repository: 'llvm/llvm-project' - ref: 'eaa7b385368fa7e3dad9b95411d04be55e71494e' + ref: 'ff24e9a19e3db330dd6412aac9d1d6c0b416697f' path: 'llvm-project' - name: Get MLIR commit hash diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp index e7ef518dbd5..972222f87ba 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp @@ -39,10 +39,9 @@ struct DifferentiatePass : public DifferentiatePassBase { pm.getDependentDialects(registry); } - registry - .insert(); + registry.insert(); } static std::vector mode_from_fn(FunctionOpInterface fn, diff --git a/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp b/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp index 002b11d6bc9..572fddd1cae 100644 --- a/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp +++ b/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp @@ -26,7 +26,6 @@ mlir::enzyme::CacheInfo::merge(mlir::enzyme::CacheInfo other) { other.initOp->erase(); } - enzyme::PushOp newPushOp = pushOp; other.pushOp->erase(); enzyme::PopOp newPopOp; diff --git a/enzyme/Enzyme/MLIR/Passes/RemovalUtils.h b/enzyme/Enzyme/MLIR/Passes/RemovalUtils.h index 32308ed1d6b..d56ce6018da 100644 --- a/enzyme/Enzyme/MLIR/Passes/RemovalUtils.h +++ b/enzyme/Enzyme/MLIR/Passes/RemovalUtils.h @@ -41,7 +41,7 @@ struct CacheInfo { Value pushedValue() { return pushOp.getValue(); } Type cachedType() { - return initOp.getResult().getType().cast().getType(); + return cast(initOp.getResult().getType()).getType(); } // Pushed values must be the same diff --git a/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp b/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp index 8ee77113e9a..cb25fa6fa8b 100644 --- a/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp +++ b/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp @@ -306,12 +306,18 @@ struct RemoveUnusedEnzymeOpsPass applyPatterns(op); + bool failed = false; op->walk([&](FunctionOpInterface func) { func->walk([&](enzyme::EnzymeOpsRemoverOpInterface iface) { - iface.removeEnzymeOps(); + auto result = iface.removeEnzymeOps(); + if (!result.succeeded()) + failed = true; }); }); + if (failed) + return signalPassFailure(); + applyPatterns(op); } }; diff --git a/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir b/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir index d384bdd0933..8acd131c169 100644 --- a/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir +++ b/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir @@ -21,6 +21,6 @@ module { // CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : (f64) -> tensor<2xf64> // CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2xf64> // CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2xf64> -// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<2xf64> +// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] // CHECK-NEXT: return %[[i2]] : tensor<2xf64> // CHECK-NEXT: }