Skip to content

Commit

Permalink
[compiler] support shape reification for callOp
Browse files Browse the repository at this point in the history
  • Loading branch information
XG-zheng committed Apr 7, 2024
1 parent 173af1c commit 7a9d18c
Show file tree
Hide file tree
Showing 14 changed files with 286 additions and 26 deletions.
2 changes: 1 addition & 1 deletion compiler/include/byteir/Dialect/mhlo/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
#include "byteir/Dialect/mhlo/Transforms/LayoutTransformation.h"
#include "byteir/Dialect/mhlo/Transforms/MatmulLayoutTransform.h"
#include "byteir/Dialect/mhlo/Transforms/RewriteWithConstraint.h"
#include "byteir/Dialect/mhlo/Transforms/ShapeReification.h"
#include "byteir/Transforms/ShapeReification.h"
#include "byteir/Dialect/mhlo/Transforms/StaticShapeInference.h"
#include "byteir/Dialect/mhlo/Transforms/UnfuseBatchNorm.h"

Expand Down
19 changes: 0 additions & 19 deletions compiler/include/byteir/Dialect/mhlo/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -305,25 +305,6 @@ def RewriteWithConstraint : Pass<"rewrite-with-constraint", "mlir::func::FuncOp
let constructor = "mlir::createRewriteWithConstraintPass()";
}

//===----------------------------------------------------------------------===//
// ShapeReification
//===----------------------------------------------------------------------===//

def ShapeReification : Pass<"byteir-shape-reification", "func::FuncOp"> {
let summary = "Iteratively reify all shape computations.";
let description = [{
If an operation has a shape reification implementation, that is to say, we
know how to express the outputs' shape by it's inputs' shape symbolicly,
then a tensor.dim or shape.shape_of on this type of operation could be
reified. And shape reification procedure could be handled recursively.
}];
let constructor = "mlir::createByteIRShapeReificationPass()";
let dependentDialects = [
"mlir::shape::ShapeDialect",
"mlir::tensor::TensorDialect"
];
}

//===----------------------------------------------------------------------===//
// Static Shape Inference
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions compiler/include/byteir/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "byteir/Transforms/SetArgShape.h"
#include "byteir/Transforms/SetSpace.h"
#include "byteir/Transforms/TryCatchModulePipeline.h"
#include "byteir/Transforms/ShapeReification.h"

namespace mlir {

Expand Down
19 changes: 19 additions & 0 deletions compiler/include/byteir/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -425,4 +425,23 @@ def SetOpSpace: Pass<"set-op-space", "func::FuncOp"> {
];
}

//===----------------------------------------------------------------------===//
// ShapeReification
//===----------------------------------------------------------------------===//

def ShapeReification : Pass<"byteir-shape-reification", "func::FuncOp"> {
let summary = "Iteratively reify all shape computations.";
let description = [{
If an operation has a shape reification implementation, that is to say, we
know how to express the outputs' shape by it's inputs' shape symbolicly,
then a tensor.dim or shape.shape_of on this type of operation could be
reified. And shape reification procedure could be handled recursively.
}];
let constructor = "mlir::createByteIRShapeReificationPass()";
let dependentDialects = [
"mlir::shape::ShapeDialect",
"mlir::tensor::TensorDialect"
];
}

#endif // BYTEIR_TRANSFORMS_PASSES
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//===- ShapeReification.h -------------------------------------*--- C++ -*-===//
//
// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved.
// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
Expand Down
2 changes: 1 addition & 1 deletion compiler/lib/Analysis/SymbolicShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
//===----------------------------------------------------------------------===//

#include "byteir/Analysis/SymbolicShape.h"
#include "byteir/Dialect/mhlo/Transforms/ShapeReification.h"
#include "byteir/Transforms/ShapeReification.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
Expand Down
1 change: 0 additions & 1 deletion compiler/lib/Dialect/mhlo/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ add_mlir_dialect_library(ByteIRMhloPasses
Transforms/ReduceFusion.cpp
Transforms/ReshapeGather.cpp
Transforms/RewriteWithConstraint.cpp
Transforms/ShapeReification.cpp
Transforms/StaticShapeInference.cpp
Transforms/TrivialFusion.cpp
Transforms/UnfuseBatchNorm.cpp
Expand Down
178 changes: 178 additions & 0 deletions compiler/lib/Dialect/mhlo/Util/ShapeInferUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,19 @@

#include "byteir/Dialect/mhlo/Util/ShapeInferUtil.h"
#include "byteir/Dialect/mhlo/DynamicShapeOpRegister/Register.h"
#include "byteir/Transforms/ShapeReification.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/TopologicalSortUtils.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/Debug.h"
#include <queue>
#include <string>

using namespace mlir;

Expand Down Expand Up @@ -177,6 +183,168 @@ mlir::inferReturnTypeComponents(llvm::StringRef name) {
return nullptr;
}

namespace {

SmallVector<Operation *> collectAllOpsForReturn(Operation *retOp) {
llvm::DenseSet<Operation *> visitedOp;
std::queue<Operation *> opQueue;

opQueue.push(retOp);
while (!opQueue.empty()) {
auto frontOp = opQueue.front();
opQueue.pop();
if (visitedOp.find(frontOp) != visitedOp.end()) {
continue;
}
visitedOp.insert(frontOp);
for (Value operand : frontOp->getOperands()) {
if (!operand.getDefiningOp()) {
continue;
}
if (Operation *defOp = operand.getDefiningOp()) {
opQueue.push(defOp);
}
}
}
visitedOp.erase(retOp);
return SmallVector<Operation *>(visitedOp.begin(), visitedOp.end());
}

bool deduceFromFuncArgShape(Value value) {
if (value.isa<BlockArgument>()) {
return false;
}

auto defOp = value.getDefiningOp();
if (!defOp) {
return false;
}

if (isa<arith::ConstantIndexOp, arith::ConstantOp>(defOp)) {
return true;
}

if (isa<tensor::DimOp, shape::ShapeOfOp>(defOp)) {
auto operand = defOp->getOperand(0);
if (operand.isa<BlockArgument>()) {
return true;
}
return false;
}

for (Value &&operand : defOp->getOperands()) {
if (!deduceFromFuncArgShape(operand)) {
return false;
}
}
return true;
}

LogicalResult reifyCallOp(OpBuilder &builder, Operation *op,
SmallVectorImpl<Value> &reifications) {
OpBuilder::InsertionGuard guard(builder);
auto callOp = dyn_cast<func::CallOp>(op);
if (!callOp) {
return failure();
}

ModuleOp moduleOp = op->getParentRegion()->getParentOfType<ModuleOp>();
// auxiliary builder used for create operations in shape func
// original builder maybe a rewriter, used for create operations in specific
// pattern.
OpBuilder auxiliaryBuilder(moduleOp);
StringRef funcName = callOp.getCallee();
auto funcOp = moduleOp.lookupSymbol<func::FuncOp>(funcName);

// clone funcOp, newFuncOp used for deduce function shape
std::string newFuncName = funcName.str() + "_Shape";
auxiliaryBuilder.setInsertionPointToStart(moduleOp.getBody());
auto newFuncOp = auxiliaryBuilder.create<func::FuncOp>(
funcOp->getLoc(), newFuncName, funcOp.getFunctionType());
newFuncOp.setPrivate();
IRMapping emptyBvm;
funcOp.cloneInto(newFuncOp, emptyBvm);

// replace the operands of returnOp with corresponding shape
func::ReturnOp retOp = *newFuncOp.getOps<func::ReturnOp>().begin();
if (!retOp) {
newFuncOp->erase();
return failure();
}

SmallVector<Type> allResultTypes;
SmallVector<Value> allResults;

auxiliaryBuilder.setInsertionPoint(retOp);
for (Value &&retTensor : retOp.getOperands()) {
auto retShape =
auxiliaryBuilder.create<shape::ShapeOfOp>(retOp.getLoc(), retTensor);
allResultTypes.emplace_back(retShape.getType());
allResults.emplace_back(retShape);
}

// return the shape of original tensor returned by function
auto newRetOp =
auxiliaryBuilder.create<func::ReturnOp>(retOp.getLoc(), allResults);
auto newFuncType = auxiliaryBuilder.getFunctionType(
newFuncOp.getArgumentTypes(), allResultTypes);
newFuncOp.setFunctionType(newFuncType);
retOp->erase();

// reify newFunc to get the shape computation for current callOp
{
PassManager pm(moduleOp->getContext(), func::FuncOp::getOperationName());
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
pm.addPass(createByteIRShapeReificationPass());
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());

if (mlir::failed(pm.run(newFuncOp))) {
newFuncOp->erase();
return failure();
}
}

// collect all shape computation ops
SmallVector<Operation *> reificationOps = collectAllOpsForReturn(newRetOp);

// value only depends on the shape of FuncArgs.
for (Value &&ret : newRetOp.getOperands()) {
if (!deduceFromFuncArgShape(ret)) {
newFuncOp->erase();
return failure();
}
}

// mapping the shape computation ops and collect reifications
{
mlir::computeTopologicalSorting(reificationOps);

IRMapping bvm;
size_t numArg = newFuncOp.getNumArguments();
for (size_t i = 0; i < numArg; ++i) {
bvm.map(newFuncOp.getArgument(i), callOp.getOperand(i));
}

builder.setInsertionPoint(callOp);

for (Operation *oldOp : reificationOps) {
auto newOp = builder.clone(*oldOp, bvm);
}

for (Value &&ret : newRetOp.getOperands()) {
reifications.push_back(bvm.lookup(ret));
}
}

// remove newFuncOp
newFuncOp->erase();
return success();
}

} // namespace

LogicalResult mlir::reifyShapes(OpBuilder &builder, Operation *op,
SmallVectorImpl<Value> &reifications) {
if (!op)
Expand Down Expand Up @@ -207,6 +375,16 @@ LogicalResult mlir::reifyShapes(OpBuilder &builder, Operation *op,
}
if (failed(inferFunc(op, builder, op->getOperands(), reifications)))
return failure();
} else if (auto callOp = dyn_cast<func::CallOp>(op)) {
if (failed(reifyCallOp(builder, op, reifications))) {
return failure();
}
} else if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(op)) {
for (OpResult &&result : op->getOpResults()) {
auto tiedOperand = dpsOp.getTiedOpOperand(result);
reifications.push_back(
builder.create<shape::ShapeOfOp>(op->getLoc(), tiedOperand->get()));
}
} else {
// Return failure if op doesn't have InferShapedTypeOpInterface and not
// registered.
Expand Down
1 change: 1 addition & 0 deletions compiler/lib/Pipelines/ByreTensorOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ void createByreTensorOptPipelineImpl(OpPassManager &pm, std::string entryFunc,
createConvertHloToByreCustomPass(getCudaByreCustomConfig()));
pm.addNestedPass<func::FuncOp>(
createConvertHloToByreTensorPass(appendArgTypes));
pm.addNestedPass<func::FuncOp>(createByteIRShapeReificationPass());
pm.addPass(createCanonicalizerPass());
}
} // namespace
Expand Down
1 change: 1 addition & 0 deletions compiler/lib/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ add_mlir_library(ByteIRTransforms
RewriteOpToStdCall.cpp
SetArgShape.cpp
SetSpace.cpp
ShapeReification.cpp
Utils.cpp

ADDITIONAL_HEADER_DIRS
Expand Down
8 changes: 8 additions & 0 deletions compiler/lib/Transforms/PassDetail.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ namespace scf {
class SCFDialect;
} // namespace scf

namespace shape {
class ShapeDialect;
} // namespace shape

namespace tensor {
class TensorDialect;
} // namespace tensor

#define GEN_PASS_CLASSES
#include "byteir/Transforms/Passes.h.inc"

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//===- ShapeReification.cpp -----------------------------------*--- C++ -*-===//
//
// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved.
// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
Expand All @@ -15,7 +15,7 @@
//
//===----------------------------------------------------------------------===//

#include "byteir/Dialect/mhlo/Transforms/ShapeReification.h"
#include "byteir/Transforms/ShapeReification.h"

#include "byteir/Dialect/mhlo/DynamicShapeOpRegister/Register.h"
#include "byteir/Dialect/mhlo/Util/ShapeInferUtil.h"
Expand Down
17 changes: 17 additions & 0 deletions compiler/test/Conversion/FuncToByre/func_to_byre_tensor.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,20 @@ func.func @test_normal_function_call(%arg0 : tensor<4xf32>) -> tensor<4xf32> att
}
// CHECK-LABEL: test_normal_function_call
// CHECK: call @some_func


// -----

func.func private @Unknown0(%arg0: tensor<?x20xf32>, %arg1: tensor<?x20xf32>) -> tensor<?x20xf32> attributes {__byteir_elementwise_fusion__, byre_compute_name = "Unknown0"} {
%0 = mhlo.add %arg0, %arg1 : tensor<?x20xf32>
return %0 : tensor<?x20xf32>
}

func.func @forward(%arg0: tensor<?x20xf32>, %arg1: tensor<?x20xf32>) -> tensor<?x20xf32> attributes {__placeholder__byre.entry_point} {
%1 = call @Unknown0(%arg1, %arg0) : (tensor<?x20xf32>, tensor<?x20xf32>) -> tensor<?x20xf32>
return %1 : tensor<?x20xf32>
}

// CHECK-LABEL: func.func @forward
// CHECK: tensor.empty
// CHECK-NEXT: byre.compute_on_tensor @Unknown0
Loading

0 comments on commit 7a9d18c

Please sign in to comment.