Skip to content

Commit

Permalink
Remove allocatable array descriptor
Browse files Browse the repository at this point in the history
Change-Id: Idce7497185c3c34164e1c1b76b44751efa7a5766
  • Loading branch information
DominikAdamski committed Sep 26, 2024
1 parent 6d8995a commit 4c15d4c
Show file tree
Hide file tree
Showing 5 changed files with 340 additions and 0 deletions.
1 change: 1 addition & 0 deletions flang/include/flang/Optimizer/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ std::unique_ptr<mlir::Pass> createAlgebraicSimplificationPass();
std::unique_ptr<mlir::Pass>
createAlgebraicSimplificationPass(const mlir::GreedyRewriteConfig &config);

std::unique_ptr<mlir::Pass> createOMPAllocatableArrayOptPass();
std::unique_ptr<mlir::Pass> createVScaleAttrPass();
std::unique_ptr<mlir::Pass>
createVScaleAttrPass(std::pair<unsigned, unsigned> vscaleAttr);
Expand Down
12 changes: 12 additions & 0 deletions flang/include/flang/Optimizer/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,18 @@ def LoopVersioning : Pass<"loop-versioning", "mlir::func::FuncOp"> {
let dependentDialects = [ "fir::FIROpsDialect" ];
}

def OMPAllocatableArrayOpt : Pass<"omp-allocatable-arrays-optimization",
"mlir::func::FuncOp"> {
let summary = "Skip extraction of pointer to allocated memory from "
"allocatable array decriptor.";
let constructor = "::fir::createOMPAllocatableArrayOptPass()";
let dependentDialects = [
"fir::FIROpsDialect",
"mlir::func::FuncDialect",
"mlir::omp::OpenMPDialect"
];
}

def VScaleAttr : Pass<"vscale-attr", "mlir::func::FuncOp"> {
let summary = "Add vscale_range attribute to functions";
let description = [{
Expand Down
2 changes: 2 additions & 0 deletions flang/include/flang/Tools/CLOptions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,8 @@ inline void createHLFIRToFIRPassPipeline(
pm.addPass(hlfir::createLowerHLFIRIntrinsics());
pm.addPass(hlfir::createBufferizeHLFIR());
pm.addPass(hlfir::createConvertHLFIRtoFIR());
pm.addPass(fir::createOMPAllocatableArrayOptPass());

}

using DoConcurrentMappingKind =
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Optimizer/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ add_flang_library(FIRTransforms
AddDebugInfo.cpp
PolymorphicOpConversion.cpp
LoopVersioning.cpp
OMPAllocatableArrayOpt.cpp
StackReclaim.cpp
VScaleAttr.cpp
FunctionAttr.cpp
Expand Down
324 changes: 324 additions & 0 deletions flang/lib/Optimizer/Transforms/OMPAllocatableArrayOpt.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,324 @@
//===- OMPAllocatableArrayOpt.cpp
//-------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements transforms to filter out functions intended for the host
// when compiling for the device and vice versa.
//
//===----------------------------------------------------------------------===//
#include "flang/Lower/Support/Utils.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
#include "flang/Optimizer/Transforms/Passes.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
#include "mlir/IR/BuiltinOps.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Frontend/OpenMP/OMPConstants.h"

namespace fir {
#define GEN_PASS_DEF_OMPALLOCATABLEARRAYOPT
#include "flang/Optimizer/Transforms/Passes.h.inc"
} // namespace fir

using namespace mlir;

namespace {

class OMPAllocatableArrayOptPass
: public fir::impl::OMPAllocatableArrayOptBase<OMPAllocatableArrayOptPass> {
struct ArrayBound {
Value lowerBound;
Value upperBound;
};
struct AllocatableArrayDescriptorItems {
omp::MapInfoOp basePtrMapInfo;
Value basePtrKernelArg;
llvm::SmallVector<ArrayBound> boundDesc;
};
llvm::DenseMap<fir::DeclareOp, AllocatableArrayDescriptorItems>
declareDescriptorMap;

std::optional<size_t> getNumberOfArrayDim(fir::ArrayCoorOp arrayCoorOp) {
// TODO: Can we optimize such fir.coor_arr ops?
if (!arrayCoorOp.getShape()) {
return {};
}
fir::ShapeShiftOp shapeShiftOp =
dyn_cast<fir::ShapeShiftOp>(arrayCoorOp.getShape().getDefiningOp());
if (!shapeShiftOp)
return {};
return shapeShiftOp.getExtents().size();
}

fir::DeclareOp findAllocatableDeclareOp(fir::ArrayCoorOp arrayCoorOp) {
// TODO: Can we optimize such fir.coor_arr ops?
if (!arrayCoorOp.getShape()) {
return nullptr;
}
fir::ShapeShiftOp shapeShiftOp =
dyn_cast<fir::ShapeShiftOp>(arrayCoorOp.getShape().getDefiningOp());
if (!shapeShiftOp)
return nullptr;
fir::BoxAddrOp boxAddrOp =
dyn_cast<fir::BoxAddrOp>(arrayCoorOp.getMemref().getDefiningOp());
if (!boxAddrOp)
return nullptr;
if (!isa<fir::HeapType>(boxAddrOp.getType()))
return nullptr;
fir::LoadOp loadOp =
dyn_cast<fir::LoadOp>(boxAddrOp.getVal().getDefiningOp());
if (!loadOp)
return nullptr;
return dyn_cast<fir::DeclareOp>(loadOp.getMemref().getDefiningOp());
}
void eraseNotUsed(fir::ArrayCoorOp arrayCoorOp,
omp::MapClauseOwningOpInterface mapClauseOwner,
Block *targetEntryBlock) {
fir::ShapeShiftOp shapeShiftOp =
dyn_cast<fir::ShapeShiftOp>(arrayCoorOp.getShape().getDefiningOp());
fir::BoxAddrOp boxAddrOp =
dyn_cast<fir::BoxAddrOp>(arrayCoorOp.getMemref().getDefiningOp());
assert(arrayCoorOp->use_empty());
arrayCoorOp.erase();
std::vector<Value> shapeShiftOpExtents(shapeShiftOp.getExtents());
if (shapeShiftOp->use_empty())
shapeShiftOp.erase();
for (size_t i = 0; i < shapeShiftOpExtents.size(); ++i) {
Value shapeVal = shapeShiftOpExtents[i];
if (shapeVal.use_empty())
shapeVal.getDefiningOp()->erase();
}

fir::LoadOp loadOp =
dyn_cast<fir::LoadOp>(boxAddrOp.getVal().getDefiningOp());
if (boxAddrOp->use_empty())
boxAddrOp.erase();

fir::DeclareOp declareOp =
dyn_cast<fir::DeclareOp>(loadOp.getMemref().getDefiningOp());
if (loadOp->use_empty())
loadOp.erase();
OperandRange mapVarsArr = mapClauseOwner.getMapVarsMutable();
assert(mapVarsArr.size() == targetEntryBlock->getNumArguments());
for (size_t i = 0; i < targetEntryBlock->getNumArguments(); ++i) {
if (targetEntryBlock->getArgument(i) == declareOp.getMemref()) {
omp::MapInfoOp mapInfo =
dyn_cast<omp::MapInfoOp>(mapVarsArr[i].getDefiningOp());
if (declareOp->use_empty()) {
declareOp.erase();
targetEntryBlock->eraseArgument(i);
mapClauseOwner.getMapVarsMutable().erase(i);
mapInfo.erase();
}
break;
}
}
}

AllocatableArrayDescriptorItems getAllocatableArrayDescriptorItems(
fir::DeclareOp declareOp, omp::MapClauseOwningOpInterface mapClauseOwner,
Block *targetEntryBlock, size_t numberOfDims) {
AllocatableArrayDescriptorItems descriptorItems;
OperandRange mapVarsArr = mapClauseOwner.getMapVars();
assert(mapVarsArr.size() == targetEntryBlock->getNumArguments());
Operation *mapItemVarPtr;
for (size_t i = 0; i < targetEntryBlock->getNumArguments(); ++i) {
if (targetEntryBlock->getArgument(i) == declareOp.getMemref()) {
omp::MapInfoOp mapInfo =
dyn_cast<omp::MapInfoOp>(mapVarsArr[i].getDefiningOp());
mapItemVarPtr = mapInfo.getVarPtr().getDefiningOp();
assert(mapInfo && (mapInfo.getMembers().size() == 1) &&
"Expected only base addr ptr");
descriptorItems.basePtrMapInfo =
dyn_cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
break;
}
}
for (size_t index = 0; index < mapVarsArr.size(); ++index) {
if (descriptorItems.basePtrMapInfo == mapVarsArr[index].getDefiningOp()) {
descriptorItems.basePtrKernelArg = targetEntryBlock->getArgument(index);
}
}
assert(descriptorItems.basePtrMapInfo && "Expected base ptr map info");
assert(descriptorItems.basePtrKernelArg &&
"Expected base ptr kernel argument");
return descriptorItems;
}

void rewriteMapInfo(AllocatableArrayDescriptorItems &descriptorItem,
omp::MapClauseOwningOpInterface mapClauseOwner,
Block *targetEntryBlock, size_t numberOfDims) {
OperandRange mapVarsArr = mapClauseOwner.getMapVars();
omp::MapInfoOp mapInfo = descriptorItem.basePtrMapInfo;
size_t index;
for (index = 0; index < mapVarsArr.size(); ++index) {
if (descriptorItem.basePtrMapInfo == mapVarsArr[index].getDefiningOp()) {
break;
}
}
assert(mapInfo);
OpBuilder opBuilder(mapInfo);
fir::FirOpBuilder builder(opBuilder, mapInfo);
Operation *op = opBuilder.create<omp::MapInfoOp>(
mapInfo->getLoc(), mapInfo.getType(), mapInfo.getVarPtrPtr(),
TypeAttr::get(mapInfo.getVarType()), mapInfo.getVarPtrPtr(),
llvm::SmallVector<Value>{}, ArrayAttr{},
llvm::SmallVector<Value>(mapInfo.getBounds()),
opBuilder.getIntegerAttr(opBuilder.getIntegerType(64, false),
mapInfo.getMapType().value()),
opBuilder.getAttr<omp::VariableCaptureKindAttr>(
mapInfo.getMapCaptureType().value()),
opBuilder.getStringAttr(""), opBuilder.getBoolAttr(false));

mapInfo.replaceAllUsesWith(op);
mapVarsArr[index] = op->getResult(0);

OpBuilder::InsertPoint insPt = builder.saveInsertionPoint();
Block *allocaBlock = builder.getAllocaBlock();
assert(allocaBlock && "No alloca block found for this top level op");
llvm::SmallVector<Value> newMapOps;
for (size_t i = 0; i < mapVarsArr.size(); ++i) {
newMapOps.push_back(mapVarsArr[i]);
}
size_t mapArgumentIndex = mapVarsArr.size();
for (size_t dim = 0; dim < numberOfDims; ++dim) {
descriptorItem.boundDesc.push_back({});
for (size_t i = 0; i < 2; ++i) {
builder.setInsertionPointToStart(allocaBlock);
auto alloca = builder.create<fir::AllocaOp>(mapInfo->getLoc(),
builder.getIntegerType(64));
builder.restoreInsertionPoint(insPt);
auto dimVal = builder.createIntegerConstant(
mapInfo->getLoc(), builder.getIndexType(), dim);
Value allocatableDescriptor =
builder.create<fir::LoadOp>(mapInfo->getLoc(), mapInfo.getVarPtr());
auto dimInfo = builder.create<fir::BoxDimsOp>(
mapInfo->getLoc(), builder.getIndexType(), builder.getIndexType(),
builder.getIndexType(), allocatableDescriptor, dimVal);
Value bound =
builder.createConvert(mapInfo->getLoc(), builder.getIntegerType(64),
dimInfo->getResult(i));
opBuilder.create<fir::StoreOp>(mapInfo->getLoc(), bound, alloca);
omp::VariableCaptureKind captureKind = omp::VariableCaptureKind::ByCopy;

Operation *newMapItem = opBuilder.create<omp::MapInfoOp>(
mapInfo->getLoc(), alloca.getType(), alloca, TypeAttr::get(bound.getType()),
Value{}, llvm::SmallVector<Value>{}, ArrayAttr{},
llvm::SmallVector<Value>{},
opBuilder.getIntegerAttr(
opBuilder.getIntegerType(64, false),
llvm::to_underlying(
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT)),
builder.getAttr<omp::VariableCaptureKindAttr>(captureKind),
opBuilder.getStringAttr(""), opBuilder.getBoolAttr(false));

newMapOps.push_back(newMapItem->getResult(0));
targetEntryBlock->insertArgument(mapArgumentIndex,
newMapItem->getResult(0).getType(),
newMapItem->getLoc());
if (i == 0) {
descriptorItem.boundDesc[dim].lowerBound =
targetEntryBlock->getArgument(mapArgumentIndex);
} else if (i == 1) {
descriptorItem.boundDesc[dim].upperBound =
targetEntryBlock->getArgument(mapArgumentIndex);
}
mapArgumentIndex++;
}
}

mapClauseOwner.getMapVarsMutable().assign(newMapOps);
mapInfo.erase();
}

void rewriteArrayCoorOp(fir::ArrayCoorOp arrayCoorOp,
AllocatableArrayDescriptorItems &descriptorItem) {
OpBuilder opBuilder(arrayCoorOp);
fir::FirOpBuilder builder(opBuilder, arrayCoorOp);
Value addr = builder.createConvert(arrayCoorOp.getLoc(),
arrayCoorOp.getMemref().getType(),
descriptorItem.basePtrKernelArg);
llvm::SmallVector<Value> lbounds;
llvm::SmallVector<Value> ubounds;

for (size_t dim = 0; dim < descriptorItem.boundDesc.size(); dim++) {
#if 0
//Experiment - provide bound information in compile time
Value lb = descriptorItem.boundDesc[dim].lowerBound;
lbounds.push_back(builder.createIntegerConstant(lb.getLoc(),builder.getIndexType(), 1));
ubounds.push_back(builder.createIntegerConstant(lb.getLoc(),builder.getIndexType(), 100));
#else
Value lb = descriptorItem.boundDesc[dim].lowerBound;
Value lbVal = builder.create<fir::LoadOp>(lb.getLoc(), lb);
Value lbValConvert =
builder.createConvert(lb.getLoc(), builder.getIndexType(), lbVal);
lbounds.push_back(lbValConvert);
Value ub = descriptorItem.boundDesc[dim].upperBound;
Value ubVal = builder.create<fir::LoadOp>(ub.getLoc(), ub);
Value ubValConvert =
builder.createConvert(ub.getLoc(), builder.getIndexType(), ubVal);
ubounds.push_back(ubValConvert);
#endif
}

auto shapeShiftArgs = flatZip(lbounds, ubounds);
auto shapeTy =
fir::ShapeShiftType::get(arrayCoorOp->getContext(), lbounds.size());
Value shapeShift = builder.create<fir::ShapeShiftOp>(
arrayCoorOp.getLoc(), shapeTy, shapeShiftArgs);
Value optimizedArrayCoorOp = builder.create<fir::ArrayCoorOp>(
arrayCoorOp.getLoc(), arrayCoorOp.getType(), addr, shapeShift,
arrayCoorOp.getSlice(), arrayCoorOp.getIndices(),
arrayCoorOp.getTypeparams());
arrayCoorOp.replaceAllUsesWith(optimizedArrayCoorOp);
}

public:
OMPAllocatableArrayOptPass() = default;

void runOnOperation() override {
func::FuncOp func = getOperation();
declareDescriptorMap.clear();
func->walk<WalkOrder::PreOrder>([&](omp::TargetOp targetOp) {
auto mapClauseOwner = llvm::dyn_cast<omp::MapClauseOwningOpInterface>(
targetOp.getOperation());
Block *entryBlock = &targetOp->getRegion(0).front();
if (mapClauseOwner) {
OperandRange mapVarsArr = mapClauseOwner.getMapVars();
targetOp->walk<WalkOrder::PreOrder>([&](fir::ArrayCoorOp arrayCoorOp) {
fir::DeclareOp declareOp = findAllocatableDeclareOp(arrayCoorOp);
std::optional<size_t> numberOfArrayDim =
getNumberOfArrayDim(arrayCoorOp);
if (!numberOfArrayDim.has_value())
return;
if (!declareOp)
return;
if (!declareDescriptorMap.contains(declareOp)) {
declareDescriptorMap[declareOp] =
getAllocatableArrayDescriptorItems(
declareOp, mapClauseOwner, entryBlock, *numberOfArrayDim);
rewriteMapInfo(declareDescriptorMap[declareOp], mapClauseOwner,
entryBlock, *numberOfArrayDim);
}
rewriteArrayCoorOp(arrayCoorOp, declareDescriptorMap[declareOp]);
eraseNotUsed(arrayCoorOp, mapClauseOwner, entryBlock);
});
;
}
});
}
};
} // namespace

std::unique_ptr<Pass> fir::createOMPAllocatableArrayOptPass() {
return std::make_unique<OMPAllocatableArrayOptPass>();
}

0 comments on commit 4c15d4c

Please sign in to comment.