diff --git a/compiler/include/byteir/Conversion/HloToByreTensor/HloToByreTensor.h b/compiler/include/byteir/Conversion/HloToByreTensor/HloToByreTensor.h index f045933ba..8b9e36c79 100644 --- a/compiler/include/byteir/Conversion/HloToByreTensor/HloToByreTensor.h +++ b/compiler/include/byteir/Conversion/HloToByreTensor/HloToByreTensor.h @@ -24,16 +24,14 @@ #include namespace mlir { +class ModuleOp; // forward decl -namespace func { -class FuncOp; -} // namespace func void populateHloToByreTensorPattern( RewritePatternSet &patterns, const llvm::StringMap &supportMap, bool appendArgTypes); -std::unique_ptr> +std::unique_ptr> createConvertHloToByreTensorPass(bool appendArgTypes = false); } // namespace mlir diff --git a/compiler/include/byteir/Conversion/Passes.td b/compiler/include/byteir/Conversion/Passes.td index ead522d1c..f98447280 100644 --- a/compiler/include/byteir/Conversion/Passes.td +++ b/compiler/include/byteir/Conversion/Passes.td @@ -267,7 +267,7 @@ def MhloToCat : Pass<"mhlo-to-cat", "func::FuncOp"> { // HloToByreTensor //===----------------------------------------------------------------------===// -def ConvertHloToByreTensor : Pass<"hlo-to-byre-tensor", "func::FuncOp"> { +def ConvertHloToByreTensor : Pass<"hlo-to-byre-tensor", "ModuleOp"> { let summary = "Convert hlo op to byre tensor op."; let constructor = "mlir::createConvertHloToByreTensorPass()"; let dependentDialects = [ diff --git a/compiler/include/byteir/Dialect/mhlo/Passes.h b/compiler/include/byteir/Dialect/mhlo/Passes.h index 351b071df..58ebeb6f7 100644 --- a/compiler/include/byteir/Dialect/mhlo/Passes.h +++ b/compiler/include/byteir/Dialect/mhlo/Passes.h @@ -36,7 +36,6 @@ #include "byteir/Dialect/mhlo/Transforms/LayoutTransformation.h" #include "byteir/Dialect/mhlo/Transforms/MatmulLayoutTransform.h" #include "byteir/Dialect/mhlo/Transforms/RewriteWithConstraint.h" -#include "byteir/Dialect/mhlo/Transforms/ShapeReification.h" #include "byteir/Dialect/mhlo/Transforms/StaticShapeInference.h" #include "byteir/Dialect/mhlo/Transforms/UnfuseBatchNorm.h" diff --git a/compiler/include/byteir/Dialect/mhlo/Passes.td b/compiler/include/byteir/Dialect/mhlo/Passes.td index 7fe03f8f8..58f35033d 100644 --- a/compiler/include/byteir/Dialect/mhlo/Passes.td +++ b/compiler/include/byteir/Dialect/mhlo/Passes.td @@ -305,25 +305,6 @@ def RewriteWithConstraint : Pass<"rewrite-with-constraint", "mlir::func::FuncOp let constructor = "mlir::createRewriteWithConstraintPass()"; } -//===----------------------------------------------------------------------===// -// ShapeReification -//===----------------------------------------------------------------------===// - -def ShapeReification : Pass<"byteir-shape-reification", "func::FuncOp"> { - let summary = "Iteratively reify all shape computations."; - let description = [{ - If an operation has a shape reification implementation, that is to say, we - know how to express the outputs' shape by it's inputs' shape symbolicly, - then a tensor.dim or shape.shape_of on this type of operation could be - reified. And shape reification procedure could be handled recursively. - }]; - let constructor = "mlir::createByteIRShapeReificationPass()"; - let dependentDialects = [ - "mlir::shape::ShapeDialect", - "mlir::tensor::TensorDialect" - ]; -} - //===----------------------------------------------------------------------===// // Static Shape Inference //===----------------------------------------------------------------------===// diff --git a/compiler/include/byteir/Transforms/Passes.h b/compiler/include/byteir/Transforms/Passes.h index 7a179a34e..0f4beb1ac 100644 --- a/compiler/include/byteir/Transforms/Passes.h +++ b/compiler/include/byteir/Transforms/Passes.h @@ -35,6 +35,7 @@ #include "byteir/Transforms/RewriteOpToStdCall.h" #include "byteir/Transforms/SetArgShape.h" #include "byteir/Transforms/SetSpace.h" +#include "byteir/Transforms/ShapeReification.h" #include "byteir/Transforms/TryCatchModulePipeline.h" namespace mlir { diff --git a/compiler/include/byteir/Transforms/Passes.td b/compiler/include/byteir/Transforms/Passes.td index 97d69c022..b3d3e9853 100644 --- a/compiler/include/byteir/Transforms/Passes.td +++ b/compiler/include/byteir/Transforms/Passes.td @@ -425,4 +425,24 @@ def SetOpSpace: Pass<"set-op-space", "func::FuncOp"> { ]; } +//===----------------------------------------------------------------------===// +// ShapeReification +//===----------------------------------------------------------------------===// + +def ShapeReification : Pass<"byteir-shape-reification", "func::FuncOp"> { + let summary = "Iteratively reify all shape computations."; + let description = [{ + If an operation has a shape reification implementation, that is to say, we + know how to express the outputs' shape by it's inputs' shape symbolicly, + then a tensor.dim or shape.shape_of on this type of operation could be + reified. And shape reification procedure could be handled recursively. + }]; + let constructor = "mlir::createByteIRShapeReificationPass()"; + let dependentDialects = [ + "mlir::shape::ShapeDialect", + "mlir::tensor::TensorDialect", + "mlir::arith::ArithDialect", + ]; +} + #endif // BYTEIR_TRANSFORMS_PASSES diff --git a/compiler/include/byteir/Dialect/mhlo/Transforms/ShapeReification.h b/compiler/include/byteir/Transforms/ShapeReification.h similarity index 94% rename from compiler/include/byteir/Dialect/mhlo/Transforms/ShapeReification.h rename to compiler/include/byteir/Transforms/ShapeReification.h index 19f338f22..7c4cb5043 100644 --- a/compiler/include/byteir/Dialect/mhlo/Transforms/ShapeReification.h +++ b/compiler/include/byteir/Transforms/ShapeReification.h @@ -1,6 +1,6 @@ //===- ShapeReification.h -------------------------------------*--- C++ -*-===// // -// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at diff --git a/compiler/lib/Analysis/SymbolicShape.cpp b/compiler/lib/Analysis/SymbolicShape.cpp index 1f5d9f499..703dec1c4 100644 --- a/compiler/lib/Analysis/SymbolicShape.cpp +++ b/compiler/lib/Analysis/SymbolicShape.cpp @@ -16,7 +16,7 @@ //===----------------------------------------------------------------------===// #include "byteir/Analysis/SymbolicShape.h" -#include "byteir/Dialect/mhlo/Transforms/ShapeReification.h" +#include "byteir/Transforms/ShapeReification.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/IRMapping.h" diff --git a/compiler/lib/Conversion/HloToByreTensor/HloToByreTensor.cpp b/compiler/lib/Conversion/HloToByreTensor/HloToByreTensor.cpp index 91104ec58..dae0bbdc1 100644 --- a/compiler/lib/Conversion/HloToByreTensor/HloToByreTensor.cpp +++ b/compiler/lib/Conversion/HloToByreTensor/HloToByreTensor.cpp @@ -768,7 +768,6 @@ struct ConvertHloToByreTensorPass MLIRContext &ctx = getContext(); RewritePatternSet patterns(&ctx); ConversionTarget target(ctx); - auto funcOp = getOperation(); populateHloToByreTensorPattern(patterns, supportMap, appendArgTypes); target.addIllegalDialect(); @@ -776,7 +775,8 @@ struct ConvertHloToByreTensorPass shape::ShapeDialect, arith::ArithDialect>(); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); - if (failed(applyPartialConversion(funcOp, target, frozenPatterns))) { + if (failed( + applyPartialConversion(getOperation(), target, frozenPatterns))) { signalPassFailure(); } } @@ -810,7 +810,7 @@ void mlir::populateHloToByreTensorPattern( ConvertSliceOp, ConvertConcatenateOp>(patterns.getContext()); } -std::unique_ptr> +std::unique_ptr> mlir::createConvertHloToByreTensorPass(bool appendArgTypes) { return std::make_unique(appendArgTypes); } diff --git a/compiler/lib/Dialect/mhlo/CMakeLists.txt b/compiler/lib/Dialect/mhlo/CMakeLists.txt index 81667fb71..a6501cf0f 100644 --- a/compiler/lib/Dialect/mhlo/CMakeLists.txt +++ b/compiler/lib/Dialect/mhlo/CMakeLists.txt @@ -105,7 +105,6 @@ add_mlir_dialect_library(ByteIRMhloPasses Transforms/ReduceFusion.cpp Transforms/ReshapeGather.cpp Transforms/RewriteWithConstraint.cpp - Transforms/ShapeReification.cpp Transforms/StaticShapeInference.cpp Transforms/TrivialFusion.cpp Transforms/UnfuseBatchNorm.cpp diff --git a/compiler/lib/Dialect/mhlo/Util/ShapeInferUtil.cpp b/compiler/lib/Dialect/mhlo/Util/ShapeInferUtil.cpp index 0bf8250b5..313025a1d 100644 --- a/compiler/lib/Dialect/mhlo/Util/ShapeInferUtil.cpp +++ b/compiler/lib/Dialect/mhlo/Util/ShapeInferUtil.cpp @@ -17,14 +17,24 @@ #include "byteir/Dialect/mhlo/Util/ShapeInferUtil.h" #include "byteir/Dialect/mhlo/DynamicShapeOpRegister/Register.h" +#include "byteir/Transforms/ShapeReification.h" #include "mhlo/IR/hlo_ops.h" +#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/TopologicalSortUtils.h" #include "llvm/ADT/StringMap.h" #include "llvm/Support/Debug.h" +#include +#include + using namespace mlir; #define DEBUG_TYPE "shape-infer-util" @@ -177,6 +187,203 @@ mlir::inferReturnTypeComponents(llvm::StringRef name) { return nullptr; } +namespace { +bool deduceFromFuncArgShape(Value value) { + if (value.isa()) { + return false; + } + + auto defOp = value.getDefiningOp(); + if (!defOp) { + return false; + } + + if (isa(defOp)) { + return true; + } + + if (isa(defOp)) { + auto operand = defOp->getOperand(0); + if (operand.isa()) { + return true; + } + return false; + } + + for (Value &&operand : defOp->getOperands()) { + if (!deduceFromFuncArgShape(operand)) { + return false; + } + } + return true; +} + +// the auxiliaryModuleOp must be a empty module, only used for save shapeFunc +FailureOr +createCorrespondingShapeFunc(func::FuncOp funcOp, ModuleOp auxiliaryModuleOp) { + // use auxiliary builder, create shape func in the start of auxiliaryModuleOp + ModuleOp oriModuleOp = funcOp->getParentOfType(); + OpBuilder builder = OpBuilder::atBlockBegin(auxiliaryModuleOp.getBody()); + + // clone funcOp, newFuncOp used for deduce function shape + Twine shapeFuncName = funcOp.getName() + "_Shape"; + auto shapeFunc = builder.create( + funcOp->getLoc(), shapeFuncName.str(), funcOp.getFunctionType()); + shapeFunc.setPrivate(); + IRMapping emptyBvm; + funcOp.cloneInto(shapeFunc, emptyBvm); + llvm::DenseSet callOpSet; + shapeFunc.walk([&](func::CallOp callOp) { callOpSet.insert(callOp); }); + + while (!callOpSet.empty()) { + auto callOp = *callOpSet.begin(); + callOpSet.erase(callOpSet.begin()); + auto callFunc = oriModuleOp.lookupSymbol(callOp.getCallee()); + // inline this func. + builder.setInsertionPoint(callOp); + IRMapping bvm; + for (auto inputAndArg : + llvm::zip(callFunc.getArguments(), callOp.getOperands())) { + bvm.map(std::get<0>(inputAndArg), std::get<1>(inputAndArg)); + } + Block &entryBlock = callFunc.getBlocks().front(); + ValueRange funcOuts; + for (Operation &op : entryBlock) { + auto retOp = mlir::dyn_cast(op); + if (!retOp) { + auto newOp = builder.clone(op, bvm); + if (auto nestedCall = dyn_cast(newOp)) { + callOpSet.insert(nestedCall); + } + } else { + funcOuts = retOp.getOperands(); + } + } + + for (auto callResultAndFuncOuts : + llvm::zip(callOp.getResults(), funcOuts)) { + auto mappedOut = bvm.lookup(std::get<1>(callResultAndFuncOuts)); + std::get<0>(callResultAndFuncOuts).replaceAllUsesWith(mappedOut); + } + callOp->erase(); + } + + // replace the operands of returnOp with corresponding shape + func::ReturnOp retOp = *shapeFunc.getOps().begin(); + if (!retOp) { + shapeFunc->erase(); + return failure(); + } + + for (Value &&retTensor : retOp.getOperands()) { + auto retTy = retTensor.getType(); + if (!retTy.isa()) { + shapeFunc->erase(); + return failure(); + } + } + + SmallVector allResultTypes; + SmallVector allResults; + + builder.setInsertionPoint(retOp); + for (Value &&retTensor : retOp.getOperands()) { + auto retShape = builder.create(retOp.getLoc(), retTensor); + allResultTypes.emplace_back(retShape.getType()); + allResults.emplace_back(retShape); + } + + // return the shape of original tensor returned by function + auto shapeFuncRetOp = + builder.create(retOp.getLoc(), allResults); + auto shapeFuncType = + builder.getFunctionType(shapeFunc.getArgumentTypes(), allResultTypes); + shapeFunc.setFunctionType(shapeFuncType); + retOp->erase(); + + // reify shapeFunc to get the shape computation. + { + PassManager pm(oriModuleOp->getContext(), func::FuncOp::getOperationName()); + // only run pass on shapeFunc + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + pm.addPass(createByteIRShapeReificationPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + if (mlir::failed(pm.run(shapeFunc))) { + shapeFunc->erase(); + return failure(); + } + } + return shapeFunc; +} + +LogicalResult reifyCallOp(OpBuilder &builder, Operation *op, + SmallVectorImpl &reifications) { + OpBuilder::InsertionGuard guard(builder); + auto callOp = dyn_cast(op); + if (!callOp) { + return failure(); + } + + ModuleOp moduleOp = op->getParentOfType(); + StringRef funcName = callOp.getCallee(); + auto funcOp = moduleOp.lookupSymbol(funcName); + + // create a temp module, then insert corresponding shape function to this + // module + OwningOpRef auxiliaryModuleOp( + ModuleOp::create(UnknownLoc::get(moduleOp->getContext()))); + auto maybeShapeFunc = + createCorrespondingShapeFunc(funcOp, auxiliaryModuleOp.get()); + if (failed(maybeShapeFunc)) { + return failure(); + } + + func::FuncOp shapeFunc = *maybeShapeFunc; + func::ReturnOp retOp = *shapeFunc.getOps().begin(); + + // collect all shape computation ops + SetVector reificationOpSet; + getBackwardSlice(retOp.getOperation(), &reificationOpSet); + SmallVector reificationOps(reificationOpSet.begin(), + reificationOpSet.end()); + // value only depends on the shape of FuncArgs. + for (Value &&ret : retOp.getOperands()) { + if (!deduceFromFuncArgShape(ret)) { + shapeFunc->erase(); + return failure(); + } + } + + // mapping the shape computation ops and collect reifications + { + mlir::computeTopologicalSorting(reificationOps); + + IRMapping bvm; + size_t numArg = shapeFunc.getNumArguments(); + for (size_t i = 0; i < numArg; ++i) { + bvm.map(shapeFunc.getArgument(i), callOp.getOperand(i)); + } + + builder.setInsertionPoint(callOp); + + for (Operation *oldOp : reificationOps) { + auto newOp = builder.clone(*oldOp, bvm); + } + + for (Value &&ret : retOp.getOperands()) { + reifications.push_back(bvm.lookup(ret)); + } + } + + // remove newFuncOp + shapeFunc->erase(); + return success(); +} + +} // namespace + LogicalResult mlir::reifyShapes(OpBuilder &builder, Operation *op, SmallVectorImpl &reifications) { if (!op) @@ -207,6 +414,16 @@ LogicalResult mlir::reifyShapes(OpBuilder &builder, Operation *op, } if (failed(inferFunc(op, builder, op->getOperands(), reifications))) return failure(); + } else if (auto callOp = dyn_cast(op)) { + if (failed(reifyCallOp(builder, op, reifications))) { + return failure(); + } + } else if (auto dpsOp = dyn_cast(op)) { + for (OpResult &&result : op->getOpResults()) { + auto tiedOperand = dpsOp.getTiedOpOperand(result); + reifications.push_back( + builder.create(op->getLoc(), tiedOperand->get())); + } } else { // Return failure if op doesn't have InferShapedTypeOpInterface and not // registered. diff --git a/compiler/lib/Pipelines/ByreTensorOpt.cpp b/compiler/lib/Pipelines/ByreTensorOpt.cpp index 5b1f710ad..91940b0d4 100644 --- a/compiler/lib/Pipelines/ByreTensorOpt.cpp +++ b/compiler/lib/Pipelines/ByreTensorOpt.cpp @@ -45,8 +45,8 @@ void createByreTensorOptPipelineImpl(OpPassManager &pm, std::string entryFunc, pm.addPass(createCanonicalizerPass()); pm.addNestedPass( createConvertHloToByreCustomPass(getCudaByreCustomConfig())); - pm.addNestedPass( - createConvertHloToByreTensorPass(appendArgTypes)); + pm.addPass(createConvertHloToByreTensorPass(appendArgTypes)); + pm.addNestedPass(createByteIRShapeReificationPass()); pm.addPass(createCanonicalizerPass()); } } // namespace diff --git a/compiler/lib/Transforms/CMakeLists.txt b/compiler/lib/Transforms/CMakeLists.txt index 9ac510696..3ab4a25ab 100644 --- a/compiler/lib/Transforms/CMakeLists.txt +++ b/compiler/lib/Transforms/CMakeLists.txt @@ -17,6 +17,7 @@ add_mlir_library(ByteIRTransforms RewriteOpToStdCall.cpp SetArgShape.cpp SetSpace.cpp + ShapeReification.cpp Utils.cpp ADDITIONAL_HEADER_DIRS diff --git a/compiler/lib/Transforms/PassDetail.h b/compiler/lib/Transforms/PassDetail.h index f0cf6f3fa..0a63eda78 100644 --- a/compiler/lib/Transforms/PassDetail.h +++ b/compiler/lib/Transforms/PassDetail.h @@ -43,6 +43,10 @@ namespace memref { class MemRefDialect; } // namespace memref +namespace arith { +class ArithDialect; +} // namespace arith + namespace mhlo { class MhloDialect; } // namespace mhlo @@ -51,6 +55,14 @@ namespace scf { class SCFDialect; } // namespace scf +namespace shape { +class ShapeDialect; +} // namespace shape + +namespace tensor { +class TensorDialect; +} // namespace tensor + #define GEN_PASS_CLASSES #include "byteir/Transforms/Passes.h.inc" diff --git a/compiler/lib/Dialect/mhlo/Transforms/ShapeReification.cpp b/compiler/lib/Transforms/ShapeReification.cpp similarity index 94% rename from compiler/lib/Dialect/mhlo/Transforms/ShapeReification.cpp rename to compiler/lib/Transforms/ShapeReification.cpp index 7b6c1b548..8d09215a8 100644 --- a/compiler/lib/Dialect/mhlo/Transforms/ShapeReification.cpp +++ b/compiler/lib/Transforms/ShapeReification.cpp @@ -1,6 +1,6 @@ //===- ShapeReification.cpp -----------------------------------*--- C++ -*-===// // -// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -15,11 +15,12 @@ // //===----------------------------------------------------------------------===// -#include "byteir/Dialect/mhlo/Transforms/ShapeReification.h" +#include "byteir/Transforms/ShapeReification.h" #include "byteir/Dialect/mhlo/DynamicShapeOpRegister/Register.h" #include "byteir/Dialect/mhlo/Util/ShapeInferUtil.h" #include "mhlo/IR/hlo_ops.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -59,11 +60,12 @@ struct ShapeReificationOnTensorDimPattern // Insert cast, if needed. if (dimOfShape.getType() != op.getType()) { - dimOfShape = rewriter.create(op.getLoc(), op.getType(), - dimOfShape); + dimOfShape = rewriter.create( + op.getLoc(), op.getType(), dimOfShape); } rewriter.replaceOp(op, dimOfShape); + return success(); } }; diff --git a/compiler/test/Conversion/FuncToByre/func_to_byre_tensor.mlir b/compiler/test/Conversion/FuncToByre/func_to_byre_tensor.mlir index 31a468e7d..4821a5915 100644 --- a/compiler/test/Conversion/FuncToByre/func_to_byre_tensor.mlir +++ b/compiler/test/Conversion/FuncToByre/func_to_byre_tensor.mlir @@ -20,3 +20,20 @@ func.func @test_normal_function_call(%arg0 : tensor<4xf32>) -> tensor<4xf32> att } // CHECK-LABEL: test_normal_function_call // CHECK: call @some_func + + +// ----- + +func.func private @Unknown0(%arg0: tensor, %arg1: tensor) -> tensor attributes {__byteir_elementwise_fusion__, byre_compute_name = "Unknown0"} { + %0 = mhlo.add %arg0, %arg1 : tensor + return %0 : tensor +} + +func.func @forward(%arg0: tensor, %arg1: tensor) -> tensor attributes {__placeholder__byre.entry_point} { + %1 = call @Unknown0(%arg1, %arg0) : (tensor, tensor) -> tensor + return %1 : tensor +} + +// CHECK-LABEL: func.func @forward +// CHECK: tensor.empty +// CHECK-NEXT: byre.compute_on_tensor @Unknown0 diff --git a/compiler/test/Transforms/shapeReification.mlir b/compiler/test/Transforms/shapeReification.mlir index d1b3cd530..cfba2985c 100644 --- a/compiler/test/Transforms/shapeReification.mlir +++ b/compiler/test/Transforms/shapeReification.mlir @@ -1,4 +1,4 @@ -// RUN: byteir-opt %s -byteir-shape-reification -canonicalize -cse | FileCheck %s +// RUN: byteir-opt %s --split-input-file -byteir-shape-reification -canonicalize -cse | FileCheck %s func.func @several_ops(%arg0: tensor, %arg1: tensor<2x4xf32>, %arg2: tensor<4xf32>) -> (!shape.shape, !shape.shape, !shape.shape, !shape.shape) { %0 = "mhlo.dot"(%arg0, %arg1) : (tensor, tensor<2x4xf32>) -> tensor @@ -26,6 +26,8 @@ func.func @several_ops(%arg0: tensor, %arg1: tensor<2x4xf32>, %arg2: te // CHECK-DAG: %[[V3:.+]] = shape.value_as_shape %[[C2]] : tensor<1xindex> -> !shape.shape // CHECK-DAG: return %[[V2]], %[[V3]], %[[V2]], %[[V2]] : !shape.shape, !shape.shape, !shape.shape, !shape.shape +// ----- + // CHECK-LABEL: @infer_shape_using_dim_op func.func @infer_shape_using_dim_op(%arg0: tensor, %arg1: tensor, %arg2: tensor<4x4xf32>) -> !shape.shape { %0 = mhlo.add %arg0, %arg1 : tensor @@ -40,6 +42,8 @@ func.func @infer_shape_using_dim_op(%arg0: tensor, %arg1: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> tensor { %0 = "mhlo.custom_call"(%arg0, %arg1, %arg2, %arg3) {call_target_name = "tf.DynamicStitch", has_side_effect = false} : (tensor, tensor, tensor, tensor) -> tensor %c0 = arith.constant 0 : index @@ -52,6 +56,8 @@ func.func @dynamic_stitch(%arg0: tensor, %arg1: tensor, %arg2: ten return %0 : tensor } +// ----- + func.func @gelu(%arg0: tensor) -> tensor { %0 = mhlo.custom_call @byteir.gelu(%arg0) {backend_config = "", byteir_attrs = {approximate = "erf"}} : (tensor) -> tensor %c0 = arith.constant 0 : index @@ -62,6 +68,8 @@ func.func @gelu(%arg0: tensor) -> tensor { return %0 : tensor } +// ----- + // CHECK-LABEL: func.func @dot_general func.func @dot_general(%arg0: tensor, %arg1: tensor) -> tensor<3xindex> { %c1 = arith.constant 1 : index @@ -80,11 +88,14 @@ func.func @dot_general(%arg0: tensor, %arg1: tensor) -> return %3 : tensor<3xindex> } +// ----- + // TODO: Check this after nested function call is supported func.func private @inner_func(%arg0 : tensor, %arg1 : tensor) -> tensor { %0 = mhlo.add %arg0, %arg1 : tensor return %0 : tensor } + func.func @outer_func(%arg0: tensor, %arg1: tensor) -> (!shape.shape, !shape.shape) { %0 = mhlo.add %arg0, %arg1 : tensor %1 = shape.shape_of %0 : tensor -> tensor<2xindex> @@ -94,3 +105,55 @@ func.func @outer_func(%arg0: tensor, %arg1: tensor) -> (!shape %5 = shape.value_as_shape %4 : tensor<2xindex> -> !shape.shape return %2, %5 : !shape.shape, !shape.shape } +// CHECK-LABEL: func.func @outer_func +// CHECK: %[[V0:.*]] = shape.shape_of %arg0 : tensor -> tensor<2xindex> +// CHECK: %[[V1:.*]] = shape.value_as_shape %1 : tensor<2xindex> -> !shape.shape +// CHECK: return %[[V1]], %[[V1]] : !shape.shape, !shape.shape + +// ----- + +func.func private @Unknown2(%arg0: tensor, %arg1: tensor) -> tensor attributes {__byteir_matmul_epilogue_fusion__} { + %0 = mhlo.constant dense_resource<__elided__> : tensor<10x20xf32> + %1 = "mhlo.dot"(%arg0, %0) : (tensor, tensor<10x20xf32>) -> tensor + %2 = mhlo.add %1, %arg1 : tensor + return %2 : tensor +} + +func.func private @Unknown1(%arg0: tensor, %arg1: tensor) -> tensor attributes {__byteir_matmul_epilogue_fusion__} { + %0 = mhlo.constant dense_resource<__elided__> : tensor<10x20xf32> + %1 = call @Unknown2(%arg0, %arg1) : (tensor, tensor) -> tensor + %2 = "mhlo.dot"(%arg0, %0) : (tensor, tensor<10x20xf32>) -> tensor + %3 = mhlo.add %2, %1 : tensor + return %3 : tensor +} + +func.func private @Unknown0(%arg0: tensor, %arg1: tensor<20xf32>, %arg2: tensor, %arg3: tensor) -> (tensor, tensor) { + %0 = mhlo.constant dense<0.000000e+00> : tensor + %c20 = arith.constant 20 : index + %c0 = arith.constant 0 : index + %dim = tensor.dim %arg0, %c0 : tensor + %from_elements = tensor.from_elements %dim, %c20 : tensor<2xindex> + %1 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %from_elements) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<20xf32>, tensor<2xindex>) -> tensor + %2 = mhlo.add %arg2, %1 : tensor + %3 = call @Unknown1(%arg0, %2) : (tensor, tensor) -> tensor + %4 = mhlo.maximum %2, %3 : tensor + return %4, %3 : tensor, tensor +} + +func.func @forward(%arg0: tensor, %arg1: tensor, %arg2: tensor<20x?xf32>) -> tensor<2xindex> attributes {__placeholder__byre.entry_point} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = mhlo.constant dense_resource<__elided__> : tensor<10x20xf32> + %1 = mhlo.constant dense_resource<__elided__> : tensor<20xf32> + %2 = "mhlo.dot"(%arg0, %0) : (tensor, tensor<10x20xf32>) -> tensor + %3:2 = call @Unknown0(%arg0, %1, %2, %arg1) : (tensor, tensor<20xf32>, tensor, tensor) -> (tensor, tensor) + %4 = "mhlo.dot"(%3#0, %arg2) : (tensor, tensor<20x?xf32>) -> tensor + %5 = shape.shape_of %4 : tensor -> tensor<2xindex> + return %5 : tensor<2xindex> +} + +// CHECK-LABEL: func.func @forward +// CHECK: %[[DIM:.*]] = tensor.dim %arg0, %c0 : tensor +// CHECK-NEXT: %[[DIM0:.*]] = tensor.dim %arg2, %c1 : tensor<20x?xf32> +// CHECK-NEXT: %[[SHAPE:.*]] = tensor.from_elements %[[DIM:.*]], %[[DIM0:.*]] : tensor<2xindex> +// CHECK-NEXT: return %[[SHAPE:.*]] : tensor<2xindex> \ No newline at end of file