From 98aa5c5d9f213f58aab25c1d5e3f7d09852004e0 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 29 Jul 2024 08:44:00 -0500 Subject: [PATCH] [Relax] Remove name requirement from MergeCompositeFunctions Prior to this commit, the `relax.transform.MergeCompositeFunctions` required the module to contain a single Relax function named `"main"`. Lack of this function would result in an error when `mod->GetGlobalVar("main")` was called. Presence of any other Relax function would cause an error, since they were not collected into the `CompositeGroupsBuilder`. This commit updates `MergeCompositeFunctions` to be indepedendent of the names of the Relax functions in an IRModule. The transform now updates all Relax functions that do not have the `attr::kPrimitive` or `attr::kCodegen` attributes. Closes https://github.com/apache/tvm/issues/17210 --- .../transform/merge_composite_functions.cc | 41 +++- ...est_transform_merge_composite_functions.py | 194 +++++++++++------- 2 files changed, 152 insertions(+), 83 deletions(-) diff --git a/src/relax/transform/merge_composite_functions.cc b/src/relax/transform/merge_composite_functions.cc index 0dd14f5bb1afe..e22e424485e18 100644 --- a/src/relax/transform/merge_composite_functions.cc +++ b/src/relax/transform/merge_composite_functions.cc @@ -332,10 +332,12 @@ class CompositeFunctionAnnotator : public ExprMutator { } using ExprMutator::VisitExpr_; - IRModule update() { - auto gvar = mod_->GetGlobalVar("main"); - auto func = Downcast(mod_->Lookup(gvar)); - builder_->UpdateFunction(gvar, Downcast(VisitExpr(func))); + IRModule update(const Array& entry_function_names) { + for (const auto& name : entry_function_names) { + auto gvar = mod_->GetGlobalVar(name); + auto func = Downcast(mod_->Lookup(gvar)); + builder_->UpdateFunction(gvar, Downcast(VisitExpr(func))); + } return builder_->GetContextIRModule(); } @@ -382,15 +384,34 @@ class CompositeFunctionAnnotator : public ExprMutator { } // namespace IRModule MergeCompositeFunctions(IRModule mod) { - auto gvar = mod->GetGlobalVar("main"); - auto func = Downcast(mod->Lookup(gvar)); support::Arena arena; - auto group_map = CompositeGroupsBuilder(mod, &arena).Run(func); - auto new_mod = MakeGroupedFunctions(mod, group_map); - new_mod = CompositeFunctionAnnotator(mod, new_mod).update(); + + Array entry_function_names; + for (const auto& [gvar, func] : mod->functions) { + if (func.as() && !func->GetAttr(attr::kCodegen).defined() && + !func->GetAttr(attr::kPrimitive).defined()) { + entry_function_names.push_back(gvar->name_hint); + } + } + + std::unordered_map group_map; + CompositeGroupsBuilder group_builder(mod, &arena); + + for (const auto& name : entry_function_names) { + auto func = Downcast(mod->Lookup(name)); + auto new_group_map = group_builder.Run(func); + + for (const auto& [obj, group] : new_group_map) { + ICHECK(!group_map.count(obj)); + group_map[obj] = group; + } + } + + auto new_mod = MakeGroupedFunctions(mod, group_map, true, entry_function_names); + new_mod = CompositeFunctionAnnotator(mod, new_mod).update(entry_function_names); // TODO(@tvm-team): Implicit pass dependency. Revisit when we have a better way to handle this. - return DeadCodeElimination(new_mod, {"main"}); + return DeadCodeElimination(new_mod, {}); } namespace transform { diff --git a/tests/python/relax/test_transform_merge_composite_functions.py b/tests/python/relax/test_transform_merge_composite_functions.py index cff832a21ff96..41b8576da5446 100644 --- a/tests/python/relax/test_transform_merge_composite_functions.py +++ b/tests/python/relax/test_transform_merge_composite_functions.py @@ -32,12 +32,12 @@ def main( ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): cls = Conv2dReLUx2 with R.dataflow(): - lv: R.Tensor( - (1, 64, 56, 56), dtype="float32" - ) = cls.fused_relax_nn_conv2d_relax_nn_relu(data, weight1) - gv: R.Tensor( - (1, 64, 54, 54), dtype="float32" - ) = cls.fused_relax_nn_conv2d_relax_nn_relu1(lv, weight2) + lv: R.Tensor((1, 64, 56, 56), dtype="float32") = ( + cls.fused_relax_nn_conv2d_relax_nn_relu(data, weight1) + ) + gv: R.Tensor((1, 64, 54, 54), dtype="float32") = ( + cls.fused_relax_nn_conv2d_relax_nn_relu1(lv, weight2) + ) R.output(gv) return gv @@ -84,10 +84,10 @@ def main( ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): cls = Conv2dReLUx2_merged with R.dataflow(): - gv: R.Tensor( - (1, 64, 54, 54), dtype="float32" - ) = cls.fused_relax_nn_conv2d_relax_nn_relu_relax_nn_conv2d_relax_nn_relu1_dnnl( - data, weight1, weight2 + gv: R.Tensor((1, 64, 54, 54), dtype="float32") = ( + cls.fused_relax_nn_conv2d_relax_nn_relu_relax_nn_conv2d_relax_nn_relu1_dnnl( + data, weight1, weight2 + ) ) R.output(gv) return gv @@ -157,9 +157,9 @@ def main( return gv2 @R.function(private=True) - def fused_relax_nn_gelu( - lv: R.Tensor((1, 64, 54, 54), dtype="float32") - ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + def fused_relax_nn_gelu(lv: R.Tensor((1, 64, 54, 54), dtype="float32")) -> R.Tensor( + (1, 64, 54, 54), dtype="float32" + ): R.func_attr({"Primitive": 1, "Composite": "compiler_A.gelu"}) with R.dataflow(): gv: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.gelu(lv) @@ -167,9 +167,9 @@ def fused_relax_nn_gelu( return gv @R.function(private=True) - def fused_relax_nn_relu( - lv1: R.Tensor((1, 64, 54, 54), dtype="float32") - ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + def fused_relax_nn_relu(lv1: R.Tensor((1, 64, 54, 54), dtype="float32")) -> R.Tensor( + (1, 64, 54, 54), dtype="float32" + ): R.func_attr({"Primitive": 1, "Composite": "compiler_A.relu"}) with R.dataflow(): gv1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv1) @@ -241,9 +241,9 @@ def lv( lv2: R.Tensor((1, 64, 54, 54), dtype="float32") = lv(data, weight) @R.function - def lv1( - lv11: R.Tensor((1, 64, 54, 54), dtype="float32") - ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + def lv1(lv11: R.Tensor((1, 64, 54, 54), dtype="float32")) -> R.Tensor( + (1, 64, 54, 54), dtype="float32" + ): # function attr dict R.func_attr({"Composite": "compiler_A.relu"}) # block 0 @@ -255,9 +255,9 @@ def lv1( lv3: R.Tensor((1, 64, 54, 54), dtype="float32") = lv1(lv2) @R.function - def lv21( - lv4: R.Tensor((1, 64, 54, 54), dtype="float32") - ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + def lv21(lv4: R.Tensor((1, 64, 54, 54), dtype="float32")) -> R.Tensor( + (1, 64, 54, 54), dtype="float32" + ): # function attr dict R.func_attr({"Composite": "compiler_A.gelu"}) # block 0 @@ -291,10 +291,10 @@ def main( ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): cls = Diamond_merged with R.dataflow(): - gv5: R.Tensor( - (1, 64, 54, 54), dtype="float32" - ) = cls.fused_relax_nn_conv2d_relax_nn_relu_relax_nn_gelu_relax_add_compiler_A( - data2, weight2 + gv5: R.Tensor((1, 64, 54, 54), dtype="float32") = ( + cls.fused_relax_nn_conv2d_relax_nn_relu_relax_nn_gelu_relax_add_compiler_A( + data2, weight2 + ) ) R.output(gv5) return gv5 @@ -319,9 +319,9 @@ def main( return gv2 @R.function(private=True) - def fused_relax_nn_gelu( - lv: R.Tensor((1, 64, 54, 54), dtype="float32") - ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + def fused_relax_nn_gelu(lv: R.Tensor((1, 64, 54, 54), dtype="float32")) -> R.Tensor( + (1, 64, 54, 54), dtype="float32" + ): R.func_attr({"Primitive": 1, "Composite": "compiler_B.gelu"}) with R.dataflow(): gv: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.gelu(lv) @@ -329,9 +329,9 @@ def fused_relax_nn_gelu( return gv @R.function(private=True) - def fused_relax_nn_relu( - lv1: R.Tensor((1, 64, 54, 54), dtype="float32") - ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + def fused_relax_nn_relu(lv1: R.Tensor((1, 64, 54, 54), dtype="float32")) -> R.Tensor( + (1, 64, 54, 54), dtype="float32" + ): R.func_attr({"Primitive": 1, "Composite": "compiler_A.relu"}) with R.dataflow(): gv1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv1) @@ -416,9 +416,9 @@ def lv( gv: R.Tensor((1, 64, 54, 54), dtype="float32") = lv(data, weight) @R.function - def lv1( - lv11: R.Tensor((1, 64, 54, 54), dtype="float32") - ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + def lv1(lv11: R.Tensor((1, 64, 54, 54), dtype="float32")) -> R.Tensor( + (1, 64, 54, 54), dtype="float32" + ): R.func_attr({"Composite": "compiler_A.relu"}) with R.dataflow(): gv1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv11) @@ -431,14 +431,14 @@ def lv1( @R.function def fused_relax_nn_gelu1_compiler_B( - lv2: R.Tensor((1, 64, 54, 54), dtype="float32") + lv2: R.Tensor((1, 64, 54, 54), dtype="float32"), ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): R.func_attr({"Codegen": "compiler_B"}) @R.function - def lv21( - lv3: R.Tensor((1, 64, 54, 54), dtype="float32") - ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + def lv21(lv3: R.Tensor((1, 64, 54, 54), dtype="float32")) -> R.Tensor( + (1, 64, 54, 54), dtype="float32" + ): R.func_attr({"Composite": "compiler_B.gelu"}) with R.dataflow(): gv2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.gelu(lv3) @@ -487,9 +487,9 @@ def main( return gv1 @R.function(private=True) - def fused_relax_nn_relu( - x11: R.Tensor((10,), dtype="float32") - ) -> R.Tensor((10,), dtype="float32"): + def fused_relax_nn_relu(x11: R.Tensor((10,), dtype="float32")) -> R.Tensor( + (10,), dtype="float32" + ): R.func_attr({"Primitive": 1, "Composite": "compiler_A.relu"}) with R.dataflow(): gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(x11) @@ -497,9 +497,9 @@ def fused_relax_nn_relu( return gv2 @R.function(private=True) - def fused_relax_nn_gelu( - x21: R.Tensor((10,), dtype="float32") - ) -> R.Tensor((10,), dtype="float32"): + def fused_relax_nn_gelu(x21: R.Tensor((10,), dtype="float32")) -> R.Tensor( + (10,), dtype="float32" + ): R.func_attr({"Primitive": 1, "Composite": "compiler_A.gelu"}) with R.dataflow(): gv3: R.Tensor((10,), dtype="float32") = R.nn.gelu(x21) @@ -574,10 +574,10 @@ def main( ) -> R.Tensor((10,), dtype="float32"): cls = MultipleProducers_merged with R.dataflow(): - gv4: R.Tensor( - (10,), dtype="float32" - ) = cls.fused_relax_nn_relu_relax_nn_gelu_relax_nn_relu_relax_nn_gelu_relax_add_compiler_A( - x12, x22 + gv4: R.Tensor((10,), dtype="float32") = ( + cls.fused_relax_nn_relu_relax_nn_gelu_relax_nn_relu_relax_nn_gelu_relax_add_compiler_A( + x12, x22 + ) ) R.output(gv4) return gv4 @@ -597,9 +597,9 @@ def main(x1: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32 return gv1 @R.function(private=True) - def fused_relax_nn_relu( - x11: R.Tensor((10,), dtype="float32") - ) -> R.Tensor((10,), dtype="float32"): + def fused_relax_nn_relu(x11: R.Tensor((10,), dtype="float32")) -> R.Tensor( + (10,), dtype="float32" + ): R.func_attr({"Primitive": 1, "Composite": "compiler_A.relu"}) with R.dataflow(): gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(x11) @@ -607,9 +607,9 @@ def fused_relax_nn_relu( return gv2 @R.function(private=True) - def fused_relax_nn_gelu( - x21: R.Tensor((10,), dtype="float32") - ) -> R.Tensor((10,), dtype="float32"): + def fused_relax_nn_gelu(x21: R.Tensor((10,), dtype="float32")) -> R.Tensor( + (10,), dtype="float32" + ): R.func_attr({"Primitive": 1, "Composite": "compiler_A.gelu"}) with R.dataflow(): gv3: R.Tensor((10,), dtype="float32") = R.nn.gelu(x21) @@ -642,9 +642,9 @@ def main(x1: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32 return gv @R.function - def fused_relax_nn_relu1_compiler_A( - x11: R.Tensor((10,), dtype="float32") - ) -> R.Tensor((10,), dtype="float32"): + def fused_relax_nn_relu1_compiler_A(x11: R.Tensor((10,), dtype="float32")) -> R.Tensor( + (10,), dtype="float32" + ): # function attr dict R.func_attr({"Codegen": "compiler_A"}) @@ -720,9 +720,9 @@ def main( return gv1 @R.function(private=True) - def fused_relax_nn_relu( - add2: R.Tensor((10,), dtype="float32") - ) -> R.Tensor((10,), dtype="float32"): + def fused_relax_nn_relu(add2: R.Tensor((10,), dtype="float32")) -> R.Tensor( + (10,), dtype="float32" + ): R.func_attr({"Primitive": 1, "Composite": "compiler_A.relu"}) with R.dataflow(): gv: R.Tensor((10,), dtype="float32") = R.nn.relu(add2) @@ -740,9 +740,9 @@ def fused_relax_add( return gv2 @R.function(private=True) - def fused_relax_nn_gelu( - x31: R.Tensor((10,), dtype="float32") - ) -> R.Tensor((10,), dtype="float32"): + def fused_relax_nn_gelu(x31: R.Tensor((10,), dtype="float32")) -> R.Tensor( + (10,), dtype="float32" + ): R.func_attr({"Primitive": 1, "Composite": "compiler_B.gelu"}) with R.dataflow(): gv3: R.Tensor((10,), dtype="float32") = R.nn.gelu(x31) @@ -815,9 +815,9 @@ def lv31(add2: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float return gv3 @R.function - def fused_relax_nn_gelu1_compiler_B( - x3: R.Tensor((10,), dtype="float32") - ) -> R.Tensor((10,), dtype="float32"): + def fused_relax_nn_gelu1_compiler_B(x3: R.Tensor((10,), dtype="float32")) -> R.Tensor( + (10,), dtype="float32" + ): R.func_attr({"Codegen": "compiler_B"}) @R.function @@ -840,9 +840,9 @@ def main( cls = MergeCompilerRegionsExampleRef with R.dataflow(): lv5: R.Tensor((10,), dtype="float32") = cls.fused_relax_nn_gelu1_compiler_B(x32) - lv13: R.Tuple( - R.Tensor((10,), dtype="float32"), R.Tensor((10,), dtype="float32") - ) = cls.fused_relax_add_relax_add_relax_nn_relu_compiler_A(x12, x22, lv5) + lv13: R.Tuple(R.Tensor((10,), dtype="float32"), R.Tensor((10,), dtype="float32")) = ( + cls.fused_relax_add_relax_add_relax_nn_relu_compiler_A(x12, x22, lv5) + ) lv23: R.Tensor((10,), dtype="float32") = lv13[0] lv32: R.Tensor((10,), dtype="float32") = lv13[1] lv41: R.Tensor((10,), dtype="float32") = cls.fused_relax_nn_gelu1_compiler_B(lv23) @@ -1097,14 +1097,62 @@ def main( lv1: R.Tensor((784, 512), dtype="float32") = R.permute_dims( linear_relu_stack_0_weight, axes=None ) - gv: R.Tensor( - (1, 512), dtype="float32" - ) = cls.fused_relax_reshape_relax_matmul_tensorrt(inp_0, lv1) + gv: R.Tensor((1, 512), dtype="float32") = ( + cls.fused_relax_reshape_relax_matmul_tensorrt(inp_0, lv1) + ) R.output(gv) return gv check(Module, Expected) +def test_relax_is_not_required_to_be_named_main(): + """Functions inside IRModule may have arbitrary names + + This is a regression test. Earlier implementations of + MergeCompositeFunctions required the public-facing Relax function + of an IRModule to be named "main". + + """ + + Before = Conv2dReLUx2.clone() + Before["main_with_another_name"] = Before["main"].with_attr( + "global_symbol", "main_with_another_name" + ) + del Before["main"] + + Expected = Conv2dReLUx2_merged.clone() + Expected["main_with_another_name"] = Expected["main"].with_attr( + "global_symbol", "main_with_another_name" + ) + del Expected["main"] + + check(Before, Expected) + + +def test_multiple_relax_functions_may_be_present(): + """IRModule may contain multiple Relax functions. + + This is a regression test. Earlier implementations of + MergeCompositeFunctions required the IRModule to have only a + single Relax function. + + """ + + Before = Conv2dReLUx2.clone() + Before["main2"] = relax.utils.copy_with_new_vars(Before["main"]).with_attr( + "global_symbol", "main2" + ) + del Before["main"] + + Expected = Conv2dReLUx2_merged.clone() + Expected["main2"] = relax.utils.copy_with_new_vars(Expected["main"]).with_attr( + "global_symbol", "main2" + ) + del Expected["main"] + + check(Before, Expected) + + if __name__ == "__main__": pytest.main([__file__])