Skip to content

Commit

Permalink
[mlir][TilingInterface] Avoid looking at operands for getting slices …
Browse files Browse the repository at this point in the history
…to continue tile + fuse. (llvm#107882)

Current implementation of `scf::tileConsumerAndFuseProducerUsingSCF`
looks at operands of tiled/tiled+fused operations to see if they are
produced by `extract_slice` operations to populate the worklist used to
continue fusion. This implicit assumption does not always work. Instead
make the implementations of `getTiledImplementation` return the slices
to use to continue fusion.

This is a breaking change

- To continue to get the same behavior of
`scf::tileConsumerAndFuseProducerUsingSCF`, change all out-of-tree
implementation of `TilingInterface::getTiledImplementation` to return
the slices to continue fusion on. All in-tree implementations have been
adapted to this.
- This change touches parts that required a simplification to the
`ControlFn` in `scf::SCFTileAndFuseOptions`. It now returns a
`std::optional<scf::SCFTileAndFuseOptions::ControlFnResult>` object that
should be `std::nullopt` if fusion is not to be performed.

Signed-off-by: MaheshRavishankar <[email protected]>
  • Loading branch information
MaheshRavishankar authored and mgehre-amd committed Nov 5, 2024
1 parent 7489677 commit d70a2f8
Show file tree
Hide file tree
Showing 10 changed files with 271 additions and 129 deletions.
11 changes: 6 additions & 5 deletions mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,12 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
/// at offsets `lbs` and with sizes `subShapeSizes`. `omitPartialTileCheck`
/// controls whether to omit the partial/boundary tile condition check in
/// cases where we statically know that it is unnecessary.
Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
ArrayRef<OpFoldResult> tileSizes, AffineMap map,
ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
ArrayRef<OpFoldResult> subShapeSizes,
bool omitPartialTileCheck);
Operation *makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
ArrayRef<OpFoldResult> tileSizes, AffineMap map,
ArrayRef<OpFoldResult> lbs,
ArrayRef<OpFoldResult> ubs,
ArrayRef<OpFoldResult> subShapeSizes,
bool omitPartialTileCheck);

/// Creates extract_slice/subview ops for all `valuesToTile` of the given
/// `linalgOp` with `builder`, assuming `linalgOp` is being fused into a loop
Expand Down
33 changes: 22 additions & 11 deletions mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ struct SCFTilingResult {
/// Values to use as replacements for the untiled op. Is the same size as the
/// number of results of the untiled op.
SmallVector<Value> replacements;
/// Slices generated after tiling that can be used for fusing with the tiled
/// producer.
SmallVector<Operation *> generatedSlices;
};

/// Method to tile an op that implements the `TilingInterface` using
Expand All @@ -129,18 +132,22 @@ struct SCFTileAndFuseOptions {
/// 2) the producer value that is to be fused
/// 3) a boolean value set to `true` if the fusion is from
/// a destination operand.
/// It retuns two booleans
/// - returns `true` if the fusion should be done through the candidate slice
/// - returns `true` if a replacement for the fused producer needs to be
/// yielded from within the tiled loop. Note that it is valid to return
/// `true` only if the slice fused is disjoint across all iterations of the
/// tiled loop. It is up to the caller to ensure that this is true for the
/// fused producers.
using ControlFnTy = std::function<std::tuple<bool, bool>(
/// The control function returns an `std::optiona<ControlFnResult>`.
/// If the return value is `std::nullopt`, that implies no fusion
/// is to be performed along that slice.
struct ControlFnResult {
/// Set to true if the loop nest has to return a replacement value
/// for the fused producer.
bool yieldProducerReplacement = false;
};
using ControlFnTy = std::function<std::optional<ControlFnResult>(
tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
bool isDestinationOperand)>;
ControlFnTy fusionControlFn = [](tensor::ExtractSliceOp, OpResult, bool) {
return std::make_tuple(true, false);
/// The default control function implements greedy fusion without yielding
/// a replacement for any of the fused results.
ControlFnTy fusionControlFn = [](tensor::ExtractSliceOp, OpResult,
bool) -> std::optional<ControlFnResult> {
return ControlFnResult{};
};
SCFTileAndFuseOptions &setFusionControlFn(ControlFnTy controlFn) {
fusionControlFn = controlFn;
Expand All @@ -156,6 +163,7 @@ struct SCFFuseProducerOfSliceResult {
OpResult origProducer; // Original untiled producer.
Value tiledAndFusedProducer; // Tile and fused producer value.
SmallVector<Operation *> tiledOps;
SmallVector<Operation *> generatedSlices;
};
std::optional<SCFFuseProducerOfSliceResult>
tileAndFuseProducerOfSlice(RewriterBase &rewriter,
Expand Down Expand Up @@ -215,7 +223,10 @@ tileAndFuseProducerOfSlice(RewriterBase &rewriter,
///
/// The @param `yieldResultNumber` decides which result would be yield. If not
/// given, yield all `opResult` of fused producer.
LogicalResult yieldReplacementForFusedProducer(
///
/// The method returns the list of new slices added during the process (which
/// can be used to fuse along).
FailureOr<SmallVector<Operation *>> yieldReplacementForFusedProducer(
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
MutableArrayRef<LoopLikeOpInterface> loops,
Expand Down
7 changes: 5 additions & 2 deletions mlir/include/mlir/Interfaces/TilingInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@ namespace mlir {

/// Container for result values of tiling.
/// - `tiledOps` contains operations created by the tiling implementation that
/// are returned to the caller for further transformations.
/// are returned to the caller for further transformations.
/// - `tiledValues` contains the tiled value corresponding to the result of the
/// untiled operation.
/// untiled operation.
/// - `generatedSlices` contains the list of slices that are generated during
/// tiling. These slices can be used for fusing producers.
struct TilingResult {
SmallVector<Operation *> tiledOps;
SmallVector<Value> tiledValues;
SmallVector<Operation *> generatedSlices;
};

/// Container for the result of merge operation of tiling.
Expand Down
82 changes: 54 additions & 28 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,20 +66,20 @@ static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v,

/// Returns a memref.subview or a tensor.extract_slice based on the type of the
/// `source`.
static Value getSlice(OpBuilder &b, Location loc, Value source,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides) {
return TypeSwitch<Type, Value>(source.getType())
.Case<RankedTensorType>([&](RankedTensorType t) -> Value {
static Operation *getSlice(OpBuilder &b, Location loc, Value source,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides) {
return TypeSwitch<Type, Operation *>(source.getType())
.Case<RankedTensorType>([&](RankedTensorType t) -> Operation * {
return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
strides);
})
.Case<MemRefType>([&](MemRefType type) -> Value {
.Case<MemRefType>([&](MemRefType type) -> Operation * {
return b.create<memref::SubViewOp>(loc, source, offsets, sizes,
strides);
})
.Default([&](Type t) { return nullptr; });
.Default([&](Type t) -> Operation * { return nullptr; });
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2599,18 +2599,29 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
auto oneAttr = builder.getI64IntegerAttr(1);
SmallVector<OpFoldResult> strides(rank, oneAttr);
SmallVector<Value> tiledOperands;
tiledOperands.emplace_back(
getSlice(builder, getLoc(), getInput(), offsets, sizes, strides));
tiledOperands.emplace_back(
getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides));
Operation *inputSlice =
getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
if (!inputSlice) {
return emitOpError("failed to compute input slice");
}
tiledOperands.emplace_back(inputSlice->getResult(0));
Operation *outputSlice =
getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
if (!outputSlice) {
return emitOpError("failed to compute output slice");
}
tiledOperands.emplace_back(outputSlice->getResult(0));

SmallVector<Type, 4> resultTypes;
if (hasPureTensorSemantics())
resultTypes.push_back(tiledOperands[1].getType());
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);

return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
return TilingResult{
{tiledOp},
SmallVector<Value>(tiledOp->getResults()),
llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
}

LogicalResult SoftmaxOp::getResultTilePosition(
Expand Down Expand Up @@ -2957,8 +2968,9 @@ FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
int64_t filterRank = getFilterOperandRank();
SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
Location loc = getLoc();
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
loc, getFilter(), sliceOffsets, sliceSizes, filterStrides));
auto filterSlice = builder.create<tensor::ExtractSliceOp>(
loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
tiledOperands.emplace_back(filterSlice);

SmallVector<OpFoldResult> resultOffsets, resultSizes;
if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
Expand All @@ -2967,15 +2979,19 @@ FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(

int64_t outputRank = getOutputOperandRank();
SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
loc, getOutput(), resultOffsets, resultSizes, outputStrides));
auto outputSlice = builder.create<tensor::ExtractSliceOp>(
loc, getOutput(), resultOffsets, resultSizes, outputStrides);
tiledOperands.emplace_back(outputSlice);

SmallVector<Type> resultTypes;
resultTypes.push_back(tiledOperands[1].getType());
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);

return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
return TilingResult{
{tiledOp},
SmallVector<Value>(tiledOp->getResults()),
llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -3124,8 +3140,9 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
{sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
int64_t inputRank = getInputOperandRank();
SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
loc, getInput(), sliceOffsets, sliceSizes, inputStrides));
auto inputSlice = builder.create<tensor::ExtractSliceOp>(
loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
tiledOperands.emplace_back(inputSlice);

SmallVector<OpFoldResult> resultOffsets, resultSizes;
if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
Expand All @@ -3134,15 +3151,19 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,

int64_t outputRank = getOutputOperandRank();
SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
loc, getOutput(), resultOffsets, resultSizes, outputStrides));
auto outputSlice = builder.create<tensor::ExtractSliceOp>(
loc, getOutput(), resultOffsets, resultSizes, outputStrides);
tiledOperands.emplace_back(outputSlice);

SmallVector<Type> resultTypes;
resultTypes.push_back(tiledOperands[1].getType());
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);

return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
return TilingResult{
{tiledOp},
SmallVector<Value>(tiledOp->getResults()),
llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -3286,8 +3307,9 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
sizes[getValueFDim()]});
int64_t valueRank = getValueOperandRank();
SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
loc, getValue(), sliceOffsets, sliceSizes, sliceStrides));
auto valueSlice = builder.create<tensor::ExtractSliceOp>(
loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
tiledOperands.emplace_back(valueSlice);

SmallVector<OpFoldResult> resultOffsets, resultSizes;
if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
Expand All @@ -3296,15 +3318,19 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(

int64_t outputRank = getOutputOperandRank();
SmallVector<OpFoldResult> strides(outputRank, oneAttr);
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
loc, getOutput(), resultOffsets, resultSizes, strides));
auto outputSlice = builder.create<tensor::ExtractSliceOp>(
loc, getOutput(), resultOffsets, resultSizes, strides);
tiledOperands.emplace_back(outputSlice);

SmallVector<Type> resultTypes;
resultTypes.push_back(tiledOperands[1].getType());
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);

return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
return TilingResult{
{tiledOp},
SmallVector<Value>(tiledOp->getResults()),
llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
}

//===----------------------------------------------------------------------===//
Expand Down
26 changes: 21 additions & 5 deletions mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,25 @@ struct LinalgOpTilingInterface
Location loc = op->getLoc();
LinalgOp linalgOp = cast<LinalgOp>(op);
SmallVector<Value> valuesToTile = linalgOp->getOperands();
SmallVector<Value, 4> tiledOperands = makeTiledShapes(
SmallVector<Value> tiledOperands = makeTiledShapes(
b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true);
SmallVector<Operation *> generatedSlices = llvm::map_to_vector(
llvm::make_filter_range(
tiledOperands,
[](Value v) -> bool {
return isa_and_nonnull<tensor::ExtractSliceOp, memref::SubViewOp>(
v.getDefiningOp());
}),
[](Value v) -> Operation * { return v.getDefiningOp(); });

SmallVector<Type> resultTensorTypes =
getTensorOutputTypes(linalgOp, tiledOperands);

Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands);
offsetIndices(b, cast<LinalgOp>(tiledOp), offsets);

return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
return TilingResult{
{tiledOp}, SmallVector<Value>(tiledOp->getResults()), generatedSlices};
}

/// Utility to fetch the offsets and sizes when applied as per the indexing
Expand Down Expand Up @@ -260,7 +269,8 @@ struct LinalgOpTilingInterface

return TilingResult{
tilingResult->tiledOps,
SmallVector<Value>{tilingResult->tiledValues[resultNumber]}};
SmallVector<Value>{tilingResult->tiledValues[resultNumber]},
tilingResult->generatedSlices};
}

/// Method to generate the tiled implementation of an operation from the tile
Expand Down Expand Up @@ -406,8 +416,12 @@ struct LinalgOpPartialReductionInterface
}

// Step 2a: Extract a slice of the input operands.
SmallVector<Value, 4> tiledInputs = makeTiledShapes(
SmallVector<Value> tiledInputs = makeTiledShapes(
b, loc, linalgOp, linalgOp.getDpsInputs(), offsets, sizes, {}, true);
SmallVector<Operation *> generatedSlices = llvm::map_to_vector(
llvm::make_filter_range(
tiledInputs, [](Value v) -> bool { return v.getDefiningOp(); }),
[](Value v) -> Operation * { return v.getDefiningOp(); });

// Step 2b: Extract a slice of the init operands.
SmallVector<Value, 1> tiledInits;
Expand All @@ -424,6 +438,7 @@ struct LinalgOpPartialReductionInterface
auto extractSlice = b.create<tensor::ExtractSliceOp>(
loc, valueToTile, initOffset, initSizes, initStride);
tiledInits.push_back(extractSlice);
generatedSlices.push_back(extractSlice);
}

// Update the indexing maps.
Expand Down Expand Up @@ -453,7 +468,8 @@ struct LinalgOpPartialReductionInterface
return TilingResult{
{genericOp.getOperation()},
llvm::map_to_vector(genericOp->getResults(),
[](OpResult r) -> Value { return r; })};
[](OpResult r) -> Value { return r; }),
generatedSlices};
}

FailureOr<MergeResult> mergeReductions(Operation *op, OpBuilder &b,
Expand Down
20 changes: 11 additions & 9 deletions mlir/lib/Dialect/Linalg/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -565,9 +565,9 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops");
}

static Value materializeTiledShape(OpBuilder &builder, Location loc,
Value valueToTile,
const SliceParameters &sliceParams) {
static Operation *materializeTiledShape(OpBuilder &builder, Location loc,
Value valueToTile,
const SliceParameters &sliceParams) {
auto shapedType = dyn_cast<ShapedType>(valueToTile.getType());
auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
.Case([&](MemRefType) {
Expand All @@ -583,14 +583,15 @@ static Value materializeTiledShape(OpBuilder &builder, Location loc,
.Default([](ShapedType) -> Operation * {
llvm_unreachable("Unexpected shaped type");
});
return sliceOp->getResult(0);
return sliceOp;
}

Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
ArrayRef<OpFoldResult> tileSizes, AffineMap map,
ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
ArrayRef<OpFoldResult> subShapeSizes,
bool omitPartialTileCheck) {
Operation *makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
ArrayRef<OpFoldResult> tileSizes, AffineMap map,
ArrayRef<OpFoldResult> lbs,
ArrayRef<OpFoldResult> ubs,
ArrayRef<OpFoldResult> subShapeSizes,
bool omitPartialTileCheck) {
SliceParameters sliceParams =
computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs,
ubs, subShapeSizes, omitPartialTileCheck);
Expand Down Expand Up @@ -841,6 +842,7 @@ SmallVector<Value> makeTiledShapes(OpBuilder &builder, Location loc,
tiledShapes.push_back(
sliceParams.has_value()
? materializeTiledShape(builder, loc, valueToTile, *sliceParams)
->getResult(0)
: valueToTile);
}
return tiledShapes;
Expand Down
Loading

0 comments on commit d70a2f8

Please sign in to comment.