diff --git a/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt b/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt index 99db4d80034..00b2cae1e38 100644 --- a/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt @@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIREnzymeTransforms SimplifyMath.cpp AddToOpToIndexAndLoad.cpp AddToOpToSplit.cpp + RemovalUtils.cpp RemoveUnusedEnzymeOps.cpp SimplifyMemrefCache.cpp Utils.cpp diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp index c91f5400fef..e7ef518dbd5 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp @@ -41,7 +41,8 @@ struct DifferentiatePass : public DifferentiatePassBase { registry .insert(); + mlir::cf::ControlFlowDialect, mlir::tensor::TensorDialect, + mlir::enzyme::EnzymeDialect>(); } static std::vector mode_from_fn(FunctionOpInterface fn, diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.td b/enzyme/Enzyme/MLIR/Passes/Passes.td index d3494956a12..ebe00135b9f 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.td +++ b/enzyme/Enzyme/MLIR/Passes/Passes.td @@ -18,6 +18,7 @@ def DifferentiatePass : Pass<"enzyme"> { "complex::ComplexDialect", "cf::ControlFlowDialect", "tensor::TensorDialect", + "enzyme::EnzymeDialect", ]; let options = [ Option<