diff --git a/frontends/tf-frontend/tf_mlir_ext/pipelines/customized_tf_to_mhlo.cc b/frontends/tf-frontend/tf_mlir_ext/pipelines/customized_tf_to_mhlo.cc index a874c815b..1b2f98be7 100644 --- a/frontends/tf-frontend/tf_mlir_ext/pipelines/customized_tf_to_mhlo.cc +++ b/frontends/tf-frontend/tf_mlir_ext/pipelines/customized_tf_to_mhlo.cc @@ -30,6 +30,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Shape/Transforms/Passes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" @@ -56,12 +57,14 @@ struct CustomizedTfToMhloPipelinePass const std::vector &customcall_ops, bool remove_control_flow, bool staticalize_dynamic_shape, bool stop_after_rewrite_custom_call, const std::unordered_map - &additional_main_func_attrs) { + &additional_main_func_attrs, + bool set_assuming_to_be_true) { this->customCallOps = customcall_ops; this->removeControlFlow = remove_control_flow; this->staticalizeDynamicShape = staticalize_dynamic_shape; this->stopAfterRewriteCustomCall = stop_after_rewrite_custom_call; this->additional_main_func_attrs = additional_main_func_attrs; + this->setAssumingToBeTrue = set_assuming_to_be_true; } void runOnOperation() override { @@ -72,15 +75,13 @@ struct CustomizedTfToMhloPipelinePass pm.addPass(mlir::createSCCPPass()); pm.addPass(mlir::createCanonicalizerPass()); - //// not safe remove-control-flow pass - // if (removeControlFlow) - // pm.addNestedPass( - // mlir::tfext::createRemoveControlFlowPass()); + // prun useless tf node + pm.addNestedPass( + mlir::tf_executor::CreateTFExecutorGraphPruningPass()); if (removeControlFlow) { pm.addNestedPass( mlir::tfext::createTFSwitchMergeToIfPass()); } - // prun useless tf node pm.addNestedPass( mlir::tf_executor::CreateTFExecutorGraphPruningPass()); @@ -209,6 +210,12 @@ struct CustomizedTfToMhloPipelinePass mlir::tfext::createRewriteFuncAttrToByteIRPass( additional_main_func_attrs)); + if (setAssumingToBeTrue) { + pm.addNestedPass( + mlir::createRemoveShapeConstraintsPass()); + pm.addNestedPass( + mlir::tfext::createRemoveCstrReshapablePass()); + } pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); @@ -230,8 +237,10 @@ mlir::tfext::createCustomizedTfToMhloPipelinePass( bool staticalize_dynamic_shape /*= false*/, bool stop_after_rewrite_custom_call /*= false*/, const std::unordered_map - &additional_main_func_attrs /*= {}*/) { + &additional_main_func_attrs /*= {}*/, + bool set_assuming_to_be_true /*= true*/) { return std::make_unique( customcall_ops, remove_control_flow, staticalize_dynamic_shape, - stop_after_rewrite_custom_call, additional_main_func_attrs); + stop_after_rewrite_custom_call, additional_main_func_attrs, + set_assuming_to_be_true); } diff --git a/frontends/tf-frontend/tf_mlir_ext/pipelines/customized_tf_to_mhlo.h b/frontends/tf-frontend/tf_mlir_ext/pipelines/customized_tf_to_mhlo.h index 89788ae85..fc11c492a 100644 --- a/frontends/tf-frontend/tf_mlir_ext/pipelines/customized_tf_to_mhlo.h +++ b/frontends/tf-frontend/tf_mlir_ext/pipelines/customized_tf_to_mhlo.h @@ -33,7 +33,8 @@ std::unique_ptr> createCustomizedTfToMhloPipelinePass( bool remove_control_flow = false, bool staticalize_dynamic_shape = false, bool stop_after_rewrite_custom_call = false, const std::unordered_map - &additional_main_func_attrs = {}); + &additional_main_func_attrs = {}, + bool set_assuming_to_be_true = true); } // namespace tfext } // namespace mlir diff --git a/frontends/tf-frontend/tf_mlir_ext/pipelines/passes.td b/frontends/tf-frontend/tf_mlir_ext/pipelines/passes.td index 041a0c02c..baeb7a288 100644 --- a/frontends/tf-frontend/tf_mlir_ext/pipelines/passes.td +++ b/frontends/tf-frontend/tf_mlir_ext/pipelines/passes.td @@ -41,7 +41,9 @@ def CustomizedTfToMhloPipeline : Pass<"customized-tf-to-mhlo", "mlir::ModuleOp"> /*default*/"false", "Aggresively and experimentally try to rewrite " "the dynamic graph to a equivalent static graph">, Option<"stopAfterRewriteCustomCall", "stop-after-rewrite-customcall", "bool", /*default=*/"false", - "stop after rewrite customcall ops"> + "stop after rewrite customcall ops">, + Option<"setAssumingToBeTrue", "set-assuming-to-be-true", "bool", /*default=*/"true", + "remove cstr_reshapable,cstr_broadcastable, and set assuming to be true"> ]; } diff --git a/frontends/tf-frontend/tf_mlir_ext/transforms/BUILD b/frontends/tf-frontend/tf_mlir_ext/transforms/BUILD index a38478478..fff7b1181 100644 --- a/frontends/tf-frontend/tf_mlir_ext/transforms/BUILD +++ b/frontends/tf-frontend/tf_mlir_ext/transforms/BUILD @@ -88,6 +88,7 @@ cc_library( "rewrite_to_custom_call.cc", "tf_fallback_to_custom_call.cc", "tf_switch_merge_to_if.cc", + "remove_cstr_reshapable.cc", ], hdrs = [ "constant_folding.h", @@ -102,6 +103,7 @@ cc_library( "rewrite_to_custom_call.h", "tf_fallback_to_custom_call.h", "tf_switch_merge_to_if.h", + "remove_cstr_reshapable.h", ], textual_hdrs = [ "passes_detail.h", diff --git a/frontends/tf-frontend/tf_mlir_ext/transforms/mhlo_legalize_tf_ext.cc b/frontends/tf-frontend/tf_mlir_ext/transforms/mhlo_legalize_tf_ext.cc index 5dc05d502..80f55d5b4 100644 --- a/frontends/tf-frontend/tf_mlir_ext/transforms/mhlo_legalize_tf_ext.cc +++ b/frontends/tf-frontend/tf_mlir_ext/transforms/mhlo_legalize_tf_ext.cc @@ -467,9 +467,65 @@ class ConvertStridedSliceOp : public OpRewritePattern { } }; +// Returns a PrecisionConfig as an array attribute based on whether TF32 +// execution is enabled +static ArrayAttr GetPrecisionConfig(Builder *builder) { + mlir::mhlo::Precision precision = mhlo::Precision::DEFAULT; + llvm::SmallVector attr_vec; + const int num_inputs = 2; + for (int i = 0; i < num_inputs; i++) { + attr_vec.push_back( + mlir::mhlo::PrecisionAttr::get(builder->getContext(), precision)); + } + return builder->getArrayAttr(attr_vec); +} + +class ConvertBatchMatMulV2Op : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::BatchMatMulV2Op op, + PatternRewriter &rewriter) const override { + Value lhs = op.getX(); + Value rhs = op.getY(); + auto lhs_type = lhs.getType().dyn_cast(); + auto rhs_type = rhs.getType().dyn_cast(); + if (!lhs_type || !rhs_type) + return failure(); + if (lhs_type.getRank() != rhs_type.getRank()) + return failure(); + auto lhs_batch = lhs_type.getShape().drop_back(2); + auto rhs_batch = rhs_type.getShape().drop_back(2); + bool batch_equal = + std::equal(lhs_batch.begin(), lhs_batch.end(), rhs_batch.begin()); + bool is_static = std::all_of(lhs_batch.begin(), lhs_batch.end(), + [](int64_t dim) { return dim > 0; }); + if (!batch_equal || !is_static) { + return failure(); + } + int64_t rank = lhs_type.getRank(); + auto batch_dimensions = llvm::to_vector<4>(llvm::seq(0, rank - 2)); + auto lhs_contracting_dimensions = llvm::to_vector<4>( + llvm::ArrayRef({op.getAdjX() ? rank - 2 : rank - 1})); + auto rhs_contracting_dimensions = llvm::to_vector<4>( + llvm::ArrayRef({op.getAdjY() ? rank - 1 : rank - 2})); + auto dimension_numbers = mhlo::DotDimensionNumbersAttr::get( + rewriter.getContext(), + /*lhs_batching_dimensions=*/batch_dimensions, + /*rhs_batching_dimensions=*/batch_dimensions, + /*lhs_contracting_dimensions=*/lhs_contracting_dimensions, + /*rhs_contracting_dimensions=*/rhs_contracting_dimensions); + rewriter.replaceOpWithNewOp( + op, op.getType(), lhs, rhs, dimension_numbers, + /*precision_config=*/GetPrecisionConfig(&rewriter)); + return success(); + } +}; + void PopulateMhloLegalizeTfExtPatterns(MLIRContext *context, RewritePatternSet *patterns) { patterns->add(std::make_unique(context)); + patterns->add(std::make_unique(context)); } struct MhloLegalizeTfExtPass diff --git a/frontends/tf-frontend/tf_mlir_ext/transforms/passes.h b/frontends/tf-frontend/tf_mlir_ext/transforms/passes.h index 969ec6cfa..9baecbe81 100644 --- a/frontends/tf-frontend/tf_mlir_ext/transforms/passes.h +++ b/frontends/tf-frontend/tf_mlir_ext/transforms/passes.h @@ -31,6 +31,7 @@ #include "tf_mlir_ext/transforms/mhlo_legalize_tf_ext.h" #include "tf_mlir_ext/transforms/process_dynamic_stitch_as_static.h" #include "tf_mlir_ext/transforms/remove_control_flow.h" +#include "tf_mlir_ext/transforms/remove_cstr_reshapable.h" #include "tf_mlir_ext/transforms/reshape_movedown_string.h" #include "tf_mlir_ext/transforms/rewrite_func_attr_to_byteir.h" #include "tf_mlir_ext/transforms/rewrite_to_custom_call.h" diff --git a/frontends/tf-frontend/tf_mlir_ext/transforms/passes.td b/frontends/tf-frontend/tf_mlir_ext/transforms/passes.td index 407438129..dc82e8543 100644 --- a/frontends/tf-frontend/tf_mlir_ext/transforms/passes.td +++ b/frontends/tf-frontend/tf_mlir_ext/transforms/passes.td @@ -108,4 +108,9 @@ def TFSwitchMergeToIf : Pass<"tf-switch-merge-to-if", "func::FuncOp"> { let constructor = "mlir::tfext::createTFSwitchMergeToIfPass()"; } +def RemoveCstrReshapable : Pass<"remove-cstr-reshapable", "func::FuncOp"> { + let summary = "replace cstr_reshapable to true"; + let constructor = "mlir::tfext::createRemoveCstrReshapablePass()"; +} + #endif // TF_MLIR_EXT_PASSES diff --git a/frontends/tf-frontend/tf_mlir_ext/transforms/remove_cstr_reshapable.cc b/frontends/tf-frontend/tf_mlir_ext/transforms/remove_cstr_reshapable.cc new file mode 100644 index 000000000..1b2a0c729 --- /dev/null +++ b/frontends/tf-frontend/tf_mlir_ext/transforms/remove_cstr_reshapable.cc @@ -0,0 +1,68 @@ +//===- remove_cstr_reshapable.cc ------------------------------*--- C++ -*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#include "mhlo/IR/hlo_ops.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/include/mlir/IR/IRMapping.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#include "tf_mlir_ext/transforms/passes_detail.h" +#include "tf_mlir_ext/transforms/remove_cstr_reshapable.h" + +using namespace mlir; +using namespace llvm; + +#define DEBUG_TYPE "remove_cstr_reshapable" + +namespace { +/// Removal patterns. +class RemoveCstrReshapableOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mhlo::CstrReshapableOp op, + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op.getOperation(), true); + return success(); + } +}; + +/// Removal pass. +class RemoveCstrReshapablePass + : public RemoveCstrReshapableBase { + + void runOnOperation() override { + MLIRContext &ctx = getContext(); + + RewritePatternSet patterns(&ctx); + patterns.add(patterns.getContext()); + + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + +} // namespace + +std::unique_ptr> +mlir::tfext::createRemoveCstrReshapablePass() { + return std::make_unique(); +} diff --git a/frontends/tf-frontend/tf_mlir_ext/transforms/remove_cstr_reshapable.h b/frontends/tf-frontend/tf_mlir_ext/transforms/remove_cstr_reshapable.h new file mode 100644 index 000000000..89762c5aa --- /dev/null +++ b/frontends/tf-frontend/tf_mlir_ext/transforms/remove_cstr_reshapable.h @@ -0,0 +1,37 @@ +//===- remove_cstr_reshapable.h -------------------------------*--- C++ -*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef TFEXT_TRANSFORMS_REMOVE_CSTR_RESHAPABLE +#define TFEXT_TRANSFORMS_REMOVE_CSTR_RESHAPABLE + +#include + +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" + +namespace mlir { +namespace tfext { + +// ------------------------------------------- +std::unique_ptr> createRemoveCstrReshapablePass(); + +} // namespace tfext +} // namespace mlir + +#endif // TFEXT_TRANSFORMS_REMOVE_CSTR_RESHAPABLE diff --git a/frontends/tf-frontend/tools/tf_frontend_main.cc b/frontends/tf-frontend/tools/tf_frontend_main.cc index 6290bb47c..eaad13d1a 100644 --- a/frontends/tf-frontend/tools/tf_frontend_main.cc +++ b/frontends/tf-frontend/tools/tf_frontend_main.cc @@ -119,6 +119,12 @@ static llvm::cl::opt keep_original_input_names( llvm::cl::desc("put original input names in main func as an ArrayAttr"), llvm::cl::init(false)); +static llvm::cl::opt set_assuming_to_be_true( + "set-assuming-to-be-true", + llvm::cl::desc("remove cstr_reshapable and cstr_broadcastable," + "and remove assuming"), + llvm::cl::init(true)); + int main(int argc, char **argv) { tensorflow::InitMlir y(&argc, &argv); @@ -252,7 +258,8 @@ int main(int argc, char **argv) { tf_frontend_manager.addPass( ::mlir::tfext::createCustomizedTfToMhloPipelinePass( customcall_ops_array, remove_control_flow, staticalize_dynamic_shape, - stop_after_rewrite_customcall, additional_main_func_attrs)); + stop_after_rewrite_customcall, additional_main_func_attrs, + set_assuming_to_be_true)); if (mlir::failed(tf_frontend_manager.run(*module))) { llvm::outs() << "tf frontend customized-tf-to-mhlo pipeline failed\n"; return 1;