Skip to content

Commit

Permalink
[Relax] Remove name requirement from MergeCompositeFunctions
Browse files Browse the repository at this point in the history
Prior to this commit, the `relax.transform.MergeCompositeFunctions`
required the module to contain a single Relax function named `"main"`.
Lack of this function would result in an error when
`mod->GetGlobalVar("main")` was called.  Presence of any other Relax
function would cause an error, since they were not collected into the
`CompositeGroupsBuilder`.

This commit updates `MergeCompositeFunctions` to be indepedendent of
the names of the Relax functions in an IRModule.  The transform now
updates all Relax functions that do not have the `attr::kPrimitive` or
`attr::kCodegen` attributes.

Closes #17210
  • Loading branch information
Lunderberg committed Jul 29, 2024
1 parent df33d73 commit 98aa5c5
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 83 deletions.
41 changes: 31 additions & 10 deletions src/relax/transform/merge_composite_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -332,10 +332,12 @@ class CompositeFunctionAnnotator : public ExprMutator {
}
using ExprMutator::VisitExpr_;

IRModule update() {
auto gvar = mod_->GetGlobalVar("main");
auto func = Downcast<Function>(mod_->Lookup(gvar));
builder_->UpdateFunction(gvar, Downcast<Function>(VisitExpr(func)));
IRModule update(const Array<String>& entry_function_names) {
for (const auto& name : entry_function_names) {
auto gvar = mod_->GetGlobalVar(name);
auto func = Downcast<Function>(mod_->Lookup(gvar));
builder_->UpdateFunction(gvar, Downcast<Function>(VisitExpr(func)));
}
return builder_->GetContextIRModule();
}

Expand Down Expand Up @@ -382,15 +384,34 @@ class CompositeFunctionAnnotator : public ExprMutator {
} // namespace

IRModule MergeCompositeFunctions(IRModule mod) {
auto gvar = mod->GetGlobalVar("main");
auto func = Downcast<Function>(mod->Lookup(gvar));
support::Arena arena;
auto group_map = CompositeGroupsBuilder(mod, &arena).Run(func);
auto new_mod = MakeGroupedFunctions(mod, group_map);
new_mod = CompositeFunctionAnnotator(mod, new_mod).update();

Array<String> entry_function_names;
for (const auto& [gvar, func] : mod->functions) {
if (func.as<FunctionNode>() && !func->GetAttr<String>(attr::kCodegen).defined() &&
!func->GetAttr<Bool>(attr::kPrimitive).defined()) {
entry_function_names.push_back(gvar->name_hint);
}
}

std::unordered_map<const Object*, Group*> group_map;
CompositeGroupsBuilder group_builder(mod, &arena);

for (const auto& name : entry_function_names) {
auto func = Downcast<Function>(mod->Lookup(name));
auto new_group_map = group_builder.Run(func);

for (const auto& [obj, group] : new_group_map) {
ICHECK(!group_map.count(obj));
group_map[obj] = group;
}
}

auto new_mod = MakeGroupedFunctions(mod, group_map, true, entry_function_names);
new_mod = CompositeFunctionAnnotator(mod, new_mod).update(entry_function_names);

// TODO(@tvm-team): Implicit pass dependency. Revisit when we have a better way to handle this.
return DeadCodeElimination(new_mod, {"main"});
return DeadCodeElimination(new_mod, {});
}

namespace transform {
Expand Down
Loading

0 comments on commit 98aa5c5

Please sign in to comment.