Skip to content

Commit

Permalink
[tf-frontend] remove shape.split_at, shape.cstr_broadcastable, mhlo.c…
Browse files Browse the repository at this point in the history
…str_reshapable
  • Loading branch information
heromapwrd committed Feb 5, 2024
1 parent 1b6a971 commit 435a78b
Show file tree
Hide file tree
Showing 10 changed files with 199 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
Expand All @@ -56,12 +57,14 @@ struct CustomizedTfToMhloPipelinePass
const std::vector<std::string> &customcall_ops, bool remove_control_flow,
bool staticalize_dynamic_shape, bool stop_after_rewrite_custom_call,
const std::unordered_map<std::string, Attribute>
&additional_main_func_attrs) {
&additional_main_func_attrs,
bool set_assuming_to_be_true) {
this->customCallOps = customcall_ops;
this->removeControlFlow = remove_control_flow;
this->staticalizeDynamicShape = staticalize_dynamic_shape;
this->stopAfterRewriteCustomCall = stop_after_rewrite_custom_call;
this->additional_main_func_attrs = additional_main_func_attrs;
this->setAssumingToBeTrue = set_assuming_to_be_true;
}

void runOnOperation() override {
Expand All @@ -72,15 +75,13 @@ struct CustomizedTfToMhloPipelinePass
pm.addPass(mlir::createSCCPPass());
pm.addPass(mlir::createCanonicalizerPass());

//// not safe remove-control-flow pass
// if (removeControlFlow)
// pm.addNestedPass<mlir::func::FuncOp>(
// mlir::tfext::createRemoveControlFlowPass());
// prun useless tf node
pm.addNestedPass<mlir::func::FuncOp>(
mlir::tf_executor::CreateTFExecutorGraphPruningPass());
if (removeControlFlow) {
pm.addNestedPass<mlir::func::FuncOp>(
mlir::tfext::createTFSwitchMergeToIfPass());
}

// prun useless tf node
pm.addNestedPass<mlir::func::FuncOp>(
mlir::tf_executor::CreateTFExecutorGraphPruningPass());
Expand Down Expand Up @@ -209,6 +210,12 @@ struct CustomizedTfToMhloPipelinePass
mlir::tfext::createRewriteFuncAttrToByteIRPass(
additional_main_func_attrs));

if (setAssumingToBeTrue) {
pm.addNestedPass<mlir::func::FuncOp>(
mlir::createRemoveShapeConstraintsPass());
pm.addNestedPass<mlir::func::FuncOp>(
mlir::tfext::createRemoveCstrReshapablePass());
}
pm.addPass(mlir::createCanonicalizerPass());

pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass());
Expand All @@ -230,8 +237,10 @@ mlir::tfext::createCustomizedTfToMhloPipelinePass(
bool staticalize_dynamic_shape /*= false*/,
bool stop_after_rewrite_custom_call /*= false*/,
const std::unordered_map<std::string, Attribute>
&additional_main_func_attrs /*= {}*/) {
&additional_main_func_attrs /*= {}*/,
bool set_assuming_to_be_true /*= true*/) {
return std::make_unique<CustomizedTfToMhloPipelinePass>(
customcall_ops, remove_control_flow, staticalize_dynamic_shape,
stop_after_rewrite_custom_call, additional_main_func_attrs);
stop_after_rewrite_custom_call, additional_main_func_attrs,
set_assuming_to_be_true);
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createCustomizedTfToMhloPipelinePass(
bool remove_control_flow = false, bool staticalize_dynamic_shape = false,
bool stop_after_rewrite_custom_call = false,
const std::unordered_map<std::string, Attribute>
&additional_main_func_attrs = {});
&additional_main_func_attrs = {},
bool set_assuming_to_be_true = true);

} // namespace tfext
} // namespace mlir
Expand Down
4 changes: 3 additions & 1 deletion frontends/tf-frontend/tf_mlir_ext/pipelines/passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def CustomizedTfToMhloPipeline : Pass<"customized-tf-to-mhlo", "mlir::ModuleOp">
/*default*/"false", "Aggresively and experimentally try to rewrite "
"the dynamic graph to a equivalent static graph">,
Option<"stopAfterRewriteCustomCall", "stop-after-rewrite-customcall", "bool", /*default=*/"false",
"stop after rewrite customcall ops">
"stop after rewrite customcall ops">,
Option<"setAssumingToBeTrue", "set-assuming-to-be-true", "bool", /*default=*/"true",
"remove cstr_reshapable,cstr_broadcastable, and set assuming to be true">
];
}

Expand Down
2 changes: 2 additions & 0 deletions frontends/tf-frontend/tf_mlir_ext/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ cc_library(
"rewrite_to_custom_call.cc",
"tf_fallback_to_custom_call.cc",
"tf_switch_merge_to_if.cc",
"remove_cstr_reshapable.cc",
],
hdrs = [
"constant_folding.h",
Expand All @@ -102,6 +103,7 @@ cc_library(
"rewrite_to_custom_call.h",
"tf_fallback_to_custom_call.h",
"tf_switch_merge_to_if.h",
"remove_cstr_reshapable.h",
],
textual_hdrs = [
"passes_detail.h",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -467,9 +467,65 @@ class ConvertStridedSliceOp : public OpRewritePattern<TF::StridedSliceOp> {
}
};

// Returns a PrecisionConfig as an array attribute based on whether TF32
// execution is enabled
static ArrayAttr GetPrecisionConfig(Builder *builder) {
mlir::mhlo::Precision precision = mhlo::Precision::DEFAULT;
llvm::SmallVector<mlir::Attribute, 2> attr_vec;
const int num_inputs = 2;
for (int i = 0; i < num_inputs; i++) {
attr_vec.push_back(
mlir::mhlo::PrecisionAttr::get(builder->getContext(), precision));
}
return builder->getArrayAttr(attr_vec);
}

class ConvertBatchMatMulV2Op : public OpRewritePattern<TF::BatchMatMulV2Op> {
public:
using OpRewritePattern<TF::BatchMatMulV2Op>::OpRewritePattern;

LogicalResult matchAndRewrite(TF::BatchMatMulV2Op op,
PatternRewriter &rewriter) const override {
Value lhs = op.getX();
Value rhs = op.getY();
auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
if (!lhs_type || !rhs_type)
return failure();
if (lhs_type.getRank() != rhs_type.getRank())
return failure();
auto lhs_batch = lhs_type.getShape().drop_back(2);
auto rhs_batch = rhs_type.getShape().drop_back(2);
bool batch_equal =
std::equal(lhs_batch.begin(), lhs_batch.end(), rhs_batch.begin());
bool is_static = std::all_of(lhs_batch.begin(), lhs_batch.end(),
[](int64_t dim) { return dim > 0; });
if (!batch_equal || !is_static) {
return failure();
}
int64_t rank = lhs_type.getRank();
auto batch_dimensions = llvm::to_vector<4>(llvm::seq<int64_t>(0, rank - 2));
auto lhs_contracting_dimensions = llvm::to_vector<4>(
llvm::ArrayRef({op.getAdjX() ? rank - 2 : rank - 1}));
auto rhs_contracting_dimensions = llvm::to_vector<4>(
llvm::ArrayRef({op.getAdjY() ? rank - 1 : rank - 2}));
auto dimension_numbers = mhlo::DotDimensionNumbersAttr::get(
rewriter.getContext(),
/*lhs_batching_dimensions=*/batch_dimensions,
/*rhs_batching_dimensions=*/batch_dimensions,
/*lhs_contracting_dimensions=*/lhs_contracting_dimensions,
/*rhs_contracting_dimensions=*/rhs_contracting_dimensions);
rewriter.replaceOpWithNewOp<mhlo::DotGeneralOp>(
op, op.getType(), lhs, rhs, dimension_numbers,
/*precision_config=*/GetPrecisionConfig(&rewriter));
return success();
}
};

void PopulateMhloLegalizeTfExtPatterns(MLIRContext *context,
RewritePatternSet *patterns) {
patterns->add(std::make_unique<ConvertStridedSliceOp>(context));
patterns->add(std::make_unique<ConvertBatchMatMulV2Op>(context));
}

struct MhloLegalizeTfExtPass
Expand Down
1 change: 1 addition & 0 deletions frontends/tf-frontend/tf_mlir_ext/transforms/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "tf_mlir_ext/transforms/mhlo_legalize_tf_ext.h"
#include "tf_mlir_ext/transforms/process_dynamic_stitch_as_static.h"
#include "tf_mlir_ext/transforms/remove_control_flow.h"
#include "tf_mlir_ext/transforms/remove_cstr_reshapable.h"
#include "tf_mlir_ext/transforms/reshape_movedown_string.h"
#include "tf_mlir_ext/transforms/rewrite_func_attr_to_byteir.h"
#include "tf_mlir_ext/transforms/rewrite_to_custom_call.h"
Expand Down
5 changes: 5 additions & 0 deletions frontends/tf-frontend/tf_mlir_ext/transforms/passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,9 @@ def TFSwitchMergeToIf : Pass<"tf-switch-merge-to-if", "func::FuncOp"> {
let constructor = "mlir::tfext::createTFSwitchMergeToIfPass()";
}

def RemoveCstrReshapable : Pass<"remove-cstr-reshapable", "func::FuncOp"> {
let summary = "replace cstr_reshapable to true";
let constructor = "mlir::tfext::createRemoveCstrReshapablePass()";
}

#endif // TF_MLIR_EXT_PASSES
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
//===- remove_cstr_reshapable.cc ------------------------------*--- C++ -*-===//
//
// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//===----------------------------------------------------------------------===//

#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/include/mlir/IR/IRMapping.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"

#include "tf_mlir_ext/transforms/passes_detail.h"
#include "tf_mlir_ext/transforms/remove_cstr_reshapable.h"

using namespace mlir;
using namespace llvm;

#define DEBUG_TYPE "remove_cstr_reshapable"

namespace {
/// Removal patterns.
class RemoveCstrReshapableOp : public OpRewritePattern<mhlo::CstrReshapableOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mhlo::CstrReshapableOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true);
return success();
}
};

/// Removal pass.
class RemoveCstrReshapablePass
: public RemoveCstrReshapableBase<RemoveCstrReshapablePass> {

void runOnOperation() override {
MLIRContext &ctx = getContext();

RewritePatternSet patterns(&ctx);
patterns.add<RemoveCstrReshapableOp>(patterns.getContext());

(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};

} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::tfext::createRemoveCstrReshapablePass() {
return std::make_unique<RemoveCstrReshapablePass>();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
//===- remove_cstr_reshapable.h -------------------------------*--- C++ -*-===//
//
// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//===----------------------------------------------------------------------===//

#ifndef TFEXT_TRANSFORMS_REMOVE_CSTR_RESHAPABLE
#define TFEXT_TRANSFORMS_REMOVE_CSTR_RESHAPABLE

#include <memory>

#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"

namespace mlir {
namespace tfext {

// -------------------------------------------
std::unique_ptr<OperationPass<func::FuncOp>> createRemoveCstrReshapablePass();

} // namespace tfext
} // namespace mlir

#endif // TFEXT_TRANSFORMS_REMOVE_CSTR_RESHAPABLE
9 changes: 8 additions & 1 deletion frontends/tf-frontend/tools/tf_frontend_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ static llvm::cl::opt<bool> keep_original_input_names(
llvm::cl::desc("put original input names in main func as an ArrayAttr"),
llvm::cl::init(false));

static llvm::cl::opt<bool> set_assuming_to_be_true(
"set-assuming-to-be-true",
llvm::cl::desc("remove cstr_reshapable and cstr_broadcastable,"
"and remove assuming"),
llvm::cl::init(true));

int main(int argc, char **argv) {
tensorflow::InitMlir y(&argc, &argv);

Expand Down Expand Up @@ -252,7 +258,8 @@ int main(int argc, char **argv) {
tf_frontend_manager.addPass(
::mlir::tfext::createCustomizedTfToMhloPipelinePass(
customcall_ops_array, remove_control_flow, staticalize_dynamic_shape,
stop_after_rewrite_customcall, additional_main_func_attrs));
stop_after_rewrite_customcall, additional_main_func_attrs,
set_assuming_to_be_true));
if (mlir::failed(tf_frontend_manager.run(*module))) {
llvm::outs() << "tf frontend customized-tf-to-mhlo pipeline failed\n";
return 1;
Expand Down

0 comments on commit 435a78b

Please sign in to comment.