diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h index 37c98ef9381f98..d857e97d5b2ca6 100644 --- a/flang/include/flang/Optimizer/Transforms/Passes.h +++ b/flang/include/flang/Optimizer/Transforms/Passes.h @@ -77,6 +77,7 @@ std::unique_ptr createAlgebraicSimplificationPass(); std::unique_ptr createAlgebraicSimplificationPass(const mlir::GreedyRewriteConfig &config); +std::unique_ptr createOMPAllocatableArrayOptPass(); std::unique_ptr createVScaleAttrPass(); std::unique_ptr createVScaleAttrPass(std::pair vscaleAttr); diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td index ab98591c911cdf..25d5eabadd8860 100644 --- a/flang/include/flang/Optimizer/Transforms/Passes.td +++ b/flang/include/flang/Optimizer/Transforms/Passes.td @@ -358,6 +358,18 @@ def LoopVersioning : Pass<"loop-versioning", "mlir::func::FuncOp"> { let dependentDialects = [ "fir::FIROpsDialect" ]; } +def OMPAllocatableArrayOpt : Pass<"omp-allocatable-arrays-optimization", + "mlir::func::FuncOp"> { + let summary = "Skip extraction of pointer to allocated memory from " + "allocatable array decriptor."; + let constructor = "::fir::createOMPAllocatableArrayOptPass()"; + let dependentDialects = [ + "fir::FIROpsDialect", + "mlir::func::FuncDialect", + "mlir::omp::OpenMPDialect" + ]; +} + def VScaleAttr : Pass<"vscale-attr", "mlir::func::FuncOp"> { let summary = "Add vscale_range attribute to functions"; let description = [{ diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc index aa2b55d2d1014a..021e1b92bcedb5 100644 --- a/flang/include/flang/Tools/CLOptions.inc +++ b/flang/include/flang/Tools/CLOptions.inc @@ -354,6 +354,8 @@ inline void createHLFIRToFIRPassPipeline( pm.addPass(hlfir::createLowerHLFIRIntrinsics()); pm.addPass(hlfir::createBufferizeHLFIR()); pm.addPass(hlfir::createConvertHLFIRtoFIR()); + pm.addPass(fir::createOMPAllocatableArrayOptPass()); + } using DoConcurrentMappingKind = diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt index b68e3d68b9b83e..fe29e83942cd75 100644 --- a/flang/lib/Optimizer/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt @@ -23,6 +23,7 @@ add_flang_library(FIRTransforms AddDebugInfo.cpp PolymorphicOpConversion.cpp LoopVersioning.cpp + OMPAllocatableArrayOpt.cpp StackReclaim.cpp VScaleAttr.cpp FunctionAttr.cpp diff --git a/flang/lib/Optimizer/Transforms/OMPAllocatableArrayOpt.cpp b/flang/lib/Optimizer/Transforms/OMPAllocatableArrayOpt.cpp new file mode 100644 index 00000000000000..df98abcb7ad176 --- /dev/null +++ b/flang/lib/Optimizer/Transforms/OMPAllocatableArrayOpt.cpp @@ -0,0 +1,324 @@ +//===- OMPAllocatableArrayOpt.cpp +//-------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements transforms to filter out functions intended for the host +// when compiling for the device and vice versa. +// +//===----------------------------------------------------------------------===// +#include "flang/Lower/Support/Utils.h" +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/FIROpsSupport.h" +#include "flang/Optimizer/Transforms/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h" +#include "mlir/IR/BuiltinOps.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Frontend/OpenMP/OMPConstants.h" + +namespace fir { +#define GEN_PASS_DEF_OMPALLOCATABLEARRAYOPT +#include "flang/Optimizer/Transforms/Passes.h.inc" +} // namespace fir + +using namespace mlir; + +namespace { + +class OMPAllocatableArrayOptPass + : public fir::impl::OMPAllocatableArrayOptBase { + struct ArrayBound { + Value lowerBound; + Value upperBound; + }; + struct AllocatableArrayDescriptorItems { + omp::MapInfoOp basePtrMapInfo; + Value basePtrKernelArg; + llvm::SmallVector boundDesc; + }; + llvm::DenseMap + declareDescriptorMap; + + std::optional getNumberOfArrayDim(fir::ArrayCoorOp arrayCoorOp) { + // TODO: Can we optimize such fir.coor_arr ops? + if (!arrayCoorOp.getShape()) { + return {}; + } + fir::ShapeShiftOp shapeShiftOp = + dyn_cast(arrayCoorOp.getShape().getDefiningOp()); + if (!shapeShiftOp) + return {}; + return shapeShiftOp.getExtents().size(); + } + + fir::DeclareOp findAllocatableDeclareOp(fir::ArrayCoorOp arrayCoorOp) { + // TODO: Can we optimize such fir.coor_arr ops? + if (!arrayCoorOp.getShape()) { + return nullptr; + } + fir::ShapeShiftOp shapeShiftOp = + dyn_cast(arrayCoorOp.getShape().getDefiningOp()); + if (!shapeShiftOp) + return nullptr; + fir::BoxAddrOp boxAddrOp = + dyn_cast(arrayCoorOp.getMemref().getDefiningOp()); + if (!boxAddrOp) + return nullptr; + if (!isa(boxAddrOp.getType())) + return nullptr; + fir::LoadOp loadOp = + dyn_cast(boxAddrOp.getVal().getDefiningOp()); + if (!loadOp) + return nullptr; + return dyn_cast(loadOp.getMemref().getDefiningOp()); + } + void eraseNotUsed(fir::ArrayCoorOp arrayCoorOp, + omp::MapClauseOwningOpInterface mapClauseOwner, + Block *targetEntryBlock) { + fir::ShapeShiftOp shapeShiftOp = + dyn_cast(arrayCoorOp.getShape().getDefiningOp()); + fir::BoxAddrOp boxAddrOp = + dyn_cast(arrayCoorOp.getMemref().getDefiningOp()); + assert(arrayCoorOp->use_empty()); + arrayCoorOp.erase(); + std::vector shapeShiftOpExtents(shapeShiftOp.getExtents()); + if (shapeShiftOp->use_empty()) + shapeShiftOp.erase(); + for (size_t i = 0; i < shapeShiftOpExtents.size(); ++i) { + Value shapeVal = shapeShiftOpExtents[i]; + if (shapeVal.use_empty()) + shapeVal.getDefiningOp()->erase(); + } + + fir::LoadOp loadOp = + dyn_cast(boxAddrOp.getVal().getDefiningOp()); + if (boxAddrOp->use_empty()) + boxAddrOp.erase(); + + fir::DeclareOp declareOp = + dyn_cast(loadOp.getMemref().getDefiningOp()); + if (loadOp->use_empty()) + loadOp.erase(); + OperandRange mapVarsArr = mapClauseOwner.getMapVarsMutable(); + assert(mapVarsArr.size() == targetEntryBlock->getNumArguments()); + for (size_t i = 0; i < targetEntryBlock->getNumArguments(); ++i) { + if (targetEntryBlock->getArgument(i) == declareOp.getMemref()) { + omp::MapInfoOp mapInfo = + dyn_cast(mapVarsArr[i].getDefiningOp()); + if (declareOp->use_empty()) { + declareOp.erase(); + targetEntryBlock->eraseArgument(i); + mapClauseOwner.getMapVarsMutable().erase(i); + mapInfo.erase(); + } + break; + } + } + } + + AllocatableArrayDescriptorItems getAllocatableArrayDescriptorItems( + fir::DeclareOp declareOp, omp::MapClauseOwningOpInterface mapClauseOwner, + Block *targetEntryBlock, size_t numberOfDims) { + AllocatableArrayDescriptorItems descriptorItems; + OperandRange mapVarsArr = mapClauseOwner.getMapVars(); + assert(mapVarsArr.size() == targetEntryBlock->getNumArguments()); + Operation *mapItemVarPtr; + for (size_t i = 0; i < targetEntryBlock->getNumArguments(); ++i) { + if (targetEntryBlock->getArgument(i) == declareOp.getMemref()) { + omp::MapInfoOp mapInfo = + dyn_cast(mapVarsArr[i].getDefiningOp()); + mapItemVarPtr = mapInfo.getVarPtr().getDefiningOp(); + assert(mapInfo && (mapInfo.getMembers().size() == 1) && + "Expected only base addr ptr"); + descriptorItems.basePtrMapInfo = + dyn_cast(mapInfo.getMembers()[0].getDefiningOp()); + break; + } + } + for (size_t index = 0; index < mapVarsArr.size(); ++index) { + if (descriptorItems.basePtrMapInfo == mapVarsArr[index].getDefiningOp()) { + descriptorItems.basePtrKernelArg = targetEntryBlock->getArgument(index); + } + } + assert(descriptorItems.basePtrMapInfo && "Expected base ptr map info"); + assert(descriptorItems.basePtrKernelArg && + "Expected base ptr kernel argument"); + return descriptorItems; + } + + void rewriteMapInfo(AllocatableArrayDescriptorItems &descriptorItem, + omp::MapClauseOwningOpInterface mapClauseOwner, + Block *targetEntryBlock, size_t numberOfDims) { + OperandRange mapVarsArr = mapClauseOwner.getMapVars(); + omp::MapInfoOp mapInfo = descriptorItem.basePtrMapInfo; + size_t index; + for (index = 0; index < mapVarsArr.size(); ++index) { + if (descriptorItem.basePtrMapInfo == mapVarsArr[index].getDefiningOp()) { + break; + } + } + assert(mapInfo); + OpBuilder opBuilder(mapInfo); + fir::FirOpBuilder builder(opBuilder, mapInfo); + Operation *op = opBuilder.create( + mapInfo->getLoc(), mapInfo.getType(), mapInfo.getVarPtrPtr(), + TypeAttr::get(mapInfo.getVarType()), mapInfo.getVarPtrPtr(), + llvm::SmallVector{}, ArrayAttr{}, + llvm::SmallVector(mapInfo.getBounds()), + opBuilder.getIntegerAttr(opBuilder.getIntegerType(64, false), + mapInfo.getMapType().value()), + opBuilder.getAttr( + mapInfo.getMapCaptureType().value()), + opBuilder.getStringAttr(""), opBuilder.getBoolAttr(false)); + + mapInfo.replaceAllUsesWith(op); + mapVarsArr[index] = op->getResult(0); + + OpBuilder::InsertPoint insPt = builder.saveInsertionPoint(); + Block *allocaBlock = builder.getAllocaBlock(); + assert(allocaBlock && "No alloca block found for this top level op"); + llvm::SmallVector newMapOps; + for (size_t i = 0; i < mapVarsArr.size(); ++i) { + newMapOps.push_back(mapVarsArr[i]); + } + size_t mapArgumentIndex = mapVarsArr.size(); + for (size_t dim = 0; dim < numberOfDims; ++dim) { + descriptorItem.boundDesc.push_back({}); + for (size_t i = 0; i < 2; ++i) { + builder.setInsertionPointToStart(allocaBlock); + auto alloca = builder.create(mapInfo->getLoc(), + builder.getIntegerType(64)); + builder.restoreInsertionPoint(insPt); + auto dimVal = builder.createIntegerConstant( + mapInfo->getLoc(), builder.getIndexType(), dim); + Value allocatableDescriptor = + builder.create(mapInfo->getLoc(), mapInfo.getVarPtr()); + auto dimInfo = builder.create( + mapInfo->getLoc(), builder.getIndexType(), builder.getIndexType(), + builder.getIndexType(), allocatableDescriptor, dimVal); + Value bound = + builder.createConvert(mapInfo->getLoc(), builder.getIntegerType(64), + dimInfo->getResult(i)); + opBuilder.create(mapInfo->getLoc(), bound, alloca); + omp::VariableCaptureKind captureKind = omp::VariableCaptureKind::ByCopy; + + Operation *newMapItem = opBuilder.create( + mapInfo->getLoc(), alloca.getType(), alloca, TypeAttr::get(bound.getType()), + Value{}, llvm::SmallVector{}, ArrayAttr{}, + llvm::SmallVector{}, + opBuilder.getIntegerAttr( + opBuilder.getIntegerType(64, false), + llvm::to_underlying( + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT)), + builder.getAttr(captureKind), + opBuilder.getStringAttr(""), opBuilder.getBoolAttr(false)); + + newMapOps.push_back(newMapItem->getResult(0)); + targetEntryBlock->insertArgument(mapArgumentIndex, + newMapItem->getResult(0).getType(), + newMapItem->getLoc()); + if (i == 0) { + descriptorItem.boundDesc[dim].lowerBound = + targetEntryBlock->getArgument(mapArgumentIndex); + } else if (i == 1) { + descriptorItem.boundDesc[dim].upperBound = + targetEntryBlock->getArgument(mapArgumentIndex); + } + mapArgumentIndex++; + } + } + + mapClauseOwner.getMapVarsMutable().assign(newMapOps); + mapInfo.erase(); + } + + void rewriteArrayCoorOp(fir::ArrayCoorOp arrayCoorOp, + AllocatableArrayDescriptorItems &descriptorItem) { + OpBuilder opBuilder(arrayCoorOp); + fir::FirOpBuilder builder(opBuilder, arrayCoorOp); + Value addr = builder.createConvert(arrayCoorOp.getLoc(), + arrayCoorOp.getMemref().getType(), + descriptorItem.basePtrKernelArg); + llvm::SmallVector lbounds; + llvm::SmallVector ubounds; + + for (size_t dim = 0; dim < descriptorItem.boundDesc.size(); dim++) { +#if 0 + //Experiment - provide bound information in compile time + Value lb = descriptorItem.boundDesc[dim].lowerBound; + lbounds.push_back(builder.createIntegerConstant(lb.getLoc(),builder.getIndexType(), 1)); + ubounds.push_back(builder.createIntegerConstant(lb.getLoc(),builder.getIndexType(), 100)); +#else + Value lb = descriptorItem.boundDesc[dim].lowerBound; + Value lbVal = builder.create(lb.getLoc(), lb); + Value lbValConvert = + builder.createConvert(lb.getLoc(), builder.getIndexType(), lbVal); + lbounds.push_back(lbValConvert); + Value ub = descriptorItem.boundDesc[dim].upperBound; + Value ubVal = builder.create(ub.getLoc(), ub); + Value ubValConvert = + builder.createConvert(ub.getLoc(), builder.getIndexType(), ubVal); + ubounds.push_back(ubValConvert); +#endif + } + + auto shapeShiftArgs = flatZip(lbounds, ubounds); + auto shapeTy = + fir::ShapeShiftType::get(arrayCoorOp->getContext(), lbounds.size()); + Value shapeShift = builder.create( + arrayCoorOp.getLoc(), shapeTy, shapeShiftArgs); + Value optimizedArrayCoorOp = builder.create( + arrayCoorOp.getLoc(), arrayCoorOp.getType(), addr, shapeShift, + arrayCoorOp.getSlice(), arrayCoorOp.getIndices(), + arrayCoorOp.getTypeparams()); + arrayCoorOp.replaceAllUsesWith(optimizedArrayCoorOp); + } + +public: + OMPAllocatableArrayOptPass() = default; + + void runOnOperation() override { + func::FuncOp func = getOperation(); + declareDescriptorMap.clear(); + func->walk([&](omp::TargetOp targetOp) { + auto mapClauseOwner = llvm::dyn_cast( + targetOp.getOperation()); + Block *entryBlock = &targetOp->getRegion(0).front(); + if (mapClauseOwner) { + OperandRange mapVarsArr = mapClauseOwner.getMapVars(); + targetOp->walk([&](fir::ArrayCoorOp arrayCoorOp) { + fir::DeclareOp declareOp = findAllocatableDeclareOp(arrayCoorOp); + std::optional numberOfArrayDim = + getNumberOfArrayDim(arrayCoorOp); + if (!numberOfArrayDim.has_value()) + return; + if (!declareOp) + return; + if (!declareDescriptorMap.contains(declareOp)) { + declareDescriptorMap[declareOp] = + getAllocatableArrayDescriptorItems( + declareOp, mapClauseOwner, entryBlock, *numberOfArrayDim); + rewriteMapInfo(declareDescriptorMap[declareOp], mapClauseOwner, + entryBlock, *numberOfArrayDim); + } + rewriteArrayCoorOp(arrayCoorOp, declareDescriptorMap[declareOp]); + eraseNotUsed(arrayCoorOp, mapClauseOwner, entryBlock); + }); + ; + } + }); + } +}; +} // namespace + +std::unique_ptr fir::createOMPAllocatableArrayOptPass() { + return std::make_unique(); +}