From 0e7673074c9d9a299717bb3379cb774c9921f72c Mon Sep 17 00:00:00 2001 From: Li-Wen Chang <120213201+liwenchangbdbz@users.noreply.github.com> Date: Wed, 30 Aug 2023 13:14:22 -0700 Subject: [PATCH] Sync internal master 1785f23d..6a90b687 (#50) * [Byre] Improved getSeed * [CAT] Improved CAT preporcessing, refined bmm-reshape-transpose fusion * [frontend/ONNX] Fixed resize, l2norm, batchnorm * [frontend/torch-mlir] Upgraded to b552d4ed956d82f5d9d0823b4727bb10bac6787c * [runtime] Fixed rng state * [Stats] Extended ByteIR-stat tp show dtype * [TransformOp] Extended type in annotate * [Version] Bumped to 1.3.0 --- compiler/doc/byteir_mhlo_custom_call.md | 4 +- .../include/byteir/Dialect/mhlo/Passes.td | 7 +- .../mhlo/Transforms/ConvertOpToCustomCall.h | 8 +- .../lib/Conversion/HloToCat/FuseHloToCat.cpp | 35 +++--- .../Dialect/Transform/IR/TransformExtOps.cpp | 23 ++++ .../mhlo/Transforms/ConvertOpToCustomCall.cpp | 106 ++++++++++++++---- .../Dialect/mhlo/Transforms/GenericFusion.cpp | 19 +++- compiler/lib/Pipelines/CatPreprocess.cpp | 9 +- compiler/lib/Pipelines/HloOpt.cpp | 5 +- compiler/lib/Stat/OpCnt/OpCnt.cpp | 99 ++++++++++++++-- compiler/python/version.txt | 2 +- .../test/Conversion/HloToCat/fused_ops.mlir | 11 ++ .../ToLinalg/rngCustomCallToLinalg.mlir | 26 +++-- compiler/test/Dialect/Linalg/annotate.mlir | 24 +++- .../transforms/ConvertOpToCustomCall.mlir | 47 ++++++-- compiler/test/Pipelines/HloOpts/rng.mlir | 82 ++++++-------- compiler/test/Stat/opCnt.mlir | 56 ++++----- compiler/test/Stat/opTypes.mlir | 65 +++++++++++ .../src/Compiler/OFCompilerPipelines.cpp | 10 +- .../src/Conversion/OFCanonicalizer.cpp | 10 +- .../src/Conversion/OFRewriteToCustomCall.cpp | 31 ++++- .../src/Conversion/OFRewriteToCustomCall.td | 10 +- .../test/of_rewrite_to_custom_call.mlir | 16 ++- .../patches/OnnxMlirResizeV13ShapeInfer.patch | 103 +++++++++++++++++ frontends/torch-frontend/requirements.txt | 8 +- .../patches/torchtostablehlo_basic.patch | 68 ----------- .../torch-frontend/third_party/torch-mlir | 2 +- .../torch-frontend/torch-requirements.txt | 2 +- runtime/VERSION_NUMBER | 2 +- .../default/tensor_generate/rng_state.cc | 12 +- 30 files changed, 648 insertions(+), 254 deletions(-) create mode 100644 compiler/test/Stat/opTypes.mlir create mode 100644 frontends/onnx-frontend/third_party/patches/OnnxMlirResizeV13ShapeInfer.patch delete mode 100644 frontends/torch-frontend/third_party/patches/torchtostablehlo_basic.patch diff --git a/compiler/doc/byteir_mhlo_custom_call.md b/compiler/doc/byteir_mhlo_custom_call.md index 8c2dbdd7a..1f97aff5d 100644 --- a/compiler/doc/byteir_mhlo_custom_call.md +++ b/compiler/doc/byteir_mhlo_custom_call.md @@ -208,7 +208,7 @@ Further needed infomation for a given coarse-grained op are encoded in a diction %high = mhlo.constant dense<1.000000e+00> : tensor %low = mhlo.constant dense<0.000000e+00> : tensor %seed = byre.compute @GetSeed() : tensor -%offset = byre.compute @GetOffset() : tensor +%offset = byre.compute @NextOffset() : tensor %0 = "mhlo.custom_call"(%low, %high, %seed, %offset) {call_target_name = "byteir.rng_uniform", has_side_effect = false} : (tensor, tensor, tensor, tensor) -> tensor<8x1024x768xf32> ``` ``` @@ -216,7 +216,7 @@ Further needed infomation for a given coarse-grained op are encoded in a diction %high = mhlo.constant dense<1.000000e+00> : tensor %low = mhlo.constant dense<0.000000e+00> : tensor %seed = byre.compute @GetSeed() : tensor -%offset = byre.compute @GetOffset() : tensor +%offset = byre.compute @NextOffset() : tensor %shape = shape.shape_of %arg0 : tensor<3xindex> %0 = "mhlo.custom_call"(%low, %high, %seed, %offset, %shape) {call_target_name = "byteir.rng_uniform", has_side_effect = false} : (tensor, tensor, tensor, tensor, tensor<3xindex>) -> tensor ``` diff --git a/compiler/include/byteir/Dialect/mhlo/Passes.td b/compiler/include/byteir/Dialect/mhlo/Passes.td index b98b15563..7fe03f8f8 100644 --- a/compiler/include/byteir/Dialect/mhlo/Passes.td +++ b/compiler/include/byteir/Dialect/mhlo/Passes.td @@ -71,9 +71,14 @@ def ConvForwardFusion : Pass<"fuse-conv-forward", "mlir::func::FuncOp"> { // Convert Rng To CustomCall //===----------------------------------------------------------------------===// -def ConvertOpToCustomCall : Pass<"convert-op-to-customcall", "mlir::func::FuncOp"> { +def ConvertOpToCustomCall : Pass<"convert-op-to-customcall", "ModuleOp"> { let summary = "Convert op to mhlo.custom_call"; let constructor = "mlir::createConvertOpToCustomCallPass()"; + let options = [ + Option<"anchorTag", "anchor-tag", "std::string", + /*default=*/"", + "Optional unitAttr anchored tag to apply this pass"> + ]; } //===----------------------------------------------------------------------===// diff --git a/compiler/include/byteir/Dialect/mhlo/Transforms/ConvertOpToCustomCall.h b/compiler/include/byteir/Dialect/mhlo/Transforms/ConvertOpToCustomCall.h index a5556d483..e84497390 100644 --- a/compiler/include/byteir/Dialect/mhlo/Transforms/ConvertOpToCustomCall.h +++ b/compiler/include/byteir/Dialect/mhlo/Transforms/ConvertOpToCustomCall.h @@ -22,13 +22,13 @@ #include namespace mlir { -namespace func { -class FuncOp; -} // namespace func + +class ModuleOp; void populateRngPatternToCustomCall(RewritePatternSet &patterns); -std::unique_ptr> createConvertOpToCustomCallPass(); +std::unique_ptr> +createConvertOpToCustomCallPass(llvm::StringRef anchor = ""); } // namespace mlir diff --git a/compiler/lib/Conversion/HloToCat/FuseHloToCat.cpp b/compiler/lib/Conversion/HloToCat/FuseHloToCat.cpp index eae11e03b..e78709194 100644 --- a/compiler/lib/Conversion/HloToCat/FuseHloToCat.cpp +++ b/compiler/lib/Conversion/HloToCat/FuseHloToCat.cpp @@ -229,7 +229,8 @@ struct ConvertTransposeReshapeBmmRrrToBmmRcr } }; -struct ConvertBmmRrrReshapeTransposeToBmmRrc +template +struct ConvertBmmReshapeTransposeToBmmReshape : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -239,8 +240,8 @@ struct ConvertBmmRrrReshapeTransposeToBmmRrc if (!reshapeOp || !reshapeOp.getResult().hasOneUse()) { return failure(); } - auto bmmrrrOp = reshapeOp.getOperand().getDefiningOp(); - if (!bmmrrrOp || !bmmrrrOp.getResult().hasOneUse()) { + auto srcBmmOp = reshapeOp.getOperand().getDefiningOp(); + if (!srcBmmOp || !srcBmmOp.getResult().hasOneUse()) { return failure(); } SmallVector permutation; @@ -266,17 +267,17 @@ struct ConvertBmmRrrReshapeTransposeToBmmRrc return failure(); } - auto bmmrrrOpType = bmmrrrOp.getType().cast(); - // build bmm_rrc op - RankedTensorType bmmrrcResultType = RankedTensorType::get( - {bmmrrrOpType.getDimSize(0), bmmrrrOpType.getDimSize(2), - bmmrrrOpType.getDimSize(1)}, - bmmrrrOpType.getElementType()); - auto bmmrrcOp = rewriter.create( - op.getLoc(), bmmrrcResultType, bmmrrrOp.getLhs(), bmmrrrOp.getRhs()); + auto srcBmmOpType = srcBmmOp.getType().template cast(); + // build dst bmm op + RankedTensorType dstBmmOpResultType = RankedTensorType::get( + {srcBmmOpType.getDimSize(0), srcBmmOpType.getDimSize(2), + srcBmmOpType.getDimSize(1)}, + srcBmmOpType.getElementType()); + auto dstBmmOp = rewriter.create( + op.getLoc(), dstBmmOpResultType, srcBmmOp.getLhs(), srcBmmOp.getRhs()); // build new reshape op auto newShapeOp = rewriter.create( - op.getLoc(), op.getType(), bmmrrcOp.getResult()); + op.getLoc(), op.getType(), dstBmmOp.getResult()); rewriter.replaceOp(op, newShapeOp.getResult()); return success(); } @@ -309,7 +310,15 @@ void populateFuseMhloToCatPattern(RewritePatternSet &patterns) { ConvertLayerNorm, ConvertTransposeGemmRrrToBmmCrr, ConvertTransposeReshapeBmmRrrToBmmRcr, - ConvertBmmRrrReshapeTransposeToBmmRrc>(patterns.getContext()); + ConvertBmmReshapeTransposeToBmmReshape, + ConvertBmmReshapeTransposeToBmmReshape, + ConvertBmmReshapeTransposeToBmmReshape, + ConvertBmmReshapeTransposeToBmmReshape, + ConvertBmmReshapeTransposeToBmmReshape, + ConvertBmmReshapeTransposeToBmmReshape, + ConvertBmmReshapeTransposeToBmmReshape, + ConvertBmmReshapeTransposeToBmmReshape + >(patterns.getContext()); // clang-format on } diff --git a/compiler/lib/Dialect/Transform/IR/TransformExtOps.cpp b/compiler/lib/Dialect/Transform/IR/TransformExtOps.cpp index fd0b65cc6..0cf7ab253 100644 --- a/compiler/lib/Dialect/Transform/IR/TransformExtOps.cpp +++ b/compiler/lib/Dialect/Transform/IR/TransformExtOps.cpp @@ -37,6 +37,24 @@ using namespace mlir; using namespace mlir::transform_ext; //===---------------------------------------------------------------------===// +// Type Extensions +//===---------------------------------------------------------------------===// +namespace { +struct PDLAttributeTypeTransformParamTypeInterfaceImpl + : public transform::TransformParamTypeInterface::ExternalModel< + PDLAttributeTypeTransformParamTypeInterfaceImpl, pdl::AttributeType> { + + /// Accept any attribute. + DiagnosedSilenceableFailure checkPayload(Type type, Location loc, + ArrayRef payload) const { + return DiagnosedSilenceableFailure::success(); + } +}; +} // namespace + +//===---------------------------------------------------------------------===// +// Op Extensions +// // CanonicalizeExtOp //===---------------------------------------------------------------------===// @@ -154,6 +172,11 @@ class TransformExtDialectExtension #define GET_OP_LIST #include "byteir/Dialect/Transform/IR/TransformExtOps.cpp.inc" >(); + + addCustomInitializationStep([](MLIRContext *context) { + pdl::AttributeType::attachInterface< + PDLAttributeTypeTransformParamTypeInterfaceImpl>(*context); + }); } }; } // namespace diff --git a/compiler/lib/Dialect/mhlo/Transforms/ConvertOpToCustomCall.cpp b/compiler/lib/Dialect/mhlo/Transforms/ConvertOpToCustomCall.cpp index 4ba13d433..0f7bd48f5 100644 --- a/compiler/lib/Dialect/mhlo/Transforms/ConvertOpToCustomCall.cpp +++ b/compiler/lib/Dialect/mhlo/Transforms/ConvertOpToCustomCall.cpp @@ -18,18 +18,62 @@ #include "byteir/Dialect/mhlo/Transforms/ConvertOpToCustomCall.h" #include "./PassDetail.h" -#include "byteir/Dialect/Byre/ByreDialect.h" +#include "byteir/Dialect/Byre/Common.h" #include "byteir/Dialect/mhlo/Util/CustomCallUtil.h" +#include "byteir/Utils/Utils.h" #include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; namespace { +func::FuncOp getOrCreatePrivateFunctionDeclare(ModuleOp module, + const std::string &funcName, + const std::string &byreOpName, + FunctionType funcType) { + auto func = SymbolTable(module).lookup(funcName); + if (func) { + // TODO(lyq): check func's type == funcType, and check func's attr + return func; + } else { + MLIRContext *context = module.getContext(); + OpBuilder builder = OpBuilder::atBlockBegin(module.getBody()); + func = builder.create(UnknownLoc::get(context), funcName, + funcType); + func.setPrivate(); + func->setAttr(byre::getByreComputeName(), + builder.getStringAttr(byreOpName)); + func->setAttr(byre::getByreForceComputeNameAttrName(), + UnitAttr::get(context)); + return func; + } +} + +func::CallOp getOrCreateCallGetSeedOp(func::FuncOp func, + func::FuncOp getSeedFunc, + PatternRewriter &rewriter) { + func::CallOp callGetSeedOp; + func.walk([&](func::CallOp op) { + if (getFuncOp(op) == getSeedFunc) { + callGetSeedOp = op; + } + }); + if (!callGetSeedOp) { + callGetSeedOp = rewriter.create( + UnknownLoc::get(rewriter.getContext()), getSeedFunc, ArrayRef{}); + } + // move func.call @getSeed to the begin of func + Block *block = callGetSeedOp->getBlock(); + callGetSeedOp->moveBefore(&block->front()); + return callGetSeedOp; +} + struct ConvertRngUniformToCustomCall : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -42,13 +86,23 @@ struct ConvertRngUniformToCustomCall : public OpRewritePattern { auto B = op.getB(); auto shape = op.getShape(); TensorType resultType = op.getResult().getType(); - TensorType seedType = RankedTensorType::get({}, rewriter.getI64Type()); - auto getSeedOp = - rewriter.create(op->getLoc(), ArrayRef{seedType}, - "GetSeed", ValueRange(), ArrayAttr()); - auto getOffsetOp = rewriter.create( - op->getLoc(), ArrayRef{seedType}, "GetOffset", ValueRange(), - ArrayAttr()); + TensorType seedOrOffsetType = + RankedTensorType::get({}, rewriter.getI64Type()); + + ModuleOp module = op->getParentRegion()->getParentOfType(); + auto functionType = FunctionType::get(module.getContext(), {}, + ArrayRef{seedOrOffsetType}); + func::FuncOp getSeedFunc = getOrCreatePrivateFunctionDeclare( + module, "GetSeedFunc", "GetSeed", functionType); + func::FuncOp nextOffsetFunc = getOrCreatePrivateFunctionDeclare( + module, "NextOffsetFunc", "NextOffset", functionType); + + // avoid to call @getSeed every time + auto getSeedOp = getOrCreateCallGetSeedOp( + op->getParentRegion()->getParentOfType(), getSeedFunc, + rewriter); + auto getOffsetOp = rewriter.create( + op->getLoc(), nextOffsetFunc, ArrayRef{}); SmallVector bufferArgs{A, B, getSeedOp.getResults()[0], getOffsetOp.getResults()[0]}; if (!op.getType().hasStaticShape()) { @@ -66,27 +120,31 @@ struct ConvertRngUniformToCustomCall : public OpRewritePattern { return success(); } }; - struct ConvertOpToCustomCallPass : public ConvertOpToCustomCallBase { -public: - ConvertOpToCustomCallPass() = default; - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - registry.insert(); + ConvertOpToCustomCallPass(llvm::StringRef anchor) + : ConvertOpToCustomCallBase() { + this->anchorTag = anchor.str(); } void runOnOperation() override { - func::FuncOp funcOp = getOperation(); - MLIRContext *context = &getContext(); + ModuleOp moduleOp = getOperation(); + + for (auto funcOp : moduleOp.getOps()) { + if (!this->anchorTag.empty() && !funcOp->hasAttr(this->anchorTag)) { + continue; + } + + MLIRContext *context = &getContext(); - RewritePatternSet patterns(context); - populateRngPatternToCustomCall(patterns); + RewritePatternSet patterns(context); + populateRngPatternToCustomCall(patterns); - FrozenRewritePatternSet frozenPatterns(std::move(patterns)); - if (failed(applyPatternsAndFoldGreedily(funcOp, frozenPatterns))) { - signalPassFailure(); + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); + if (failed(applyPatternsAndFoldGreedily(funcOp, frozenPatterns))) { + signalPassFailure(); + } } } }; @@ -97,7 +155,7 @@ void mlir::populateRngPatternToCustomCall(RewritePatternSet &patterns) { patterns.add(patterns.getContext()); } -std::unique_ptr> -mlir::createConvertOpToCustomCallPass() { - return std::make_unique(); +std::unique_ptr> +mlir::createConvertOpToCustomCallPass(llvm::StringRef anchor) { + return std::make_unique(anchor); } diff --git a/compiler/lib/Dialect/mhlo/Transforms/GenericFusion.cpp b/compiler/lib/Dialect/mhlo/Transforms/GenericFusion.cpp index 60999f980..0caf95fbf 100644 --- a/compiler/lib/Dialect/mhlo/Transforms/GenericFusion.cpp +++ b/compiler/lib/Dialect/mhlo/Transforms/GenericFusion.cpp @@ -18,6 +18,7 @@ #include "byteir/Dialect/mhlo/Transforms/HloFuser.h" #include "byteir/Dialect/mhlo/Transforms/GenericFusionCommon.h" +#include "byteir/Dialect/mhlo/Util/CustomCallUtil.h" #include "byteir/Dialect/mhlo/Util/FusionUtil.h" #include "byteir/Dialect/mhlo/Util/Util.h" #include "byteir/Utils/IRRewrite.h" @@ -36,13 +37,21 @@ using namespace mlir::mhlo; namespace { namespace elementwise { +bool isCustomMhloRngOp(Operation *op) { + if (auto customOp = llvm::dyn_cast_or_null(op)) { + return customOp.getCallTargetName() == getRngUniformName(); + } + return false; +} + // TODO: maybe we should support non-splat constant on device in future bool isFusibleCandidate(Operation *op) { return isMhlo(op) && (op->hasTrait<::mlir::OpTrait::Elementwise>() || op->hasTrait() || isSplatMhloConstantLike(op) || - isa(op)); + isa(op) || + isCustomMhloRngOp(op)); } // every candidate can start @@ -51,7 +60,7 @@ bool isFusibleStart(Operation *op) { return true; } bool isFusibleTrigger(Operation *op) { if (op->hasTrait<::mlir::OpTrait::Elementwise>() || op->hasTrait() || - isa(op)) { + isa(op) || isCustomMhloRngOp(op)) { return true; } @@ -76,13 +85,15 @@ bool isFusibleWith(Operation *target, Operation * /*start*/) { target->hasTrait() || isSplatMhloConstantLike(target) || isa( - target); + target) || + isCustomMhloRngOp(target); } bool isValidSingleOp(Operation *op) { return op->hasTrait<::mlir::OpTrait::Elementwise>() || op->hasTrait() || - isa(op); + isa(op) || + isCustomMhloRngOp(op); } static GenericFuserConfig config{ diff --git a/compiler/lib/Pipelines/CatPreprocess.cpp b/compiler/lib/Pipelines/CatPreprocess.cpp index 2e1818b51..f3b1ea665 100644 --- a/compiler/lib/Pipelines/CatPreprocess.cpp +++ b/compiler/lib/Pipelines/CatPreprocess.cpp @@ -17,15 +17,10 @@ #include "byteir/Pipelines/CatPreprocess.h" -#include "byteir/Dialect/mhlo/Transforms/FuseBMMDimension.h" -#include "byteir/Dialect/mhlo/Transforms/HloFolder.h" -#include "byteir/Dialect/mhlo/Transforms/HloMove.h" -#include "byteir/Dialect/mhlo/Transforms/LayoutTransformation.h" -#include "byteir/Dialect/mhlo/Transforms/MatmulLayoutTransform.h" +#include "byteir/Dialect/mhlo/Passes.h" #include "byteir/Pipelines/Common/Utils.h" #include "byteir/Transforms/CanonicalizeExt.h" #include "mhlo/IR/hlo_ops.h" -#include "mhlo/transforms/passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Transforms/DialectConversion.h" @@ -39,7 +34,7 @@ void createCatPreprocessPipelineImpl(OpPassManager &pm, const std::string &convLayout) { pm.addNestedPass(createFuseBMMDimensionPass()); pm.addNestedPass(createMatmulLayoutTransformPass(true, "rcr")); - pm.addNestedPass(createTestUnfuseBatchNormPass()); + pm.addNestedPass(createUnfuseBatchNormPass()); pm.addNestedPass(createHloFolderPass()); pm.addNestedPass(createLayoutTransformationPass(convLayout)); pm.addNestedPass(createHloMoveDownPass()); diff --git a/compiler/lib/Pipelines/HloOpt.cpp b/compiler/lib/Pipelines/HloOpt.cpp index 9320a3f24..b80e815de 100644 --- a/compiler/lib/Pipelines/HloOpt.cpp +++ b/compiler/lib/Pipelines/HloOpt.cpp @@ -32,8 +32,8 @@ void addGenericHloFusionPatterns(OpPassManager &pm, const std::string &entry, bool outlineSingleElemwiseOp, bool outlineCatOp, bool aggressiveCatFusion) { // cluster constraint - pm.addNestedPass(createClusterConstraintPass()); - pm.addPass(createFusionOutliningPass()); + // pm.addNestedPass(createClusterConstraintPass()); + // pm.addPass(createFusionOutliningPass()); // Fusion passes if (outlineCatOp) { @@ -87,6 +87,7 @@ void createHloOptPipelineImpl(OpPassManager &pm, const std::string &entryFunc, pm.addNestedPass(createHloTransposeDotToDotGeneralPass()); pm.addNestedPass(createReduceFusionPass()); pm.addNestedPass(createReshapeGatherPass()); + pm.addPass(createConvertOpToCustomCallPass()); // rewrite with constraint pm.addNestedPass(createRewriteWithConstraintPass()); diff --git a/compiler/lib/Stat/OpCnt/OpCnt.cpp b/compiler/lib/Stat/OpCnt/OpCnt.cpp index 452ece21e..9a1c670e2 100644 --- a/compiler/lib/Stat/OpCnt/OpCnt.cpp +++ b/compiler/lib/Stat/OpCnt/OpCnt.cpp @@ -21,7 +21,10 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "llvm/Support/CommandLine.h" +#include + using namespace byteir; +using namespace llvm; using namespace mlir; //===----------------------------------------------------------------------===// @@ -40,18 +43,47 @@ mlir::LogicalResult byteir::opCntStatistics(ModuleOp moduleOp, llvm::raw_ostream &os, const std::string &funcNmae, bool topOnly) { - os << "========== Operation Type and Its Numbers ============\n"; llvm::StringMap opCnt; + llvm::StringMap> opInDTypes, opOutDTypes; + + auto collectTypes = [&](Type type, StringRef key, bool isOperand) { + if (auto shapedType = type.dyn_cast_or_null()) { + auto dtype = shapedType.getElementType(); + // type to string + std::string typeStr; + llvm::raw_string_ostream os(typeStr); + dtype.print(os); + if (isOperand) { + opInDTypes[key].insert(os.str()); + } else { + opOutDTypes[key].insert(os.str()); + } + } + }; if (funcNmae.empty()) { for (func::FuncOp func : moduleOp.getOps()) { if (topOnly) { - for (auto &op : func.getOps()) { - opCnt[op.getName().getStringRef()] += 1; - } + auto countOps = [&](auto &op) { + llvm::StringRef key = op.getName().getStringRef(); + opCnt[key] += 1; + + llvm::for_each(op.getOperandTypes(), + [&](Type type) { collectTypes(type, key, true); }); + llvm::for_each(op.getResultTypes(), + [&](Type type) { collectTypes(type, key, false); }); + }; + llvm::for_each(func.getOps(), countOps); } else { - func.walk( - [&](Operation *op) { opCnt[op->getName().getStringRef()] += 1; }); + func.walk([&](Operation *op) { + llvm::StringRef key = op->getName().getStringRef(); + opCnt[key] += 1; + + llvm::for_each(op->getOperandTypes(), + [&](Type type) { collectTypes(type, key, true); }); + llvm::for_each(op->getResultTypes(), + [&](Type type) { collectTypes(type, key, false); }); + }); } } } else { @@ -63,19 +95,62 @@ mlir::LogicalResult byteir::opCntStatistics(ModuleOp moduleOp, return success(); if (topOnly) { - for (auto &op : func.getOps()) { - opCnt[op.getName().getStringRef()] += 1; - } + auto countOps = [&](auto &op) { + llvm::StringRef key = op.getName().getStringRef(); + opCnt[key] += 1; + + llvm::for_each(op.getOperandTypes(), + [&](Type type) { collectTypes(type, key, true); }); + llvm::for_each(op.getResultTypes(), + [&](Type type) { collectTypes(type, key, false); }); + }; + llvm::for_each(func.getOps(), countOps); } else { - func.walk( - [&](Operation *op) { opCnt[op->getName().getStringRef()] += 1; }); + func.walk([&](Operation *op) { + llvm::StringRef key = op->getName().getStringRef(); + opCnt[key] += 1; + + llvm::for_each(op->getOperandTypes(), + [&](Type type) { collectTypes(type, key, true); }); + llvm::for_each(op->getResultTypes(), + [&](Type type) { collectTypes(type, key, false); }); + }); } } SmallVector sorted(opCnt.keys()); llvm::sort(sorted); + os << "========== Operation Statistics ============\n"; + os << "Operation Type \t Numbers \t Operand Types \t Result Types\n"; for (auto opType : sorted) { - os << opType << " " << opCnt[opType] << "\n"; + os << opType << "\t\t" << opCnt[opType] << "\t"; + + // Operands data types + for (auto it = opInDTypes[opType].begin(); it != opInDTypes[opType].end(); + ++it) { + if (it == std::prev(opInDTypes[opType].end())) { + os << *it; + } else { + os << *it << ","; + } + } + if (opInDTypes[opType].empty()) { + os << "NA"; + } + os << "\t"; + // Resutls data types + for (auto it = opOutDTypes[opType].begin(); it != opOutDTypes[opType].end(); + ++it) { + if (it == std::prev(opOutDTypes[opType].end())) { + os << *it; + } else { + os << *it << ","; + } + } + if (opOutDTypes[opType].empty()) { + os << "NA"; + } + os << "\n"; } return success(); } \ No newline at end of file diff --git a/compiler/python/version.txt b/compiler/python/version.txt index 867e52437..589268e6f 100644 --- a/compiler/python/version.txt +++ b/compiler/python/version.txt @@ -1 +1 @@ -1.2.0 \ No newline at end of file +1.3.0 \ No newline at end of file diff --git a/compiler/test/Conversion/HloToCat/fused_ops.mlir b/compiler/test/Conversion/HloToCat/fused_ops.mlir index f53518ce7..0b84fc4d5 100644 --- a/compiler/test/Conversion/HloToCat/fused_ops.mlir +++ b/compiler/test/Conversion/HloToCat/fused_ops.mlir @@ -253,6 +253,17 @@ func.func @test_bmm_rrr_reshape_transpose_to_bmm_rrc_reshape(%arg0: tensor<64x12 // CHECK-NEXT: mhlo.reshape // CHECK-NEXT: return +func.func @test_bmm_crr_reshape_transpose_to_bmm_crc_reshape(%arg0: tensor<512x128x128xf16>, %arg1: tensor<512x128x128xf16>) -> tensor<16x32x128x128xf16> { + %0 = "cat.bmm_crr"(%arg0, %arg1) : (tensor<512x128x128xf16>, tensor<512x128x128xf16>) -> tensor<512x128x128xf16> + %1 = mhlo.reshape %0 : (tensor<512x128x128xf16>) -> tensor<16x32x128x128xf16> + %2 = "mhlo.transpose"(%1) {permutation = dense<[0, 1, 3, 2]> : tensor<4xi64>} : (tensor<16x32x128x128xf16>) -> tensor<16x32x128x128xf16> + return %2 : tensor<16x32x128x128xf16> +} +// CHECK: func.func @test_bmm_crr_reshape_transpose_to_bmm_crc_reshape +// CHECK-NEXT: cat.bmm_crc +// CHECK-NEXT: mhlo.reshape +// CHECK-NEXT: return + func.func @test_softmax_f16(%arg0 : tensor<1x12x1024x1024xf16>) -> tensor<1x12x1024x1024xf32> { %0 = mhlo.custom_call @byteir.softmax(%arg0) {backend_config = "", byteir_attrs = {axis = 3 : i64}} : (tensor<1x12x1024x1024xf16>) -> tensor<1x12x1024x1024xf32> return %0 : tensor<1x12x1024x1024xf32> diff --git a/compiler/test/Conversion/ToLinalg/rngCustomCallToLinalg.mlir b/compiler/test/Conversion/ToLinalg/rngCustomCallToLinalg.mlir index 7a0516082..93ab5ad6d 100644 --- a/compiler/test/Conversion/ToLinalg/rngCustomCallToLinalg.mlir +++ b/compiler/test/Conversion/ToLinalg/rngCustomCallToLinalg.mlir @@ -1,18 +1,20 @@ -// RUN: byteir-opt -hlo-fusion-to-linalg -cse %s | FileCheck %s +// RUN: byteir-opt -hlo-fusion-to-linalg -cse -split-input-file %s | FileCheck %s -func.func @convert_rng_static() -> tensor<8x1024x768xf32> attributes {__placeholder__byre.entry_point} { +func.func private @NextOffsetFunc() -> tensor attributes {byre_compute_name = "NextOffset", byre_force_compute_name} +func.func private @GetSeedFunc() -> tensor attributes {byre_compute_name = "GetSeed", byre_force_compute_name} +func.func @convert_rng_static() -> tensor<8x1024x768xf32> { %0 = mhlo.constant dense<1.000000e+00> : tensor %1 = mhlo.constant dense<0.000000e+00> : tensor - %2 = byre.compute @GetSeed() -> tensor - %3 = byre.compute @GetOffset() -> tensor + %2 = call @GetSeedFunc() : () -> tensor + %3 = call @NextOffsetFunc() : ()-> tensor %4 = mhlo.custom_call @byteir.rng_uniform(%1, %0, %2, %3) {backend_config = ""} : (tensor, tensor, tensor, tensor) -> tensor<8x1024x768xf32> return %4 : tensor<8x1024x768xf32> } // CHECK-LABEL: func.func @convert_rng_static // CHECK-DAG: arith.constant // CHECK-DAG: arith.constant -// CHECK-DAG: byre.compute -// CHECK-DAG: byre.compute +// CHECK-DAG: call +// CHECK-DAG: call // CHECK-DAG: tensor.empty // CHECK-NEXT: linalg.generic // CHECK-DAG: ^{{.*}}(%[[MIN:.+]]: f32, %[[MAX:.+]]: f32, %[[SEED:.+]]: i64, %[[OFFSET:.+]]: i64, %[[OUT:.+]]: f32 @@ -45,13 +47,17 @@ func.func @convert_rng_static() -> tensor<8x1024x768xf32> attributes {__placehol // CHECK-NEXT: linalg.yield %[[VAL7]] : f32 // CHECK-NEXT: -> tensor<8x1024x768xf32> +// ----- + +func.func private @NextOffsetFunc() -> tensor attributes {byre_compute_name = "NextOffset", byre_force_compute_name} +func.func private @GetSeedFunc() -> tensor attributes {byre_compute_name = "GetSeed", byre_force_compute_name} func.func @convert_rng_dynamic(%arg0: tensor) -> tensor attributes {__placeholder__byre.entry_point} { %0 = mhlo.constant dense<1.000000e+00> : tensor %1 = mhlo.constant dense<0.000000e+00> : tensor %2 = shape.shape_of %arg0 : tensor -> tensor<3xindex> %3 = arith.index_cast %2 : tensor<3xindex> to tensor<3xi64> - %4 = byre.compute @GetSeed() -> tensor - %5 = byre.compute @GetOffset() -> tensor + %4 = call @GetSeedFunc() : () -> tensor + %5 = call @NextOffsetFunc() : () -> tensor %6 = mhlo.custom_call @byteir.rng_uniform(%1, %0, %4, %5, %3) {backend_config = ""} : (tensor, tensor, tensor, tensor, tensor<3xi64>) -> tensor return %6 : tensor } @@ -59,8 +65,8 @@ func.func @convert_rng_dynamic(%arg0: tensor) -> tensor at // CHECK-LABEL: func.func @convert_rng_dynamic // CHECK-DAG: arith.constant // CHECK-DAG: arith.constant -// CHECK-DAG: byre.compute -// CHECK-DAG: byre.compute +// CHECK-DAG: call +// CHECK-DAG: call // CHECK-DAG: arith.index_cast // CHECK-DAG: arith.constant // CHECK-DAG: tensor.extract diff --git a/compiler/test/Dialect/Linalg/annotate.mlir b/compiler/test/Dialect/Linalg/annotate.mlir index 263816d98..c017f9b00 100644 --- a/compiler/test/Dialect/Linalg/annotate.mlir +++ b/compiler/test/Dialect/Linalg/annotate.mlir @@ -1,4 +1,4 @@ -// RUN: byteir-opt %s -linalg-bufferize --transform-dialect-interpreter -cse -canonicalize | FileCheck %s +// RUN: byteir-opt %s --transform-dialect-interpreter -cse -canonicalize --split-input-file | FileCheck %s transform.sequence failures(propagate) { ^bb0(%arg1: !pdl.operation): @@ -21,4 +21,24 @@ func.func @hgemm(%arg0: memref<5376x2048xf16>, %arg1: memref<2048x5376xf16, #map linalg.matmul ins(%arg0, %arg1: memref<5376x2048xf16>, memref<2048x5376xf16, #map>) outs(%arg2: memref<5376x5376xf16>) return -} \ No newline at end of file +} + +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = transform.param.constant "test_attr" -> !pdl.attribute + transform.annotate %0 "test_name" = %1 : !pdl.operation, !pdl.attribute +} + +#map = affine_map<(d0, d1) -> (d1 * 2048 + d0)> + +// CHECK-LABEL: func.func @hgemm +func.func @hgemm(%arg0: memref<5376x2048xf16>, %arg1: memref<2048x5376xf16, #map>, %arg2: memref<5376x5376xf16>) { + // CHECK: linalg.matmul + // CHECK-SAME: test_name = "test_attr" + linalg.matmul ins(%arg0, %arg1: memref<5376x2048xf16>, memref<2048x5376xf16, #map>) + outs(%arg2: memref<5376x5376xf16>) + return +} diff --git a/compiler/test/Dialect/Mhlo/transforms/ConvertOpToCustomCall.mlir b/compiler/test/Dialect/Mhlo/transforms/ConvertOpToCustomCall.mlir index 3700d02f8..1b13f2851 100644 --- a/compiler/test/Dialect/Mhlo/transforms/ConvertOpToCustomCall.mlir +++ b/compiler/test/Dialect/Mhlo/transforms/ConvertOpToCustomCall.mlir @@ -1,21 +1,48 @@ -// RUN: byteir-opt %s -convert-op-to-customcall | FileCheck %s +// RUN: byteir-opt %s -convert-op-to-customcall --split-input-file | FileCheck %s -func.func @convert_rng_static() -> tensor<8x1024x768xf32> attributes {__placeholder__byre.entry_point} { +func.func @convert_rng_static() -> tensor<8x1024x768xf32> { %16 = mhlo.constant dense<1.000000e+00> : tensor %17 = mhlo.constant dense<0.000000e+00> : tensor %18 = mhlo.constant dense<[8, 1024, 768]> : tensor<3xi64> %26 = "mhlo.rng"(%17, %16, %18) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<3xi64>) -> tensor<8x1024x768xf32> return %26 : tensor<8x1024x768xf32> } +// CHECK-LABEL: func.func private @NextOffsetFunc() -> tensor attributes {byre_compute_name = "NextOffset", byre_force_compute_name} +// CHECK-LABEL: func.func private @GetSeedFunc() -> tensor attributes {byre_compute_name = "GetSeed", byre_force_compute_name} // CHECK-LABEL: func.func @convert_rng_static // CHECK-NEXT: mhlo.constant // CHECK-NEXT: mhlo.constant -// CHECK-NEXT: byre.compute @GetSeed -// CHECK-NEXT: byre.compute @GetOffset +// CHECK-NEXT: call @GetSeedFunc +// CHECK-NEXT: call @NextOffsetFunc // CHECK-NEXT: mhlo.custom_call -// CHEKC-SAME: call_target_name = "byteir.rng_uniform" +// CHECK-SAME: @byteir.rng_uniform -func.func @convert_rng_dynamic(%arg0: tensor) -> tensor attributes {__placeholder__byre.entry_point} { +// ----- + +func.func @convert_two_rng_static() -> (tensor<8x1024x768xf32>, tensor<8x1024x768xf32>) { + %16 = mhlo.constant dense<1.000000e+00> : tensor + %17 = mhlo.constant dense<0.000000e+00> : tensor + %18 = mhlo.constant dense<[8, 1024, 768]> : tensor<3xi64> + %26 = "mhlo.rng"(%17, %16, %18) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<3xi64>) -> tensor<8x1024x768xf32> + %27 = "mhlo.rng"(%17, %16, %18) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<3xi64>) -> tensor<8x1024x768xf32> + return %26, %27 : tensor<8x1024x768xf32>, tensor<8x1024x768xf32> +} +// CHECK-LABEL: func.func private @NextOffsetFunc() -> tensor attributes {byre_compute_name = "NextOffset", byre_force_compute_name} +// CHECK-LABEL: func.func private @GetSeedFunc() -> tensor attributes {byre_compute_name = "GetSeed", byre_force_compute_name} +// CHECK-LABEL: func.func @convert_two_rng_static +// CHECK-NEXT: mhlo.constant +// CHECK-NEXT: mhlo.constant +// CHECK-NEXT: call @GetSeedFunc +// CHECK-NEXT: call @NextOffsetFunc +// CHECK-NEXT: mhlo.custom_call +// CHECK-SAME: @byteir.rng_uniform +// CHECK-NEXT: call @NextOffsetFunc +// CHECK-NEXT: mhlo.custom_call +// CHECK-SAME: @byteir.rng_uniform + +// ----- + +func.func @convert_rng_dynamic(%arg0: tensor) -> tensor { %16 = mhlo.constant dense<1.000000e+00> : tensor %17 = mhlo.constant dense<0.000000e+00> : tensor %shape = shape.shape_of %arg0 : tensor -> tensor<3xindex> @@ -23,12 +50,14 @@ func.func @convert_rng_dynamic(%arg0: tensor) -> tensor at %26 = "mhlo.rng"(%17, %16, %shape1) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<3xi64>) -> tensor return %26 : tensor } +// CHECK-LABEL: func.func private @NextOffsetFunc() -> tensor attributes {byre_compute_name = "NextOffset", byre_force_compute_name} +// CHECK-LABEL: func.func private @GetSeedFunc() -> tensor attributes {byre_compute_name = "GetSeed", byre_force_compute_name} // CHECK-LABEL: func.func @convert_rng_dynamic // CHECK-NEXT: mhlo.constant // CHECK-NEXT: mhlo.constant +// CHECK-NEXT: call @GetSeedFunc // CHECK-NEXT: shape.shape_of // CHECK-NEXT: arith.index_cast -// CHECK-NEXT: byre.compute @GetSeed -// CHECK-NEXT: byre.compute @GetOffset +// CHECK-NEXT: call @NextOffsetFunc // CHECK-NEXT: mhlo.custom_call -// CHEKC-SAME: call_target_name = "byteir.rng_uniform" \ No newline at end of file +// CHECK-SAME: @byteir.rng_uniform diff --git a/compiler/test/Pipelines/HloOpts/rng.mlir b/compiler/test/Pipelines/HloOpts/rng.mlir index ac9fd7aab..7ef7668f2 100644 --- a/compiler/test/Pipelines/HloOpts/rng.mlir +++ b/compiler/test/Pipelines/HloOpts/rng.mlir @@ -1,57 +1,47 @@ // RUN: byteir-opt -hlo-opt %s | FileCheck %s -func.func @uniform_rngf32() -> tensor<2x128x128xf32> { - %0 = mhlo.constant dense<0.000000e+00> : tensor - %1 = mhlo.constant dense<1.000000e+00> : tensor - %2 = mhlo.constant dense<[2, 128, 128]> : tensor<3xi64> - %3 = "mhlo.rng"(%0, %1, %2) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<3xi64>) -> tensor<2x128x128xf32> - %4 = "mhlo.rng"(%0, %1, %2) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<3xi64>) -> tensor<2x128x128xf32> - %5 = mhlo.add %3, %4 : tensor<2x128x128xf32> - return %5 : tensor<2x128x128xf32> +module @uniform_rng { + func.func @uniform_rngf32() -> tensor<2x128x128xf32> { + %0 = mhlo.constant dense<0.000000e+00> : tensor + %1 = mhlo.constant dense<1.000000e+00> : tensor + %2 = mhlo.constant dense<[2, 128, 128]> : tensor<3xi64> + %3 = "mhlo.rng"(%0, %1, %2) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<3xi64>) -> tensor<2x128x128xf32> + %4 = mhlo.add %3, %3 : tensor<2x128x128xf32> + return %4 : tensor<2x128x128xf32> + } + func.func @uniform_rngf64() -> tensor<2x128x128xf64> { + %0 = mhlo.constant dense<0.000000e+00> : tensor + %1 = mhlo.constant dense<1.000000e+00> : tensor + %2 = mhlo.constant dense<[2, 128, 128]> : tensor<3xi64> + %3 = "mhlo.rng"(%0, %1, %2) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<3xi64>) -> tensor<2x128x128xf64> + %4 = mhlo.add %3, %3 : tensor<2x128x128xf64> + return %4 : tensor<2x128x128xf64> + } } -// CHECK-LABEL: func.func private @RngUniform_f32f32_f320 -// CHECK-DAG: __byre__high = 1.000000e+00 -// CHECK-DAG: __byre__low = 0.000000e+00 -// CHECK-DAG: byre_compute_name = "RngUniform_f32f32_f32" -// CHECK-DAG: byre_force_compute_name +// CHECK-LABEL: func.func private @NextOffsetFunc() -> tensor attributes {byre_compute_name = "NextOffset", byre_force_compute_name} +// CHECK-LABEL: func.func private @GetSeedFunc() -> tensor attributes {byre_compute_name = "GetSeed", byre_force_compute_name} -// CHECK-LABEL: func.func private @RngUniform_f32f32_f321 -// CHECK-DAG: __byre__high = 1.000000e+00 -// CHECK-DAG: __byre__low = 0.000000e+00 -// CHECK-DAG: byre_compute_name = "RngUniform_f32f32_f32" -// CHECK-DAG: byre_force_compute_name +// CHECK-LABEL: func.func private @Unknown0 +// CHECK-DAG: mhlo.constant +// CHECK-DAG: mhlo.constant +// CHECK-DAG: mhlo.custom_call +// CHECK-DAG: mhlo.add // CHECK-LABEL: func.func @uniform_rngf32 -// CHECK: %[[VAR_0:.*]] = call @RngUniform_f32f32_f320 -// CHECK: %[[VAR_1:.*]] = call @RngUniform_f32f32_f321 -// CHECK: %[[VAR_2:.*]] = mhlo.add %[[VAR_0]], %[[VAR_1]] +// CHECK: %[[VAR_0:.*]] = call @GetSeedFunc() +// CHECK: %[[VAR_1:.*]] = call @NextOffsetFunc() +// CHECK: %[[VAR_2:.*]] = call @Unknown0(%[[VAR_0]], %[[VAR_1]]) // CHECK: return %[[VAR_2]] -func.func @uniform_rngf64() -> tensor<2x128x128xf64> { - %0 = mhlo.constant dense<0.000000e+00> : tensor - %1 = mhlo.constant dense<1.000000e+00> : tensor - %2 = mhlo.constant dense<[2, 128, 128]> : tensor<3xi64> - %3 = "mhlo.rng"(%0, %1, %2) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<3xi64>) -> tensor<2x128x128xf64> - %4 = "mhlo.rng"(%0, %1, %2) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<3xi64>) -> tensor<2x128x128xf64> - %5 = mhlo.add %3, %4 : tensor<2x128x128xf64> - return %5 : tensor<2x128x128xf64> -} - -// CHECK-LABEL: func.func private @RngUniform_f64f64_f642 -// CHECK-DAG: __byre__high = 1.000000e+00 -// CHECK-DAG: __byre__low = 0.000000e+00 -// CHECK-DAG: byre_compute_name = "RngUniform_f64f64_f64" -// CHECK-DAG: byre_force_compute_name - -// CHECK-LABEL: func.func private @RngUniform_f64f64_f643 -// CHECK-DAG: __byre__high = 1.000000e+00 -// CHECK-DAG: __byre__low = 0.000000e+00 -// CHECK-DAG: byre_compute_name = "RngUniform_f64f64_f64" -// CHECK-DAG: byre_force_compute_name +// CHECK-LABEL: func.func private @Unknown1 +// CHECK-DAG: mhlo.constant +// CHECK-DAG: mhlo.constant +// CHECK-DAG: mhlo.custom_call +// CHECK-DAG: mhlo.add // CHECK-LABEL: func.func @uniform_rngf64 -// CHECK: %[[VAR_0:.*]] = call @RngUniform_f64f64_f642 -// CHECK: %[[VAR_1:.*]] = call @RngUniform_f64f64_f643 -// CHECK: %[[VAR_2:.*]] = mhlo.add %[[VAR_0]], %[[VAR_1]] -// CHECK: return %[[VAR_2]] +// CHECK: %[[VAR_3:.*]] = call @GetSeedFunc() +// CHECK: %[[VAR_4:.*]] = call @NextOffsetFunc() +// CHECK: %[[VAR_5:.*]] = call @Unknown1(%[[VAR_3]], %[[VAR_4]]) +// CHECK: return %[[VAR_5]] diff --git a/compiler/test/Stat/opCnt.mlir b/compiler/test/Stat/opCnt.mlir index cd3eb4a9d..6bf6a0d7b 100644 --- a/compiler/test/Stat/opCnt.mlir +++ b/compiler/test/Stat/opCnt.mlir @@ -8,7 +8,7 @@ module { %0 = "tf.Add"(%arg0, %arg1) : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> %1 = "tf.Mul"(%0, %arg0) : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> %2 = "mhlo.add"(%1, %arg0) : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> - %3 = "mhlo.add"(%1, %2) : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> + %3 = "mhlo.add"(%1, %2) : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> return %3 : tensor<2x4xf32> } func.func @tf_add(%arg0 : tensor<2x4xf32>, %arg1 : tensor<2x4xf32>) -> (tensor<2x4xf32>) { @@ -25,34 +25,34 @@ module { return %5 : tensor<2x4xf32> } } -// DEFAULT: func.call 1 -// DEFAULT: func.func 2 -// DEFAULT: func.return 2 -// DEFAULT: mhlo.add 6 -// DEFAULT: mhlo.fusion 1 -// DEFAULT: mhlo.return 1 -// DEFAULT: tf.Add 2 -// DEFAULT: tf.Mul 2 +// DEFAULT: func.call 1 f32 f32 +// DEFAULT: func.func 2 NA NA +// DEFAULT: func.return 2 f32 NA +// DEFAULT: mhlo.add 6 f32 f32 +// DEFAULT: mhlo.fusion 1 f32 f32 +// DEFAULT: mhlo.return 1 f32 NA +// DEFAULT: tf.Add 2 f32 f32 +// DEFAULT: tf.Mul 2 f32 f32 -// FUNCNAME: func.call 1 +// FUNCNAME: func.call 1 f32 f32 // FUNCNAME: func.func 1 -// FUNCNAME: func.return 1 -// FUNCNAME: mhlo.add 4 -// FUNCNAME: mhlo.fusion 1 -// FUNCNAME: mhlo.return 1 -// FUNCNAME: tf.Add 1 -// FUNCNAME: tf.Mul 1 +// FUNCNAME: func.return 1 f32 NA +// FUNCNAME: mhlo.add 4 f32 f32 +// FUNCNAME: mhlo.fusion 1 f32 f32 +// FUNCNAME: mhlo.return 1 f32 NA +// FUNCNAME: tf.Add 1 f32 f32 +// FUNCNAME: tf.Mul 1 f32 f32 -// TOPONLY: func.call 1 -// TOPONLY: func.return 2 -// TOPONLY: mhlo.add 4 -// TOPONLY: mhlo.fusion 1 -// TOPONLY: tf.Add 2 -// TOPONLY: tf.Mul 2 +// TOPONLY: func.call 1 f32 f32 +// TOPONLY: func.return 2 f32 NA +// TOPONLY: mhlo.add 4 f32 f32 +// TOPONLY: mhlo.fusion 1 f32 f32 +// TOPONLY: tf.Add 2 f32 f32 +// TOPONLY: tf.Mul 2 f32 f32 -// FUNCNAMETOPONLY: func.call 1 -// FUNCNAMETOPONLY: func.return 1 -// FUNCNAMETOPONLY: mhlo.add 2 -// FUNCNAMETOPONLY: mhlo.fusion 1 -// FUNCNAMETOPONLY: tf.Add 1 -// FUNCNAMETOPONLY: tf.Mul 1 +// FUNCNAMETOPONLY: func.call 1 f32 f32 +// FUNCNAMETOPONLY: func.return 1 f32 NA +// FUNCNAMETOPONLY: mhlo.add 2 f32 f32 +// FUNCNAMETOPONLY: mhlo.fusion 1 f32 f32 +// FUNCNAMETOPONLY: tf.Add 1 f32 f32 +// FUNCNAMETOPONLY: tf.Mul 1 f32 f32 diff --git a/compiler/test/Stat/opTypes.mlir b/compiler/test/Stat/opTypes.mlir new file mode 100644 index 000000000..a3b32cc93 --- /dev/null +++ b/compiler/test/Stat/opTypes.mlir @@ -0,0 +1,65 @@ +// RUN: byteir-stat -op-cnt %s | FileCheck %s -check-prefix=DEFAULT +// RUN: byteir-stat -op-cnt -func-name="tf_add" %s | FileCheck %s -check-prefix=FUNCNAME +// RUN: byteir-stat -op-cnt -top-only %s | FileCheck %s -check-prefix=TOPONLY +// RUN: byteir-stat -op-cnt -func-name="tf_add" -top-only %s | FileCheck %s -check-prefix=FUNCNAMETOPONLY + +module { + func.func private @some_callee(%arg0 : tensor<2x4xf32>, %arg1 : tensor<2x4xf32>) -> (tensor<2x4xf32>) { + %0 = "tf.Add"(%arg0, %arg1) : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> + %1 = "tf.Mul"(%0, %arg0) : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> + %2 = "mhlo.add"(%1, %arg0) : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> + %11 = mhlo.convert %1 : (tensor<2x4xf32>) -> tensor<2x4xf16> + %21 = mhlo.convert %2 : (tensor<2x4xf32>) -> tensor<2x4xf16> + %3 = "mhlo.add"(%11, %21) : (tensor<2x4xf16>, tensor<2x4xf16>) -> tensor<2x4xf16> + %31 = mhlo.custom_call @byteir.softmax(%3) {backend_config = "", byteir_attrs = {axis = 1 : i64}} : (tensor<2x4xf16>) -> tensor<2x4xf32> + return %31 : tensor<2x4xf32> + } + func.func @tf_add(%arg0 : tensor<2x4xf32>, %arg1 : tensor<2x4xf32>) -> (tensor<2x4xf32>) { + %0 = "tf.Add"(%arg0, %arg1) : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> + %1 = "tf.Mul"(%0, %arg0) : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> + %2 = "mhlo.add"(%1, %arg0) : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> + %3 = "mhlo.add"(%1, %2) : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> + %4 = "mhlo.fusion"(%arg0, %arg1) ( { + %6 = "mhlo.add" (%arg0, %arg1) : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> + %7 = "mhlo.add" (%arg0, %6) : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> + "mhlo.return"(%6) : (tensor<2x4xf32>) -> () + }) : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> + %5 = call @some_callee(%3, %4) : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> + return %5 : tensor<2x4xf32> + } +} +// DEFAULT: func.call 1 f32 f32 +// DEFAULT: func.func 2 NA NA +// DEFAULT: func.return 2 f32 NA +// DEFAULT: mhlo.add 6 f16,f32 f16,f32 +// DEFAULT: mhlo.convert 2 f32 f16 +// DEFAULT: mhlo.custom_call 1 f16 f32 +// DEFAULT: mhlo.fusion 1 f32 f32 +// DEFAULT: mhlo.return 1 f32 NA +// DEFAULT: tf.Add 2 f32 f32 +// DEFAULT: tf.Mul 2 f32 f32 + +// FUNCNAME: func.call 1 f32 f32 +// FUNCNAME: func.func 1 +// FUNCNAME: func.return 1 f32 NA +// FUNCNAME: mhlo.add 4 f32 f32 +// FUNCNAME: mhlo.fusion 1 f32 f32 +// FUNCNAME: mhlo.return 1 f32 NA +// FUNCNAME: tf.Add 1 f32 f32 +// FUNCNAME: tf.Mul 1 f32 f32 + +// TOPONLY: func.call 1 f32 f32 +// TOPONLY: func.return 2 f32 NA +// TOPONLY: mhlo.add 4 f16,f32 f16,f32 +// TOPONLY: mhlo.convert 2 f32 f16 +// TOPONLY: mhlo.custom_call 1 f16 f32 +// TOPONLY: mhlo.fusion 1 f32 f32 +// TOPONLY: tf.Add 2 f32 f32 +// TOPONLY: tf.Mul 2 f32 f32 + +// FUNCNAMETOPONLY: func.call 1 f32 f32 +// FUNCNAMETOPONLY: func.return 1 f32 NA +// FUNCNAMETOPONLY: mhlo.add 2 f32 f32 +// FUNCNAMETOPONLY: mhlo.fusion 1 f32 f32 +// FUNCNAMETOPONLY: tf.Add 1 f32 f32 +// FUNCNAMETOPONLY: tf.Mul 1 f32 f32 diff --git a/frontends/onnx-frontend/onnx-frontend/src/Compiler/OFCompilerPipelines.cpp b/frontends/onnx-frontend/onnx-frontend/src/Compiler/OFCompilerPipelines.cpp index e2bad0d49..e5bec370d 100644 --- a/frontends/onnx-frontend/onnx-frontend/src/Compiler/OFCompilerPipelines.cpp +++ b/frontends/onnx-frontend/onnx-frontend/src/Compiler/OFCompilerPipelines.cpp @@ -44,9 +44,13 @@ void addCustomizedONNXToMhloPasses( onnx_frontend::createOFRewriteToCustomCallPass(customCallOps)); pm.addNestedPass( onnx_mlir::createDecomposeONNXToONNXPass("mhlo")); - pm.addPass(onnx_mlir::createShapeInferencePass()); - pm.addPass(onnx_frontend::createOFCanonicalizerPass()); - pm.addPass(onnx_mlir::createShapeInferencePass()); + for (int i = 0; i < onnx_frontend::ofRepeatStatic; i++) { + pm.addPass(onnx_mlir::createShapeInferencePass()); + pm.addPass(onnx_frontend::createOFCanonicalizerPass()); + pm.addPass(onnx_mlir::createShapeInferencePass()); + pm.addNestedPass( + onnx_mlir::createConstPropONNXToONNXPass()); + } } // There are more opportunities for const propagation once all tensors have diff --git a/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFCanonicalizer.cpp b/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFCanonicalizer.cpp index b4ad3a257..87a8a88e9 100644 --- a/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFCanonicalizer.cpp +++ b/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFCanonicalizer.cpp @@ -30,14 +30,18 @@ struct OFCanonicalizerPass LogicalResult initialize(MLIRContext *context) override { RewritePatternSet owningPatterns(context); + SmallVector disabledPatterns{ + "FuseBatchNormInferenceModeConvPattern", + "RewriteBatchNormInferenceModeConvPattern1", + "RewriteBatchNormInferenceModeConvPattern2"}; for (auto *dialect : context->getLoadedDialects()) dialect->getCanonicalizationPatterns(owningPatterns); for (RegisteredOperationName op : context->getRegisteredOperations()) op.getCanonicalizationPatterns(owningPatterns, context); - patterns = FrozenRewritePatternSet(std::move(owningPatterns), - /*disabledPatterns*/ {}, - /*enabledPatterns*/ {}); + patterns = + FrozenRewritePatternSet(std::move(owningPatterns), disabledPatterns, + /*enabledPatterns*/ {}); return success(); } diff --git a/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.cpp b/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.cpp index 11f36caba..dfda80cc2 100644 --- a/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.cpp +++ b/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.cpp @@ -167,6 +167,33 @@ Value createL2Norm(PatternRewriter &rewriter, Location loc, Value input, return customCallOp.getResults()[0]; } +Value createL2NormWithoutEps(PatternRewriter &rewriter, Location loc, + Value input, ArrayAttr axis_attr) { + RankedTensorType inputType = + input.getType().dyn_cast_or_null(); + assert(inputType != nullptr && "L2Norm input type must be ranked"); + + int64_t axis = axis_attr[0].cast().getInt(); + // canonicalize axis to be positive + if (axis < 0) { + axis = inputType.getRank() + axis; + } + + std::string call_target_name = getL2NormNameWithPrefix(); + mhlo::CustomCallOp customCallOp = rewriter.create( + loc, llvm::ArrayRef{inputType}, llvm::ArrayRef{input}, + call_target_name, false, rewriter.getStringAttr(""), + mhlo::CustomCallApiVersion::API_VERSION_ORIGINAL, + rewriter.getArrayAttr(llvm::ArrayRef{}), + mhlo::CustomCallSchedule::NONE, nullptr, nullptr, + rewriter.getArrayAttr(llvm::ArrayRef{})); + DictionaryAttrWrapper attrs(rewriter.getContext()); + attrs.setAttr("axis", rewriter.getI64ArrayAttr({axis})); + customCallOp->setAttr(BYTEIR_ATTRS, getCleanAttr(attrs)); + + return customCallOp.getResults()[0]; +} + //===----------------------------------------------------------------------===// // Quantize/Dequantize //===----------------------------------------------------------------------===// @@ -524,7 +551,9 @@ struct OFRewriteToCustomCallPass llvm::SmallVector>> validOpSet; validOpSet[getL2NormName()].emplace_back( - std::make_unique(context)); + std::make_unique(context)); + validOpSet[getL2NormName()].emplace_back( + std::make_unique(context)); validOpSet[getQuantizeName()].emplace_back( std::make_unique(context)); validOpSet[getDequantizeName()].emplace_back( diff --git a/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.td b/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.td index 238071949..c18e3a7fb 100644 --- a/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.td +++ b/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.td @@ -32,7 +32,7 @@ def HasOneUse : Constraint, "value has exactly one use"> //===----------------------------------------------------------------------===// // L2Norm Pattern //===----------------------------------------------------------------------===// -def RewriteL2Norm : Pat< +def RewriteL2NormPat1 : Pat< (ONNXDivOp $input, (ONNXExpandOp @@ -46,6 +46,14 @@ def RewriteL2Norm : Pat< (NativeCodeCall<"createL2Norm($_builder, $_loc, $0, $1, $2)"> $input, $axis_attr, $epsilon_attr), [(IsOneSize $axis_attr), (TrueBoolAttr $keep_dims), (IsOneSizeElements $epsilon_attr)]>; +def RewriteL2NormPat2 : Pat< + (ONNXDivOp + $input, + (ONNXReduceL2V13Op $input, $axis_attr, $keep_dims) + ), + (NativeCodeCall<"createL2NormWithoutEps($_builder, $_loc, $0, $1)"> $input, $axis_attr), + [(IsOneSize $axis_attr), (TrueBoolAttr $keep_dims)]>; + //===----------------------------------------------------------------------===// // Quantize/Dequantize Pattern //===----------------------------------------------------------------------===// diff --git a/frontends/onnx-frontend/onnx-frontend/test/of_rewrite_to_custom_call.mlir b/frontends/onnx-frontend/onnx-frontend/test/of_rewrite_to_custom_call.mlir index 1165efee2..a508ec109 100644 --- a/frontends/onnx-frontend/onnx-frontend/test/of_rewrite_to_custom_call.mlir +++ b/frontends/onnx-frontend/onnx-frontend/test/of_rewrite_to_custom_call.mlir @@ -155,7 +155,7 @@ func.func @test_gelu_without_last_mul(%arg0: tensor<1x3x5x5xf32>, %arg1: tensor< // ----- -func.func @test_l2_norm(%267: tensor<16x128xf32>) -> tensor<16x128xf32> { +func.func @test_l2_norm_pat1(%267: tensor<16x128xf32>) -> tensor<16x128xf32> { %5 = "onnx.Constant"() {value = dense<9.99999996E-13> : tensor} : () -> tensor %126 = "onnx.Constant"() {value = dense<[16, 128]> : tensor<2xi64>} : () -> tensor<2xi64> %268 = "onnx.ReduceL2V13"(%267) {axes = [-1], keepdims = 1 : si64, onnx_node_name = "ReduceL2_213"} : (tensor<16x128xf32>) -> tensor<16x1xf32> @@ -163,7 +163,7 @@ func.func @test_l2_norm(%267: tensor<16x128xf32>) -> tensor<16x128xf32> { %270 = "onnx.Expand"(%269, %126) {onnx_node_name = "Expand_217"} : (tensor<16x1xf32>, tensor<2xi64>) -> tensor<16x128xf32> %271 = "onnx.Div"(%267, %270) {onnx_node_name = "Div_218"} : (tensor<16x128xf32>, tensor<16x128xf32>) -> tensor<16x128xf32> return %271 : tensor<16x128xf32> -// CHECK-LABEL: @test_l2_norm +// CHECK-LABEL: @test_l2_norm_pat1 // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<16x128xf32>) -> tensor<16x128xf32> { // CHECK-NEXT: [[VAR_0_:%.+]] = mhlo.custom_call @byteir.l2_norm(%arg0) {backend_config = "", byteir_attrs = {axis = [1], epsilon = 9.999999960041972E-13 : f64}} : (tensor<16x128xf32>) -> tensor<16x128xf32> // CHECK-NEXT: return [[VAR_0_]] : tensor<16x128xf32> @@ -171,6 +171,18 @@ func.func @test_l2_norm(%267: tensor<16x128xf32>) -> tensor<16x128xf32> { // ----- +func.func @test_l2_norm_pat2(%1146: tensor<12x128xf32>) -> tensor<12x128xf32> { + %1147 = "onnx.ReduceL2V13"(%1146) {axes = [1], keepdims = 1 : si64, onnx_node_name = "ReduceL2_769"} : (tensor<12x128xf32>) -> tensor<12x1xf32> + %1148 = "onnx.Div"(%1146, %1147) {onnx_node_name = "Div_770"} : (tensor<12x128xf32>, tensor<12x1xf32>) -> tensor<12x128xf32> + return %1148 : tensor<12x128xf32> +// CHECK-LABEL: @test_l2_norm_pat2 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<12x128xf32>) -> tensor<12x128xf32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = mhlo.custom_call @byteir.l2_norm(%arg0) {backend_config = "", byteir_attrs = {axis = [1]}} : (tensor<12x128xf32>) -> tensor<12x128xf32> +// CHECK-NEXT: return [[VAR_0_]] : tensor<12x128xf32> +} + +// ----- + func.func @test_quantize_per_tensor(%arg0: tensor<16x3x256x256xf32>) -> tensor<16x3x256x256xi8> { %291 = mhlo.constant dense<0.0207054354> : tensor %292 = mhlo.constant dense<0> : tensor diff --git a/frontends/onnx-frontend/third_party/patches/OnnxMlirResizeV13ShapeInfer.patch b/frontends/onnx-frontend/third_party/patches/OnnxMlirResizeV13ShapeInfer.patch new file mode 100644 index 000000000..85f95e97a --- /dev/null +++ b/frontends/onnx-frontend/third_party/patches/OnnxMlirResizeV13ShapeInfer.patch @@ -0,0 +1,103 @@ +diff --git a/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp b/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp +index d7dff949..95ad6b4c 100644 +--- a/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp ++++ b/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp +@@ -760,6 +760,18 @@ struct ONNXResizeOpShapeHelper : public ONNXOpShapeHelper { + llvm::SmallVector scales; + }; + ++struct ONNXResizeV13OpShapeHelper : public ONNXOpShapeHelper { ++ ONNXResizeV13OpShapeHelper(mlir::Operation *op, mlir::ValueRange operands, ++ IndexExprBuilder *ieBuilder = nullptr, IndexExprScope *scope = nullptr) ++ : ONNXOpShapeHelper(op, operands, ieBuilder, scope) {} ++ virtual ~ONNXResizeV13OpShapeHelper() {} ++ mlir::LogicalResult computeShape() final; ++ // Values set by computeShape: scales is a float index expression. It is ++ // directly the `scale` argument when scale is provided by the op. When ++ // `size` is provided, then scale is float(`size`)/float(dim). ++ llvm::SmallVector scales; ++}; ++ + //===----------------------------------------------------------------------===// + // Non specific Ops, namely ops that + // * need customization only for themselves (no sharing of code) +diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Resize.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Resize.cpp +index d8e11c29..abedb620 100644 +--- a/src/Dialect/ONNX/ONNXOps/Tensor/Resize.cpp ++++ b/src/Dialect/ONNX/ONNXOps/Tensor/Resize.cpp +@@ -86,6 +86,48 @@ LogicalResult ONNXResizeOpShapeHelper::computeShape() { + return success(); + } + ++LogicalResult ONNXResizeV13OpShapeHelper::computeShape() { ++ ONNXResizeV13OpAdaptor operandAdaptor(operands); ++ uint64_t rank = createIE->getShapedTypeRank(operandAdaptor.getX()); ++ DimsExpr inputDims, outputDims; ++ createIE->getShapeAsDims(operandAdaptor.getX(), inputDims); ++ bool scalesIsAbsent = isAbsent(operandAdaptor.getScales()); ++ ++ if (!scalesIsAbsent) { ++ // Read and save scales as float. ++ createIE->getFloatFromArrayAsNonAffine(operandAdaptor.getScales(), scales); ++ if (inputDims.size() != scales.size()) ++ return op->emitError("expected scales to have the same rank as input"); ++ // Compute output dims = int(floor(float(input dims) * scales)). ++ for (uint64_t i = 0; i < rank; ++i) { ++ // Special case for scale == 1.0 as converts are then needed. ++ if (scales[i].isLiteralAndIdenticalTo(1.0)) { ++ outputDims.emplace_back(inputDims[i]); ++ } else { ++ IndexExpr floatInputDim = inputDims[i].convertToFloat(); ++ IndexExpr floatProduct = floatInputDim * scales[i]; ++ // Formula has a floor, but convert of positive number already rounds ++ // toward zero, so skip the floor. ++ outputDims.emplace_back(floatProduct.convertToIndex()); ++ } ++ } ++ } else { ++ // Output size is defined by input `sizes`. ++ createIE->getIntFromArrayAsSymbols(operandAdaptor.getSizes(), outputDims); ++ if (inputDims.size() != outputDims.size()) ++ return op->emitError("expected scales to have the same rank as input"); ++ // Compute scales as float(output dims) / float(input dims). ++ for (uint64_t i = 0; i < rank; ++i) { ++ IndexExpr floatInputDim = inputDims[i].convertToFloat(); ++ IndexExpr floatOutputDim = outputDims[i].convertToFloat(); ++ scales.emplace_back(floatOutputDim / floatInputDim); ++ } ++ } ++ // Save output dims ++ setOutputDims(outputDims); ++ return success(); ++} ++ + } // namespace onnx_mlir + + //===----------------------------------------------------------------------===// +@@ -127,3 +191,13 @@ LogicalResult ONNXResizeOp::inferShapes( + ONNXResizeOpShapeHelper shapeHelper(getOperation(), {}); + return shapeHelper.computeShapeAndUpdateType(elementType); + } ++ ++LogicalResult ONNXResizeV13Op::inferShapes( ++ std::function doShapeInference) { ++ if (!hasShapeAndRank(getX())) ++ return success(); ++ ++ Type elementType = getX().getType().cast().getElementType(); ++ ONNXResizeV13OpShapeHelper shapeHelper(getOperation(), {}); ++ return shapeHelper.computeShapeAndUpdateType(elementType); ++} +diff --git a/src/Dialect/ONNX/ONNXUnsupportedOps.hpp b/src/Dialect/ONNX/ONNXUnsupportedOps.hpp +index baa23f55..e8bce2a0 100644 +--- a/src/Dialect/ONNX/ONNXUnsupportedOps.hpp ++++ b/src/Dialect/ONNX/ONNXUnsupportedOps.hpp +@@ -62,7 +62,7 @@ UNSUPPORTED_OPS(ONNXRandomUniformLikeOp) + UNSUPPORTED_OPS(ONNXRandomUniformOp) + UNSUPPORTED_OPS(ONNXResizeV10Op) + UNSUPPORTED_OPS(ONNXResizeV11Op) +-UNSUPPORTED_OPS(ONNXResizeV13Op) ++// UNSUPPORTED_OPS(ONNXResizeV13Op) + UNSUPPORTED_OPS(ONNXSequenceMapOp) + UNSUPPORTED_OPS(ONNXSVMClassifierOp) + UNSUPPORTED_OPS(ONNXSVMRegressorOp) diff --git a/frontends/torch-frontend/requirements.txt b/frontends/torch-frontend/requirements.txt index f69aa183d..8b642e2f5 100644 --- a/frontends/torch-frontend/requirements.txt +++ b/frontends/torch-frontend/requirements.txt @@ -1,14 +1,14 @@ # torch and torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torch==2.1.0.dev20230725 -torchvision==0.16.0.dev20230725 +torch==2.1.0.dev20230820 +torchvision==0.16.0.dev20230820 # cuda torch and torchvision # --extra-index-url https://download.pytorch.org/whl/nightly/cu118 # --pre -# torch==2.1.0.dev20230725+cu118 -# torchvision==0.16.0.dev20230725+cu118 +# torch==2.1.0.dev20230820+cu118 +# torchvision==0.16.0.dev20230820+cu118 # transformers diff --git a/frontends/torch-frontend/third_party/patches/torchtostablehlo_basic.patch b/frontends/torch-frontend/third_party/patches/torchtostablehlo_basic.patch deleted file mode 100644 index 3c1d8ad91..000000000 --- a/frontends/torch-frontend/third_party/patches/torchtostablehlo_basic.patch +++ /dev/null @@ -1,68 +0,0 @@ -diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp -index c12b84f8..df5e84fa 100644 ---- a/lib/Conversion/TorchToStablehlo/Basic.cpp -+++ b/lib/Conversion/TorchToStablehlo/Basic.cpp -@@ -923,7 +923,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( - AtenBatchNormOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value input = adaptor.getInput(); -- // shape = [N, C, H, W] - auto inputTy = input.getType().cast(); - Value weight = adaptor.getWeight(); - Value bias = adaptor.getBias(); -@@ -942,7 +941,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( - } - auto inputElemTy = inputTy.getElementType().cast(); - -- Value channelDim = rewriter.create(op->getLoc(), input, 1); -+ Value channelDim = -+ rewriter.create(op->getLoc(), input, feature_index); - - if (options.dimSizeIndexBits == 32) { - auto channelDimI64 = rewriter.create( -@@ -1027,21 +1027,30 @@ LogicalResult ConvertAtenOp::matchAndRewrite( - return success(); - } else { - Type outputTy = getTypeConverter()->convertType(op.getType()); -- SmallVector castShape{inputTy.getShape().begin(), -- inputTy.getShape().end()}; -- castShape[1] = weightTy.getShape()[0]; -- auto castTy = RankedTensorType::get(castShape, inputTy.getElementType()); -- // Feature counts must match among operands of -- // stablehlo::BatchNormInferenceOp. -- Value inputCasted = -- rewriter.create(op.getLoc(), castTy, input); -- Value output = rewriter.create( -- op.getLoc(), inputCasted.getType(), inputCasted, weight, bias, -- runningMean, runningVar, -- // 'epsilon' must satisfy constraint: 32-bit float attribute. -- rewriter.getF32FloatAttr(eps), -- rewriter.getI64IntegerAttr(feature_index)); -- rewriter.replaceOpWithNewOp(op, outputTy, output); -+ bool mixedType = (inputTy.getElementType() != weightTy.getElementType()); -+ Value output; -+ if (mixedType) { -+ RankedTensorType convertedType = -+ RankedTensorType::get(inputTy.getShape(), rewriter.getF32Type()); -+ input = hlo::promoteType(rewriter, op.getLoc(), input, convertedType); -+ weight = hlo::promoteType(rewriter, op.getLoc(), weight, convertedType); -+ bias = hlo::promoteType(rewriter, op.getLoc(), bias, convertedType); -+ runningMean = hlo::promoteType(rewriter, op.getLoc(), runningMean, convertedType); -+ runningVar = hlo::promoteType(rewriter, op.getLoc(), runningVar, convertedType); -+ Value bnResult = rewriter.create( -+ op.getLoc(), convertedType, input, weight, bias, runningMean, -+ runningVar, rewriter.getF32FloatAttr(eps), -+ rewriter.getI64IntegerAttr(feature_index)); -+ output = -+ hlo::promoteType(rewriter, op.getLoc(), bnResult, outputTy.cast()); -+ } else { -+ output = rewriter.create( -+ op.getLoc(), outputTy, input, weight, bias, runningMean, runningVar, -+ // 'epsilon' must satisfy constraint: 32-bit float attribute. -+ rewriter.getF32FloatAttr(eps), -+ rewriter.getI64IntegerAttr(feature_index)); -+ } -+ rewriter.replaceOp(op, output); - return success(); - } - } diff --git a/frontends/torch-frontend/third_party/torch-mlir b/frontends/torch-frontend/third_party/torch-mlir index 9a1fae97b..b552d4ed9 160000 --- a/frontends/torch-frontend/third_party/torch-mlir +++ b/frontends/torch-frontend/third_party/torch-mlir @@ -1 +1 @@ -Subproject commit 9a1fae97b1e56464a109206acd3b4be7d63ebfc7 +Subproject commit b552d4ed956d82f5d9d0823b4727bb10bac6787c diff --git a/frontends/torch-frontend/torch-requirements.txt b/frontends/torch-frontend/torch-requirements.txt index 57264a852..7d9870d60 100644 --- a/frontends/torch-frontend/torch-requirements.txt +++ b/frontends/torch-frontend/torch-requirements.txt @@ -1,4 +1,4 @@ # cuda torch and torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cu118 --pre -torch==2.1.0.dev20230725+cu118 +torch==2.1.0.dev20230820+cu118 diff --git a/runtime/VERSION_NUMBER b/runtime/VERSION_NUMBER index 26aaba0e8..f0bb29e76 100644 --- a/runtime/VERSION_NUMBER +++ b/runtime/VERSION_NUMBER @@ -1 +1 @@ -1.2.0 +1.3.0 diff --git a/runtime/lib/backends/cuda/providers/default/tensor_generate/rng_state.cc b/runtime/lib/backends/cuda/providers/default/tensor_generate/rng_state.cc index 7a40c62f9..408c680e0 100644 --- a/runtime/lib/backends/cuda/providers/default/tensor_generate/rng_state.cc +++ b/runtime/lib/backends/cuda/providers/default/tensor_generate/rng_state.cc @@ -39,13 +39,15 @@ common::Status GetSeedOpKernel::RunImpl(const ExecutionContext &ctx) { int64_t rngSeed = rngStateHandle->getSeed(); OpAccessor accessor(info_, ctx.exec_frame); DTypeEnum dtype = accessor.GetArgDTypeEnum(0); + cudaStream_t stream = + static_cast(ctx.work_queue)->GetComputeStream(); void *device_p = accessor.GetArgAsyncValueRef(0); #define CASE(D) \ case DTypeEnum::D: { \ using ctype = DTypeTraits::type_t; \ ctype castedRngSeed = static_cast(rngSeed); \ - cudaMemcpy(device_p, &castedRngSeed, sizeof(ctype), \ - cudaMemcpyHostToDevice); \ + cudaMemcpyAsync(device_p, &castedRngSeed, sizeof(ctype), \ + cudaMemcpyHostToDevice, stream); \ return common::Status::OK(); \ } BRT_DISPATCH_NUMBER_TYPES(dtype, CASE) @@ -73,13 +75,15 @@ common::Status NextOffsetOpKernel::RunImpl(const ExecutionContext &ctx) { int64_t rngOffset = rngStateHandle->nextOffset(); OpAccessor accessor(info_, ctx.exec_frame); DTypeEnum dtype = accessor.GetArgDTypeEnum(0); + cudaStream_t stream = + static_cast(ctx.work_queue)->GetComputeStream(); void *device_p = accessor.GetArgAsyncValueRef(0); #define CASE(D) \ case DTypeEnum::D: { \ using ctype = DTypeTraits::type_t; \ ctype casteRngOffset = static_cast(rngOffset); \ - cudaMemcpy(device_p, &casteRngOffset, sizeof(ctype), \ - cudaMemcpyHostToDevice); \ + cudaMemcpyAsync(device_p, &casteRngOffset, sizeof(ctype), \ + cudaMemcpyHostToDevice, stream); \ return common::Status::OK(); \ } BRT_DISPATCH_NUMBER_TYPES(dtype, CASE)