Skip to content

Commit

Permalink
[mlir][func] Fix multiple bugs in DuplicateFunctionElimination (llv…
Browse files Browse the repository at this point in the history
…m#109571)

This PR fixes multiple bugs in `DuplicateFunctionElimination`.
- Prevents elimination of function declarations.
- Updates all symbol uses to reference unique function representatives.

Fixes llvm#93483.
  • Loading branch information
CoTinker authored Oct 22, 2024
1 parent b8930cd commit 2ce655c
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 7 deletions.
18 changes: 11 additions & 7 deletions mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ struct DuplicateFuncOpEquivalenceInfo
if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
rhs == getTombstoneKey() || rhs == getEmptyKey())
return false;

if (lhs.isDeclaration() || rhs.isDeclaration())
return false;

// Check discardable attributes equivalence
if (lhs->getDiscardableAttrDictionary() !=
rhs->getDiscardableAttrDictionary())
Expand Down Expand Up @@ -97,14 +101,14 @@ struct DuplicateFunctionEliminationPass
}
});

// Update call ops to call unique func op representants.
module.walk([&](func::CallOp callOp) {
func::FuncOp callee = getRepresentant[callOp.getCalleeAttr().getAttr()];
callOp.setCallee(callee.getSymName());
});

// Erase redundant func ops.
// Update all symbol uses to reference unique func op
// representants and erase redundant func ops.
SymbolTableCollection symbolTable;
SymbolUserMap userMap(symbolTable, module);
for (auto it : toBeErased) {
StringAttr oldSymbol = it.getSymNameAttr();
StringAttr newSymbol = getRepresentant[oldSymbol].getSymNameAttr();
userMap.replaceAllUsesWith(it, newSymbol);
it.erase();
}
}
Expand Down
47 changes: 47 additions & 0 deletions mlir/test/Dialect/Func/duplicate-function-elimination.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -366,3 +366,50 @@ func.func @user(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %odd: f32, %even: f32)
// CHECK: @user
// CHECK-2: call @deep_tree
// CHECK: call @reverse_deep_tree

// -----

func.func private @func_declaration(i32, i32) -> i32
func.func private @func_declaration1(i32, i32) -> i32

func.func @user(%arg0: i32, %arg1: i32) -> (i32, i32) {
%0 = call @func_declaration(%arg0, %arg1) : (i32, i32) -> i32
%1 = call @func_declaration1(%arg0, %arg1) : (i32, i32) -> i32
return %0, %1 : i32, i32
}

// CHECK: @func_declaration
// CHECK: @func_declaration1
// CHECK: @user
// CHECK: call @func_declaration
// CHECK: call @func_declaration1

// -----

func.func @identity(%arg0: tensor<f32>) -> tensor<f32> {
return %arg0 : tensor<f32>
}

func.func @also_identity(%arg0: tensor<f32>) -> tensor<f32> {
return %arg0 : tensor<f32>
}

func.func @yet_another_identity(%arg0: tensor<f32>) -> tensor<f32> {
return %arg0 : tensor<f32>
}

func.func @user(%arg0: tensor<f32>) -> tensor<f32> {
%f = constant @identity : (tensor<f32>) -> tensor<f32>
%0 = call_indirect %f(%arg0) : (tensor<f32>) -> tensor<f32>
%f_0 = constant @also_identity : (tensor<f32>) -> tensor<f32>
%1 = call_indirect %f_0(%0) : (tensor<f32>) -> tensor<f32>
%2 = call @yet_another_identity(%1) : (tensor<f32>) -> tensor<f32>
return %2 : tensor<f32>
}

// CHECK: @identity
// CHECK-NOT: @also_identity
// CHECK-NOT: @yet_another_identity
// CHECK: @user
// CHECK-2: constant @identity
// CHECK: call @identity

0 comments on commit 2ce655c

Please sign in to comment.