Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 7, 2025
1 parent fa93aab commit c311b2f
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/enzyme-mlir.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
- uses: actions/checkout@v4
with:
repository: 'llvm/llvm-project'
ref: 'eaa7b385368fa7e3dad9b95411d04be55e71494e'
ref: 'ff24e9a19e3db330dd6412aac9d1d6c0b416697f'
path: 'llvm-project'

- name: Get MLIR commit hash
Expand Down
7 changes: 3 additions & 4 deletions enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,9 @@ struct DifferentiatePass : public DifferentiatePassBase<DifferentiatePass> {
pm.getDependentDialects(registry);
}

registry
.insert<mlir::arith::ArithDialect, mlir::complex::ComplexDialect,
mlir::cf::ControlFlowDialect, mlir::tensor::TensorDialect,
mlir::enzyme::EnzymeDialect>();
registry.insert<mlir::arith::ArithDialect, mlir::complex::ComplexDialect,
mlir::cf::ControlFlowDialect, mlir::tensor::TensorDialect,
mlir::enzyme::EnzymeDialect>();
}

static std::vector<DIFFE_TYPE> mode_from_fn(FunctionOpInterface fn,
Expand Down
1 change: 0 additions & 1 deletion enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ mlir::enzyme::CacheInfo::merge(mlir::enzyme::CacheInfo other) {
other.initOp->erase();
}

enzyme::PushOp newPushOp = pushOp;
other.pushOp->erase();

enzyme::PopOp newPopOp;
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/MLIR/Passes/RemovalUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ struct CacheInfo {

Value pushedValue() { return pushOp.getValue(); }
Type cachedType() {
return initOp.getResult().getType().cast<enzyme::CacheType>().getType();
return cast<enzyme::CacheType>(initOp.getResult().getType()).getType();
}

// Pushed values must be the same
Expand Down
8 changes: 7 additions & 1 deletion enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,12 +306,18 @@ struct RemoveUnusedEnzymeOpsPass

applyPatterns(op);

bool failed = false;
op->walk([&](FunctionOpInterface func) {
func->walk([&](enzyme::EnzymeOpsRemoverOpInterface iface) {
iface.removeEnzymeOps();
auto result = iface.removeEnzymeOps();
if (!result.succeeded())
failed = true;
});
});

if (failed)
return signalPassFailure();

applyPatterns(op);
}
};
Expand Down
2 changes: 1 addition & 1 deletion enzyme/test/MLIR/ForwardMode/batched_scalar.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ module {
// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array<i64: 2>}> : (f64) -> tensor<2xf64>
// CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2xf64>
// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2xf64>
// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<2xf64>
// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]]
// CHECK-NEXT: return %[[i2]] : tensor<2xf64>
// CHECK-NEXT: }

0 comments on commit c311b2f

Please sign in to comment.