Skip to content

Commit

Permalink
AIRSegmentLoopFusion: Make memref shrinkage support vectorized code (#…
Browse files Browse the repository at this point in the history
…565)

* Enable memref shrinkage to analyze access patterns for vector transfer_read/write

* (WiP) Apply shrinkage to vector reads and writes

* Improve code stability with subview in scf loop nest, vector read/write in scf loop nest, and linalg.generic in scf nest

* Tests
  • Loading branch information
erwei-xilinx authored May 8, 2024
1 parent 5a1991d commit c8be1f9
Show file tree
Hide file tree
Showing 4 changed files with 528 additions and 67 deletions.
5 changes: 5 additions & 0 deletions mlir/include/air/Util/Util.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"

Expand Down Expand Up @@ -178,6 +179,10 @@ std::tuple<SmallVector<Value>, SmallVector<Value>, SmallVector<Value>>
writeAccessPattern(air::ChannelInterface chanOp);
std::tuple<SmallVector<Value>, SmallVector<Value>, SmallVector<Value>>
writeAccessPattern(memref::SubViewOp subview);
std::tuple<SmallVector<Value>, SmallVector<Value>, SmallVector<Value>>
writeAccessPattern(mlir::vector::TransferReadOp readOp);
std::tuple<SmallVector<Value>, SmallVector<Value>, SmallVector<Value>>
writeAccessPattern(mlir::vector::TransferWriteOp writeOp);
SmallVector<int64_t>
getDataAccessShapeFromMemcpyOp(Value memref,
SmallVector<air::ChannelInterface> chanUsers);
Expand Down
240 changes: 173 additions & 67 deletions mlir/lib/Transform/AIRDependencyScheduleOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IntegerSet.h"
Expand Down Expand Up @@ -3746,79 +3747,48 @@ struct ShrinkMemrefSizesByAccessPattern
if (boundsAreAllOnes)
return failure(); // Memref access pattern analysis failed
if (shrinkMemref) {
// Start shrinking memref.
// Shrink access patterns to memref.
for (auto user : users) {
// Update access patterns to shrunk memref from air.channel puts and
// gets.
if (!isa<air::ChannelInterface>(user))
continue;
auto chanOp = dyn_cast<air::ChannelInterface>(user);
rewriter.setInsertionPoint(chanOp);
// Update offsets.
auto new_offsets = getUpdatedOffsetsAfterShrinkage(
memref_shape, overall_access_bounds, chanOp.getOffsets());
int offsetListIdxOffset =
dyn_cast<air::AsyncOpInterface>(chanOp.getOperation())
.getAsyncDependencies()
.size() +
chanOp.getIndices().size() + 1;
for (unsigned i = offsetListIdxOffset;
i < offsetListIdxOffset + chanOp.getOffsets().size(); i++) {
if (new_offsets[i - offsetListIdxOffset] < 0)
continue;
chanOp->getOpOperand(i).assign(
rewriter.create<arith::ConstantIndexOp>(
chanOp->getLoc(), new_offsets[i - offsetListIdxOffset]));
}
// Update strides.
auto new_strides = getUpdatedStridesAfterShrinkage(
memref_shape, overall_access_bounds, chanOp.getStrides());
int strideListIdxOffset = offsetListIdxOffset +
chanOp.getOffsets().size() +
chanOp.getSizes().size();
for (unsigned i = strideListIdxOffset;
i < strideListIdxOffset + chanOp.getStrides().size(); i++) {
chanOp->getOpOperand(i).assign(
rewriter.create<arith::ConstantIndexOp>(
chanOp->getLoc(), new_strides[i - strideListIdxOffset]));
if (!chanOp)
continue;
if (updateAccessPatternAfterShrinkage(chanOp, memref_shape,
overall_access_bounds, rewriter)
.failed()) {
alloc->setAttr("shrinkage", rewriter.getBoolAttr(false));
return failure();
}
}
for (auto user : users) {
// Update access patterns to shrunk memref from memref.subview.
if (!isa<memref::SubViewOp>(user))
continue;
auto subViewOp = dyn_cast<memref::SubViewOp>(user);
auto subview_sizes = subViewOp.getSizes().begin();
auto subview_strides = subViewOp.getStrides().begin();
auto static_sizes = subViewOp.getStaticSizes();
auto static_strides = subViewOp.getStaticStrides();
for (unsigned i = 0; i < static_sizes.size(); i++) {
if (static_sizes[i] < 0) {
if (*getConstantIntValue(*subview_sizes++) !=
overall_access_bounds[i])
return failure(); // Memref shrinkage attempting to mutate
// memref.subview, NYI.
} else {
if (static_sizes[i] != overall_access_bounds[i])
return failure(); // Memref shrinkage attempting to mutate
// memref.subview, NYI.
}
if (!subViewOp)
continue;
if (updateAccessPatternAfterShrinkage(subViewOp, users,
overall_access_bounds, rewriter)
.failed()) {
alloc->setAttr("shrinkage", rewriter.getBoolAttr(false));
return failure();
}
for (unsigned i = 0; i < static_strides.size(); i++) {
if (static_strides[i] < 0) {
if (*getConstantIntValue(*subview_strides++) != 1)
return failure(); // Memref shrinkage attempting to mutate
// memref.subview, NYI.
} else {
if (static_strides[i] != 1)
return failure(); // Memref shrinkage attempting to mutate
// memref.subview, NYI.
}
for (auto user : users) {
// Update access patterns to shrunk memref from
// vector.transfer_read/write.
auto transReadOp = dyn_cast<vector::TransferReadOp>(user);
auto transWriteOp = dyn_cast<vector::TransferWriteOp>(user);
if (transReadOp) {
if (updateAccessPatternAfterShrinkage(transReadOp, rewriter)
.failed()) {
alloc->setAttr("shrinkage", rewriter.getBoolAttr(false));
return failure();
}
} else if (transWriteOp) {
if (updateAccessPatternAfterShrinkage(transWriteOp, rewriter)
.failed()) {
alloc->setAttr("shrinkage", rewriter.getBoolAttr(false));
return failure();
}
}
subViewOp.getResult().replaceAllUsesWith(subViewOp.getSource());
rewriter.eraseOp(subViewOp);
for (auto newUser : subViewOp.getSource().getUsers())
push_back_if_unique<Operation *>(users, newUser);
}

// Replace memref alloc op;
Expand Down Expand Up @@ -3879,22 +3849,27 @@ struct ShrinkMemrefSizesByAccessPattern
for (auto user : memref.getUsers()) {
if (auto da = dyn_cast<memref::DeallocOp>(user))
dealloc = da;
else if (auto chanOp = dyn_cast<air::ChannelInterface>(user)) {
else if (isa<air::ChannelInterface>(user))
users.push_back(user);
else if (isa<memref::SubViewOp>(user))
users.push_back(user);
else if (isa<mlir::vector::TransferReadOp>(user))
users.push_back(user);
} else if (auto subviewOp = dyn_cast<memref::SubViewOp>(user)) {
else if (isa<mlir::vector::TransferWriteOp>(user))
users.push_back(user);
} else if (auto herdOp = dyn_cast<air::HerdOp>(user)) {
else if (auto herdOp = dyn_cast<air::HerdOp>(user)) {
for (unsigned i = 0; i < herdOp.getNumKernelOperands(); i++) {
if (herdOp.getKernelOperand(i) == memref) {
auto memrefInHerd = herdOp.getKernelArgument(i);
if (getAllChanUsers(memrefInHerd, users, dealloc, builder).failed())
return failure();
}
}

} else
return failure(); // NYI.
}
if (users.empty())
return failure();
return success();
}

Expand All @@ -3914,6 +3889,137 @@ struct ShrinkMemrefSizesByAccessPattern
herdOp.getBody().front().eraseArgument(herdOp.getNumDims() * 2 + i + 1);
}
}

// Update access patterns to shrunk memref from air.channel puts and gets.
LogicalResult
updateAccessPatternAfterShrinkage(air::ChannelInterface chanOp,
SmallVector<int> memref_shape,
SmallVector<int64_t> overall_access_bounds,
PatternRewriter &rewriter) const {
rewriter.setInsertionPoint(chanOp);
// Update offsets.
auto new_offsets = getUpdatedOffsetsAfterShrinkage(
memref_shape, overall_access_bounds, chanOp.getOffsets());
int offsetListIdxOffset =
dyn_cast<air::AsyncOpInterface>(chanOp.getOperation())
.getAsyncDependencies()
.size() +
chanOp.getIndices().size() + 1;
for (unsigned i = offsetListIdxOffset;
i < offsetListIdxOffset + chanOp.getOffsets().size(); i++) {
if (new_offsets[i - offsetListIdxOffset] < 0)
continue;
chanOp->getOpOperand(i).assign(rewriter.create<arith::ConstantIndexOp>(
chanOp->getLoc(), new_offsets[i - offsetListIdxOffset]));
}
// Update strides.
auto new_strides = getUpdatedStridesAfterShrinkage(
memref_shape, overall_access_bounds, chanOp.getStrides());
int strideListIdxOffset = offsetListIdxOffset + chanOp.getOffsets().size() +
chanOp.getSizes().size();
for (unsigned i = strideListIdxOffset;
i < strideListIdxOffset + chanOp.getStrides().size(); i++) {
chanOp->getOpOperand(i).assign(rewriter.create<arith::ConstantIndexOp>(
chanOp->getLoc(), new_strides[i - strideListIdxOffset]));
}
return success();
}

// Update access patterns to shrunk memref from memref.subview.
LogicalResult
updateAccessPatternAfterShrinkage(memref::SubViewOp subViewOp,
SmallVector<Operation *> &users,
SmallVector<int64_t> overall_access_bounds,
PatternRewriter &rewriter) const {
rewriter.setInsertionPoint(subViewOp);
auto subview_sizes = subViewOp.getSizes().begin();
auto subview_strides = subViewOp.getStrides().begin();
auto static_sizes = subViewOp.getStaticSizes();
auto static_strides = subViewOp.getStaticStrides();
// Get MemRefType after shrinkage.
Type elemType =
subViewOp.getSource().getType().cast<MemRefType>().getElementType();
Attribute memorySpace =
subViewOp.getSource().getType().cast<MemRefType>().getMemorySpace();
auto shrunkMemrefType =
MemRefType::get(overall_access_bounds, elemType, nullptr, memorySpace);
MemRefType inferredSubViewOutputTy =
memref::SubViewOp::inferResultType(
shrunkMemrefType, subViewOp.getStaticOffsets(),
subViewOp.getStaticSizes(), subViewOp.getStaticStrides())
.cast<MemRefType>();
for (unsigned i = 0; i < static_sizes.size(); i++) {
if (static_sizes[i] < 0) {
if (*getConstantIntValue(*subview_sizes++) !=
overall_access_bounds[i]) {
subViewOp.getResult().setType(inferredSubViewOutputTy);
return updateAccessPatternAfterShrinkage(subViewOp.getOffsets(),
rewriter);
}
} else {
if (static_sizes[i] != overall_access_bounds[i]) {
subViewOp.getResult().setType(inferredSubViewOutputTy);
return updateAccessPatternAfterShrinkage(subViewOp.getOffsets(),
rewriter);
}
}
}
for (unsigned i = 0; i < static_strides.size(); i++) {
if (static_strides[i] < 0) {
if (*getConstantIntValue(*subview_strides++) != 1) {
subViewOp.getResult().setType(inferredSubViewOutputTy);
return updateAccessPatternAfterShrinkage(subViewOp.getOffsets(),
rewriter);
}
} else {
if (static_strides[i] != 1) {
subViewOp.getResult().setType(inferredSubViewOutputTy);
return updateAccessPatternAfterShrinkage(subViewOp.getOffsets(),
rewriter);
}
}
}
subViewOp.getResult().replaceAllUsesWith(subViewOp.getSource());
rewriter.eraseOp(subViewOp);
for (auto newUser : subViewOp.getSource().getUsers())
push_back_if_unique<Operation *>(users, newUser);
return updateAccessPatternAfterShrinkage(subViewOp.getOffsets(), rewriter);
}

// Update access patterns to shrunk memref from vector.transfer_read/write.
LogicalResult
updateAccessPatternAfterShrinkage(SmallVector<Value> indices,
PatternRewriter &rewriter) const {
for (auto index : indices) {
if (getConstantIntValue(index))
continue;
if (!index.getDefiningOp())
continue;
if (auto execOp = dyn_cast<air::ExecuteOp>(index.getDefiningOp())) {
for (auto oper : execOp.getChildOp()->getOperands()) {
if (auto herdOp = air::getHerdArgOwner(oper)) {
rewriter.setInsertionPointToStart(&herdOp.getBody().front());
execOp.getChildOp()->replaceUsesOfWith(
oper, rewriter.create<arith::ConstantIndexOp>(
rewriter.getUnknownLoc(), 0));
}
}
}
}
return success();
}
LogicalResult
updateAccessPatternAfterShrinkage(vector::TransferReadOp transReadOp,
PatternRewriter &rewriter) const {
return updateAccessPatternAfterShrinkage(transReadOp.getIndices(),
rewriter);
}
LogicalResult
updateAccessPatternAfterShrinkage(vector::TransferWriteOp transWriteOp,
PatternRewriter &rewriter) const {
return updateAccessPatternAfterShrinkage(transWriteOp.getIndices(),
rewriter);
}
};

// A pass which performs loop fusion within air.segment op's region.
Expand Down
64 changes: 64 additions & 0 deletions mlir/lib/Util/Util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1100,6 +1100,36 @@ SmallVector<int64_t> air::getDataAccessShapeFromMemcpyOp(
return overall_access_bounds;
}

void updateAccessPatternByScfForNest(
std::tuple<SmallVector<Value>, SmallVector<Value>, SmallVector<Value>>
&pattern,
SmallVector<Value> indices, OpBuilder builder) {
auto loc = builder.getUnknownLoc();
auto updateWrapAndStride = [&](Value index, int i) {
if (auto scfForOp = scf::getForInductionVarOwner(index)) {
std::get<1>(pattern)[i] = builder.create<arith::ConstantIndexOp>(
loc, *air::getStaticScfForTripCountAsInt(scfForOp));
std::get<2>(pattern)[i] = builder.create<arith::ConstantIndexOp>(
loc, (*getConstantIntValue(scfForOp.getStep())) *
(*getConstantIntValue(std::get<2>(pattern)[i])));

scfForOp.getStep();
}
};
int dim = -1;
for (auto index : indices) {
dim++;
if (getConstantIntValue(index))
continue;
updateWrapAndStride(index, dim);
if (!index.getDefiningOp())
continue;
if (auto execOp = dyn_cast<air::ExecuteOp>(index.getDefiningOp()))
for (auto oper : execOp.getChildOp()->getOperands())
updateWrapAndStride(oper, dim);
}
}

std::tuple<SmallVector<Value>, SmallVector<Value>, SmallVector<Value>>
air::writeAccessPattern(air::ChannelInterface chanOp) {
std::tuple<SmallVector<Value>, SmallVector<Value>, SmallVector<Value>>
Expand Down Expand Up @@ -1151,8 +1181,38 @@ air::writeAccessPattern(memref::SubViewOp subview) {
else
std::get<2>(pattern).push_back(*subview_strides++);
}
updateAccessPatternByScfForNest(pattern, std::get<0>(pattern), builder);
return pattern;
}

std::tuple<SmallVector<Value>, SmallVector<Value>, SmallVector<Value>>
air::writeAccessPattern(mlir::vector::TransferReadOp readOp) {
OpBuilder builder(readOp);
std::tuple<SmallVector<Value>, SmallVector<Value>, SmallVector<Value>>
pattern;
auto memrefTy = readOp.getSource().getType().cast<MemRefType>();
assert(memrefTy && "Not a memref");
populateDefaultWrapsAndStrides(builder, readOp.getSource(),
std::get<0>(pattern), std::get<1>(pattern),
std::get<2>(pattern));
updateAccessPatternByScfForNest(pattern, readOp.getIndices(), builder);
return pattern;
}

std::tuple<SmallVector<Value>, SmallVector<Value>, SmallVector<Value>>
air::writeAccessPattern(mlir::vector::TransferWriteOp writeOp) {
OpBuilder builder(writeOp);
std::tuple<SmallVector<Value>, SmallVector<Value>, SmallVector<Value>>
pattern;
auto memrefTy = writeOp.getSource().getType().cast<MemRefType>();
assert(memrefTy && "Not a memref");
populateDefaultWrapsAndStrides(builder, writeOp.getSource(),
std::get<0>(pattern), std::get<1>(pattern),
std::get<2>(pattern));
updateAccessPatternByScfForNest(pattern, writeOp.getIndices(), builder);
return pattern;
}

SmallVector<int64_t> air::getDataAccessShapeFromMemcpyOp(
Value memref, SmallVector<air::ChannelInterface> chanUsers) {
SmallVector<
Expand All @@ -1174,6 +1234,10 @@ air::getDataAccessShapeFromMemcpyOp(Value memref,
accessPatterns.push_back(writeAccessPattern(chanUser));
else if (auto svUser = dyn_cast<memref::SubViewOp>(user))
accessPatterns.push_back(writeAccessPattern(svUser));
else if (auto vecReadUser = dyn_cast<mlir::vector::TransferReadOp>(user))
accessPatterns.push_back(writeAccessPattern(vecReadUser));
else if (auto vecWriteUser = dyn_cast<mlir::vector::TransferWriteOp>(user))
accessPatterns.push_back(writeAccessPattern(vecWriteUser));
}
return getDataAccessShapeFromMemcpyOp(memref, accessPatterns);
}
Expand Down
Loading

0 comments on commit c8be1f9

Please sign in to comment.