Skip to content

Commit

Permalink
Sync internal master 1785f23..6a90b68 (#50)
Browse files Browse the repository at this point in the history
* [Byre] Improved getSeed
* [CAT] Improved CAT preporcessing, refined bmm-reshape-transpose fusion
* [frontend/ONNX] Fixed resize, l2norm, batchnorm
* [frontend/torch-mlir] Upgraded to b552d4ed956d82f5d9d0823b4727bb10bac6787c
* [runtime] Fixed rng state
* [Stats] Extended ByteIR-stat tp show dtype
* [TransformOp] Extended type in annotate
* [Version] Bumped to 1.3.0
  • Loading branch information
liwenchangbdbz authored Aug 30, 2023
1 parent 2694dd3 commit 0e76730
Show file tree
Hide file tree
Showing 30 changed files with 648 additions and 254 deletions.
4 changes: 2 additions & 2 deletions compiler/doc/byteir_mhlo_custom_call.md
Original file line number Diff line number Diff line change
Expand Up @@ -208,15 +208,15 @@ Further needed infomation for a given coarse-grained op are encoded in a diction
%high = mhlo.constant dense<1.000000e+00> : tensor<f32>
%low = mhlo.constant dense<0.000000e+00> : tensor<f32>
%seed = byre.compute @GetSeed() : tensor<i64>
%offset = byre.compute @GetOffset() : tensor<i64>
%offset = byre.compute @NextOffset() : tensor<i64>
%0 = "mhlo.custom_call"(%low, %high, %seed, %offset) {call_target_name = "byteir.rng_uniform", has_side_effect = false} : (tensor<f32>, tensor<f32>, tensor<i64>, tensor<i64>) -> tensor<8x1024x768xf32>
```
```
// Dynamic Shape Case
%high = mhlo.constant dense<1.000000e+00> : tensor<f32>
%low = mhlo.constant dense<0.000000e+00> : tensor<f32>
%seed = byre.compute @GetSeed() : tensor<i64>
%offset = byre.compute @GetOffset() : tensor<i64>
%offset = byre.compute @NextOffset() : tensor<i64>
%shape = shape.shape_of %arg0 : tensor<3xindex>
%0 = "mhlo.custom_call"(%low, %high, %seed, %offset, %shape) {call_target_name = "byteir.rng_uniform", has_side_effect = false} : (tensor<f32>, tensor<f32>, tensor<i64>, tensor<i64>, tensor<3xindex>) -> tensor<?x?x?xf32>
```
7 changes: 6 additions & 1 deletion compiler/include/byteir/Dialect/mhlo/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,14 @@ def ConvForwardFusion : Pass<"fuse-conv-forward", "mlir::func::FuncOp"> {
// Convert Rng To CustomCall
//===----------------------------------------------------------------------===//

def ConvertOpToCustomCall : Pass<"convert-op-to-customcall", "mlir::func::FuncOp"> {
def ConvertOpToCustomCall : Pass<"convert-op-to-customcall", "ModuleOp"> {
let summary = "Convert op to mhlo.custom_call";
let constructor = "mlir::createConvertOpToCustomCallPass()";
let options = [
Option<"anchorTag", "anchor-tag", "std::string",
/*default=*/"",
"Optional unitAttr anchored tag to apply this pass">
];
}

//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
#include <memory>

namespace mlir {
namespace func {
class FuncOp;
} // namespace func

class ModuleOp;

void populateRngPatternToCustomCall(RewritePatternSet &patterns);

std::unique_ptr<OperationPass<func::FuncOp>> createConvertOpToCustomCallPass();
std::unique_ptr<OperationPass<ModuleOp>>
createConvertOpToCustomCallPass(llvm::StringRef anchor = "");

} // namespace mlir

Expand Down
35 changes: 22 additions & 13 deletions compiler/lib/Conversion/HloToCat/FuseHloToCat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,8 @@ struct ConvertTransposeReshapeBmmRrrToBmmRcr
}
};

struct ConvertBmmRrrReshapeTransposeToBmmRrc
template <typename SrcBmmType, typename DstBmmType>
struct ConvertBmmReshapeTransposeToBmmReshape
: public OpRewritePattern<mhlo::TransposeOp> {
using OpRewritePattern<mhlo::TransposeOp>::OpRewritePattern;

Expand All @@ -239,8 +240,8 @@ struct ConvertBmmRrrReshapeTransposeToBmmRrc
if (!reshapeOp || !reshapeOp.getResult().hasOneUse()) {
return failure();
}
auto bmmrrrOp = reshapeOp.getOperand().getDefiningOp<cat::BMMRRROp>();
if (!bmmrrrOp || !bmmrrrOp.getResult().hasOneUse()) {
auto srcBmmOp = reshapeOp.getOperand().getDefiningOp<SrcBmmType>();
if (!srcBmmOp || !srcBmmOp.getResult().hasOneUse()) {
return failure();
}
SmallVector<int64_t> permutation;
Expand All @@ -266,17 +267,17 @@ struct ConvertBmmRrrReshapeTransposeToBmmRrc
return failure();
}

auto bmmrrrOpType = bmmrrrOp.getType().cast<ShapedType>();
// build bmm_rrc op
RankedTensorType bmmrrcResultType = RankedTensorType::get(
{bmmrrrOpType.getDimSize(0), bmmrrrOpType.getDimSize(2),
bmmrrrOpType.getDimSize(1)},
bmmrrrOpType.getElementType());
auto bmmrrcOp = rewriter.create<cat::BMMRRCOp>(
op.getLoc(), bmmrrcResultType, bmmrrrOp.getLhs(), bmmrrrOp.getRhs());
auto srcBmmOpType = srcBmmOp.getType().template cast<ShapedType>();
// build dst bmm op
RankedTensorType dstBmmOpResultType = RankedTensorType::get(
{srcBmmOpType.getDimSize(0), srcBmmOpType.getDimSize(2),
srcBmmOpType.getDimSize(1)},
srcBmmOpType.getElementType());
auto dstBmmOp = rewriter.create<DstBmmType>(
op.getLoc(), dstBmmOpResultType, srcBmmOp.getLhs(), srcBmmOp.getRhs());
// build new reshape op
auto newShapeOp = rewriter.create<mhlo::ReshapeOp>(
op.getLoc(), op.getType(), bmmrrcOp.getResult());
op.getLoc(), op.getType(), dstBmmOp.getResult());
rewriter.replaceOp(op, newShapeOp.getResult());
return success();
}
Expand Down Expand Up @@ -309,7 +310,15 @@ void populateFuseMhloToCatPattern(RewritePatternSet &patterns) {
ConvertLayerNorm,
ConvertTransposeGemmRrrToBmmCrr,
ConvertTransposeReshapeBmmRrrToBmmRcr,
ConvertBmmRrrReshapeTransposeToBmmRrc>(patterns.getContext());
ConvertBmmReshapeTransposeToBmmReshape<cat::BMMRRROp, cat::BMMRRCOp>,
ConvertBmmReshapeTransposeToBmmReshape<cat::BMMRCROp, cat::BMMRCCOp>,
ConvertBmmReshapeTransposeToBmmReshape<cat::BMMCRROp, cat::BMMCRCOp>,
ConvertBmmReshapeTransposeToBmmReshape<cat::BMMCCROp, cat::BMMCCCOp>,
ConvertBmmReshapeTransposeToBmmReshape<cat::BMMRRCOp, cat::BMMRRROp>,
ConvertBmmReshapeTransposeToBmmReshape<cat::BMMRCCOp, cat::BMMRCROp>,
ConvertBmmReshapeTransposeToBmmReshape<cat::BMMCRCOp, cat::BMMCRROp>,
ConvertBmmReshapeTransposeToBmmReshape<cat::BMMCCCOp, cat::BMMCCROp>
>(patterns.getContext());
// clang-format on
}

Expand Down
23 changes: 23 additions & 0 deletions compiler/lib/Dialect/Transform/IR/TransformExtOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,24 @@ using namespace mlir;
using namespace mlir::transform_ext;

//===---------------------------------------------------------------------===//
// Type Extensions
//===---------------------------------------------------------------------===//
namespace {
struct PDLAttributeTypeTransformParamTypeInterfaceImpl
: public transform::TransformParamTypeInterface::ExternalModel<
PDLAttributeTypeTransformParamTypeInterfaceImpl, pdl::AttributeType> {

/// Accept any attribute.
DiagnosedSilenceableFailure checkPayload(Type type, Location loc,
ArrayRef<Attribute> payload) const {
return DiagnosedSilenceableFailure::success();
}
};
} // namespace

//===---------------------------------------------------------------------===//
// Op Extensions
//
// CanonicalizeExtOp
//===---------------------------------------------------------------------===//

Expand Down Expand Up @@ -154,6 +172,11 @@ class TransformExtDialectExtension
#define GET_OP_LIST
#include "byteir/Dialect/Transform/IR/TransformExtOps.cpp.inc"
>();

addCustomInitializationStep([](MLIRContext *context) {
pdl::AttributeType::attachInterface<
PDLAttributeTypeTransformParamTypeInterfaceImpl>(*context);
});
}
};
} // namespace
Expand Down
106 changes: 82 additions & 24 deletions compiler/lib/Dialect/mhlo/Transforms/ConvertOpToCustomCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,62 @@
#include "byteir/Dialect/mhlo/Transforms/ConvertOpToCustomCall.h"

#include "./PassDetail.h"
#include "byteir/Dialect/Byre/ByreDialect.h"
#include "byteir/Dialect/Byre/Common.h"
#include "byteir/Dialect/mhlo/Util/CustomCallUtil.h"
#include "byteir/Utils/Utils.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;

namespace {

func::FuncOp getOrCreatePrivateFunctionDeclare(ModuleOp module,
const std::string &funcName,
const std::string &byreOpName,
FunctionType funcType) {
auto func = SymbolTable(module).lookup<func::FuncOp>(funcName);
if (func) {
// TODO(lyq): check func's type == funcType, and check func's attr
return func;
} else {
MLIRContext *context = module.getContext();
OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
func = builder.create<func::FuncOp>(UnknownLoc::get(context), funcName,
funcType);
func.setPrivate();
func->setAttr(byre::getByreComputeName(),
builder.getStringAttr(byreOpName));
func->setAttr(byre::getByreForceComputeNameAttrName(),
UnitAttr::get(context));
return func;
}
}

func::CallOp getOrCreateCallGetSeedOp(func::FuncOp func,
func::FuncOp getSeedFunc,
PatternRewriter &rewriter) {
func::CallOp callGetSeedOp;
func.walk([&](func::CallOp op) {
if (getFuncOp(op) == getSeedFunc) {
callGetSeedOp = op;
}
});
if (!callGetSeedOp) {
callGetSeedOp = rewriter.create<func::CallOp>(
UnknownLoc::get(rewriter.getContext()), getSeedFunc, ArrayRef<Value>{});
}
// move func.call @getSeed to the begin of func
Block *block = callGetSeedOp->getBlock();
callGetSeedOp->moveBefore(&block->front());
return callGetSeedOp;
}

struct ConvertRngUniformToCustomCall : public OpRewritePattern<mhlo::RngOp> {
using OpRewritePattern<mhlo::RngOp>::OpRewritePattern;

Expand All @@ -42,13 +86,23 @@ struct ConvertRngUniformToCustomCall : public OpRewritePattern<mhlo::RngOp> {
auto B = op.getB();
auto shape = op.getShape();
TensorType resultType = op.getResult().getType();
TensorType seedType = RankedTensorType::get({}, rewriter.getI64Type());
auto getSeedOp =
rewriter.create<byre::ComputeOp>(op->getLoc(), ArrayRef<Type>{seedType},
"GetSeed", ValueRange(), ArrayAttr());
auto getOffsetOp = rewriter.create<byre::ComputeOp>(
op->getLoc(), ArrayRef<Type>{seedType}, "GetOffset", ValueRange(),
ArrayAttr());
TensorType seedOrOffsetType =
RankedTensorType::get({}, rewriter.getI64Type());

ModuleOp module = op->getParentRegion()->getParentOfType<ModuleOp>();
auto functionType = FunctionType::get(module.getContext(), {},
ArrayRef<Type>{seedOrOffsetType});
func::FuncOp getSeedFunc = getOrCreatePrivateFunctionDeclare(
module, "GetSeedFunc", "GetSeed", functionType);
func::FuncOp nextOffsetFunc = getOrCreatePrivateFunctionDeclare(
module, "NextOffsetFunc", "NextOffset", functionType);

// avoid to call @getSeed every time
auto getSeedOp = getOrCreateCallGetSeedOp(
op->getParentRegion()->getParentOfType<func::FuncOp>(), getSeedFunc,
rewriter);
auto getOffsetOp = rewriter.create<func::CallOp>(
op->getLoc(), nextOffsetFunc, ArrayRef<Value>{});
SmallVector<Value> bufferArgs{A, B, getSeedOp.getResults()[0],
getOffsetOp.getResults()[0]};
if (!op.getType().hasStaticShape()) {
Expand All @@ -66,27 +120,31 @@ struct ConvertRngUniformToCustomCall : public OpRewritePattern<mhlo::RngOp> {
return success();
}
};

struct ConvertOpToCustomCallPass
: public ConvertOpToCustomCallBase<ConvertOpToCustomCallPass> {
public:
ConvertOpToCustomCallPass() = default;

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<byre::ByreDialect>();
registry.insert<mhlo::MhloDialect>();
ConvertOpToCustomCallPass(llvm::StringRef anchor)
: ConvertOpToCustomCallBase() {
this->anchorTag = anchor.str();
}

void runOnOperation() override {
func::FuncOp funcOp = getOperation();
MLIRContext *context = &getContext();
ModuleOp moduleOp = getOperation();

for (auto funcOp : moduleOp.getOps<func::FuncOp>()) {
if (!this->anchorTag.empty() && !funcOp->hasAttr(this->anchorTag)) {
continue;
}

MLIRContext *context = &getContext();

RewritePatternSet patterns(context);
populateRngPatternToCustomCall(patterns);
RewritePatternSet patterns(context);
populateRngPatternToCustomCall(patterns);

FrozenRewritePatternSet frozenPatterns(std::move(patterns));
if (failed(applyPatternsAndFoldGreedily(funcOp, frozenPatterns))) {
signalPassFailure();
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
if (failed(applyPatternsAndFoldGreedily(funcOp, frozenPatterns))) {
signalPassFailure();
}
}
}
};
Expand All @@ -97,7 +155,7 @@ void mlir::populateRngPatternToCustomCall(RewritePatternSet &patterns) {
patterns.add<ConvertRngUniformToCustomCall>(patterns.getContext());
}

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::createConvertOpToCustomCallPass() {
return std::make_unique<ConvertOpToCustomCallPass>();
std::unique_ptr<OperationPass<ModuleOp>>
mlir::createConvertOpToCustomCallPass(llvm::StringRef anchor) {
return std::make_unique<ConvertOpToCustomCallPass>(anchor);
}
19 changes: 15 additions & 4 deletions compiler/lib/Dialect/mhlo/Transforms/GenericFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "byteir/Dialect/mhlo/Transforms/HloFuser.h"

#include "byteir/Dialect/mhlo/Transforms/GenericFusionCommon.h"
#include "byteir/Dialect/mhlo/Util/CustomCallUtil.h"
#include "byteir/Dialect/mhlo/Util/FusionUtil.h"
#include "byteir/Dialect/mhlo/Util/Util.h"
#include "byteir/Utils/IRRewrite.h"
Expand All @@ -36,13 +37,21 @@ using namespace mlir::mhlo;
namespace {
namespace elementwise {

bool isCustomMhloRngOp(Operation *op) {
if (auto customOp = llvm::dyn_cast_or_null<mhlo::CustomCallOp>(op)) {
return customOp.getCallTargetName() == getRngUniformName();
}
return false;
}

// TODO: maybe we should support non-splat constant on device in future
bool isFusibleCandidate(Operation *op) {
return isMhlo(op) &&
(op->hasTrait<::mlir::OpTrait::Elementwise>() ||
op->hasTrait<hlo::OpTrait::BroadcastingElementwise>() ||
isSplatMhloConstantLike(op) ||
isa<mhlo::BroadcastInDimOp, mhlo::BroadcastOp, mhlo::ReshapeOp>(op));
isa<mhlo::BroadcastInDimOp, mhlo::BroadcastOp, mhlo::ReshapeOp>(op) ||
isCustomMhloRngOp(op));
}

// every candidate can start
Expand All @@ -51,7 +60,7 @@ bool isFusibleStart(Operation *op) { return true; }
bool isFusibleTrigger(Operation *op) {
if (op->hasTrait<::mlir::OpTrait::Elementwise>() ||
op->hasTrait<hlo::OpTrait::BroadcastingElementwise>() ||
isa<mhlo::ReshapeOp>(op)) {
isa<mhlo::ReshapeOp>(op) || isCustomMhloRngOp(op)) {
return true;
}

Expand All @@ -76,13 +85,15 @@ bool isFusibleWith(Operation *target, Operation * /*start*/) {
target->hasTrait<hlo::OpTrait::BroadcastingElementwise>() ||
isSplatMhloConstantLike(target) ||
isa<mhlo::BroadcastInDimOp, mhlo::BroadcastOp, mhlo::ReshapeOp>(
target);
target) ||
isCustomMhloRngOp(target);
}

bool isValidSingleOp(Operation *op) {
return op->hasTrait<::mlir::OpTrait::Elementwise>() ||
op->hasTrait<hlo::OpTrait::BroadcastingElementwise>() ||
isa<mhlo::BroadcastInDimOp, mhlo::BroadcastOp, mhlo::IotaOp>(op);
isa<mhlo::BroadcastInDimOp, mhlo::BroadcastOp, mhlo::IotaOp>(op) ||
isCustomMhloRngOp(op);
}

static GenericFuserConfig config{
Expand Down
9 changes: 2 additions & 7 deletions compiler/lib/Pipelines/CatPreprocess.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,10 @@

#include "byteir/Pipelines/CatPreprocess.h"

#include "byteir/Dialect/mhlo/Transforms/FuseBMMDimension.h"
#include "byteir/Dialect/mhlo/Transforms/HloFolder.h"
#include "byteir/Dialect/mhlo/Transforms/HloMove.h"
#include "byteir/Dialect/mhlo/Transforms/LayoutTransformation.h"
#include "byteir/Dialect/mhlo/Transforms/MatmulLayoutTransform.h"
#include "byteir/Dialect/mhlo/Passes.h"
#include "byteir/Pipelines/Common/Utils.h"
#include "byteir/Transforms/CanonicalizeExt.h"
#include "mhlo/IR/hlo_ops.h"
#include "mhlo/transforms/passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Transforms/DialectConversion.h"

Expand All @@ -39,7 +34,7 @@ void createCatPreprocessPipelineImpl(OpPassManager &pm,
const std::string &convLayout) {
pm.addNestedPass<func::FuncOp>(createFuseBMMDimensionPass());
pm.addNestedPass<func::FuncOp>(createMatmulLayoutTransformPass(true, "rcr"));
pm.addNestedPass<func::FuncOp>(createTestUnfuseBatchNormPass());
pm.addNestedPass<func::FuncOp>(createUnfuseBatchNormPass());
pm.addNestedPass<func::FuncOp>(createHloFolderPass());
pm.addNestedPass<func::FuncOp>(createLayoutTransformationPass(convLayout));
pm.addNestedPass<func::FuncOp>(createHloMoveDownPass());
Expand Down
Loading

0 comments on commit 0e76730

Please sign in to comment.