Skip to content

Commit

Permalink
[GlobalOpt][CPU] Move to using indexing maps for data tiling encoding…
Browse files Browse the repository at this point in the history
…s instead of named op enums (iree-org#15984)

This PR adds a `user_indexing_maps` attribute to `linalg_ext.encoding`, and
uses this attribute in MaterializeEncoding in place of the case-by-case
enums for matmul and batch_matmul. This will enable data tiling on
transposed matmul cases like `linalg.matmul_transpose_a`, and is a step
towards data-tiling of `linalg.generic` contraction ops.

In SetEncoding, the `user_indexing_maps` attribute is set, containing the
indexing maps of the LHS, RHS, and RESULT of the op to be data-tiled.
The case-by-case checks are removed by this PR, and transposed
`linalg::ContractionOpInterface` ops are allowed to get encodings. The
`MATMUL` and `BATCH_MATMUL` user encodings are kept for now, but will
eventually be removed.

In MaterializeEncoding, the `user_indexing_maps` are used to infer the
contraction dimensions (M, N, K, Batch) of the inputs, and a
`tensor.pack` op is created with appropriate `inner_dims_pos` and
`outer_dims_perm` to transpose and pack the input into the canonical
`linalg.mmt4d` input shapes.
  • Loading branch information
Max191 authored Jan 11, 2024
1 parent 17e9529 commit c27ed41
Show file tree
Hide file tree
Showing 8 changed files with 1,536 additions and 784 deletions.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

177 changes: 119 additions & 58 deletions compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/EncodingUtils.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"

namespace mlir::iree_compiler {

Expand Down Expand Up @@ -70,6 +72,85 @@ EncodingAttr getEncodingAttr(RankedTensorType type) {
return type.getEncoding().dyn_cast_or_null<EncodingAttr>();
}

static AffineMap getMapForRole(EncodingAttr encoding) {
EncodingRole role = encoding.getRole().getValue();
if (role == EncodingRole::LHS)
return cast<AffineMapAttr>(encoding.getUserIndexingMaps()[0])
.getAffineMap();
else if (role == EncodingRole::RHS)
return cast<AffineMapAttr>(encoding.getUserIndexingMaps()[1])
.getAffineMap();
else
return cast<AffineMapAttr>(encoding.getUserIndexingMaps()[2])
.getAffineMap();
}

static FailureOr<linalg::ContractionDimensions>
getEncodingContractionDims(EncodingAttr encoding) {
auto indexingMapsAttr = encoding.getUserIndexingMaps();
SmallVector<AffineMap> indexingMaps = llvm::map_to_vector(
indexingMapsAttr.getValue(), [](Attribute m) -> AffineMap {
return cast<AffineMapAttr>(m).getAffineMap();
});
return linalg::inferContractionDims(indexingMaps);
}

/// Given the dim position of the encoding `user_indexing_maps`, return the
/// matching index of the given encoding's tensor
static unsigned mapDimToRoleIndex(int64_t dimPos, EncodingAttr encoding) {
AffineMap map = getMapForRole(encoding);
auto idx = map.getResultPosition(getAffineDimExpr(dimPos, map.getContext()));
assert(idx.has_value());
return idx.value();
}

std::optional<SmallVector<int64_t>>
getPermutationToCanonicalMatmulShape(EncodingAttr encoding) {
FailureOr<linalg::ContractionDimensions> cDims =
getEncodingContractionDims(encoding);
if (failed(cDims)) {
return std::nullopt;
}
// Only support at most 1 Batch, M, N, K dimensions for now
if (cDims->m.size() > 1 || cDims->n.size() > 1 || cDims->k.size() > 1 ||
cDims->batch.size() > 1) {
return std::nullopt;
}
SmallVector<int64_t> perm;
EncodingRole role = encoding.getRole().getValue();
EncodingUser user = encoding.getUser().getValue();
// Add batch dim
if (user == EncodingUser::BATCH_MATMUL) {
perm.push_back(mapDimToRoleIndex(cDims->batch[0], encoding));
}
// Add M dim
if (role != EncodingRole::RHS && cDims->m.size() == 1) {
perm.push_back(mapDimToRoleIndex(cDims->m[0], encoding));
}
// Add K dim
if (role != EncodingRole::RESULT) {
perm.push_back(mapDimToRoleIndex(cDims->k[0], encoding));
}
// Add N dim
if (role != EncodingRole::LHS && cDims->n.size() == 1) {
perm.push_back(mapDimToRoleIndex(cDims->n[0], encoding));
}
return perm;
}

RankedTensorType getCanonicalMatmulTypeWithEncoding(RankedTensorType type) {
auto encoding = getEncodingAttr(type);
if (!encoding) {
return type;
}
auto perm = getPermutationToCanonicalMatmulShape(encoding);
if (!perm) {
return type;
}
return RankedTensorType::get(applyPermutation(type.getShape(), perm.value()),
type.getElementType(), encoding);
}

RankedTensorType getOriginalTypeWithEncoding(RankedTensorType type) {
auto encoding = getEncodingAttr(type);
if (!encoding) {
Expand Down Expand Up @@ -100,23 +181,27 @@ int64_t getIntOrZero(IntegerAttr a) {
}

bool isVecmatEncoding(EncodingAttr encoding) {
return encoding.getUser().getValue() == EncodingUser::MATMUL &&
getIntOrZero(encoding.getMatmulNarrow_M()) == 1;
auto cDims = getEncodingContractionDims(encoding);
return !failed(cDims) && cDims->batch.size() == 0 && cDims->m.size() == 0 &&
cDims->k.size() == 1 && cDims->n.size() == 1;
}

bool isMatvecEncoding(EncodingAttr encoding) {
return encoding.getUser().getValue() == EncodingUser::MATMUL &&
getIntOrZero(encoding.getMatmulNarrow_N()) == 1;
auto cDims = getEncodingContractionDims(encoding);
return !failed(cDims) && cDims->batch.size() == 0 && cDims->m.size() == 1 &&
cDims->k.size() == 1 && cDims->n.size() == 0;
}

bool isBatchVecmatEncoding(EncodingAttr encoding) {
return encoding.getUser().getValue() == EncodingUser::BATCH_MATMUL &&
getIntOrZero(encoding.getMatmulNarrow_M()) == 1;
auto cDims = getEncodingContractionDims(encoding);
return !failed(cDims) && cDims->batch.size() == 1 && cDims->m.size() == 0 &&
cDims->k.size() == 1 && cDims->n.size() == 1;
}

bool isBatchMatvecEncoding(EncodingAttr encoding) {
return encoding.getUser().getValue() == EncodingUser::BATCH_MATMUL &&
getIntOrZero(encoding.getMatmulNarrow_N()) == 1;
auto cDims = getEncodingContractionDims(encoding);
return !failed(cDims) && cDims->batch.size() == 1 && cDims->m.size() == 1 &&
cDims->k.size() == 1 && cDims->n.size() == 0;
}

bool isVectorEncoding(int64_t rank, EncodingUser user) {
Expand All @@ -126,61 +211,37 @@ bool isVectorEncoding(int64_t rank, EncodingUser user) {
MaterializeEncodingInfo getEncodingInfoForMatmul(EncodingAttr encoding,
int64_t rank,
TileMxNxK tileMxNxK) {
EncodingUser user = encoding.getUser().getValue();
EncodingRole role = encoding.getRole().getValue();
bool isVector = isVectorEncoding(rank, user);
bool isVecmatVector = (isVector && (isVecmatEncoding(encoding) ||
isBatchVecmatEncoding(encoding)));
bool isMatvecVector = (isVector && (isMatvecEncoding(encoding) ||
isBatchMatvecEncoding(encoding)));
// Start dim of the MxK (LHS), KxN (RHS), or MxN (RESULT) 2D matrix.
int64_t matmulDimBase = isBatchMatmulEncodingUser(user) ? 1 : 0;

MaterializeEncodingInfo encodingInfo;
if (isVector) {
encodingInfo.innerDimsPos = {matmulDimBase};
} else {
encodingInfo.innerDimsPos = {matmulDimBase, matmulDimBase + 1};
}

switch (role) {
case (EncodingRole::LHS): {
if (isVecmatVector) {
encodingInfo.innerTileSizes = {tileMxNxK.K};
break;
}
encodingInfo.innerTileSizes = {tileMxNxK.M, tileMxNxK.K};
break;
}
case (EncodingRole::RHS): {
if (isMatvecVector) {
encodingInfo.innerTileSizes = {tileMxNxK.K};
break;
}
encodingInfo.innerTileSizes = {tileMxNxK.N, tileMxNxK.K};
encodingInfo.innerDimsPos = {matmulDimBase + 1, matmulDimBase};
encodingInfo.outerDimsPerm =
llvm::to_vector(llvm::seq<int64_t>(0, matmulDimBase));
encodingInfo.outerDimsPerm.push_back(matmulDimBase + 1);
encodingInfo.outerDimsPerm.push_back(matmulDimBase);
break;
auto cDims = getEncodingContractionDims(encoding);
// The following expects M, N, K, and Batch sizes of at most 1 for now
assert(cDims->m.size() <= 1 && cDims->n.size() <= 1 && cDims->k.size() <= 1 &&
cDims->batch.size() <= 1 &&
"Expected at most one M, N, K, and Batch dimension");
if (!cDims->batch.empty()) {
encodingInfo.outerDimsPerm.push_back(
mapDimToRoleIndex(cDims->batch[0], encoding));
}
case (EncodingRole::RESULT): {
if (isVecmatVector) {
encodingInfo.innerTileSizes = {tileMxNxK.N};
break;
}
if (isMatvecVector) {
encodingInfo.innerTileSizes = {tileMxNxK.M};
break;
}
encodingInfo.innerTileSizes = {tileMxNxK.M, tileMxNxK.N};
break;
if (role != EncodingRole::RHS && !cDims->m.empty()) {
encodingInfo.outerDimsPerm.push_back(
mapDimToRoleIndex(cDims->m[0], encoding));
encodingInfo.innerDimsPos.push_back(
mapDimToRoleIndex(cDims->m[0], encoding));
encodingInfo.innerTileSizes.push_back(tileMxNxK.M);
}
default: {
assert(false);
return {};
if (role != EncodingRole::LHS && !cDims->n.empty()) {
encodingInfo.outerDimsPerm.push_back(
mapDimToRoleIndex(cDims->n[0], encoding));
encodingInfo.innerDimsPos.push_back(
mapDimToRoleIndex(cDims->n[0], encoding));
encodingInfo.innerTileSizes.push_back(tileMxNxK.N);
}
if (role != EncodingRole::RESULT) {
encodingInfo.outerDimsPerm.push_back(
mapDimToRoleIndex(cDims->k[0], encoding));
encodingInfo.innerDimsPos.push_back(
mapDimToRoleIndex(cDims->k[0], encoding));
encodingInfo.innerTileSizes.push_back(tileMxNxK.K);
}
return encodingInfo;
}
Expand Down
9 changes: 9 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,15 @@ class OpMaterializeEncodingPattern : public OpConversionPattern<OpTy> {
/// Otherwise, returns null.
IREE::LinalgExt::EncodingAttr getEncodingAttr(RankedTensorType type);

/// Get the permutation that permutes the input shape to the canonical
/// matmul input shape based on the IndexingMaps encoding attribute.
std::optional<SmallVector<int64_t>>
getPermutationToCanonicalMatmulShape(IREE::LinalgExt::EncodingAttr encoding);

/// Returns a RankedTensorType that has been transposed into the canonical
/// form for an ordinary matmul/batch_matmul op.
RankedTensorType getCanonicalMatmulTypeWithEncoding(RankedTensorType type);

/// Returns the original type that carried by encoding.
RankedTensorType getOriginalTypeWithEncoding(RankedTensorType type);

Expand Down
69 changes: 43 additions & 26 deletions compiler/src/iree/compiler/GlobalOptimization/SetEncoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
Expand Down Expand Up @@ -98,17 +99,21 @@ enum class ContractionOpType {

static ContractionOpType
getContractionOpType(linalg::ContractionOpInterface op) {
if (op.isRowMajorMatmul() || op.isColumnMajorMatmul())
FailureOr<linalg::ContractionDimensions> cDims =
linalg::inferContractionDims(cast<linalg::LinalgOp>(op.getOperation()));
if (failed(cDims))
return ContractionOpType::kInvalid;
if (cDims->batch.size() == 0 && cDims->m.size() == 1 && cDims->n.size() == 1)
return ContractionOpType::kMatmul;
if (op.isRowMajorBatchMatmul())
return ContractionOpType::kBatchMatmul;
if (op.isVecmat())
if (cDims->batch.size() == 0 && cDims->m.size() == 0 && cDims->n.size() == 1)
return ContractionOpType::kVecmat;
if (op.isBatchVecmat())
return ContractionOpType::kBatchVecmat;
if (op.isMatvec())
if (cDims->batch.size() == 0 && cDims->m.size() == 1 && cDims->n.size() == 0)
return ContractionOpType::kMatvec;
if (op.isBatchMatvec())
if (cDims->batch.size() == 1 && cDims->m.size() == 1 && cDims->n.size() == 1)
return ContractionOpType::kBatchMatmul;
if (cDims->batch.size() == 1 && cDims->m.size() == 0 && cDims->n.size() == 1)
return ContractionOpType::kBatchVecmat;
if (cDims->batch.size() == 1 && cDims->m.size() == 1 && cDims->n.size() == 0)
return ContractionOpType::kBatchMatvec;
return ContractionOpType::kInvalid;
}
Expand All @@ -120,25 +125,33 @@ struct MatmulNarrowSizes {
// Returns the minimum of static sizes of the M/N-dimensions in the types of the
// Ouput.
static MatmulNarrowSizes getMatmulNarrowSizes(ShapedType outType,
ContractionOpType opType) {
ContractionOpType opType,
linalg::LinalgOp linalgOp) {
linalg::ContractionDimensions cDims =
linalg::inferContractionDims(linalgOp).value();
auto map = linalgOp.getIndexingMapsArray().back();
auto getOutputSizeAtDimPos = [&](unsigned dimPos) -> int64_t {
return outType.getDimSize(
map.getResultPosition(getAffineDimExpr(dimPos, linalgOp->getContext()))
.value());
};
int64_t M, N;
int64_t rank = outType.getRank();
switch (opType) {
case ContractionOpType::kMatmul:
case ContractionOpType::kBatchMatmul: {
M = outType.getDimSize(rank - 2);
N = outType.getDimSize(rank - 1);
M = getOutputSizeAtDimPos(cDims.m[0]);
N = getOutputSizeAtDimPos(cDims.n[0]);
break;
}
case ContractionOpType::kVecmat:
case ContractionOpType::kBatchVecmat: {
M = 1;
N = outType.getDimSize(outType.getRank() - 1);
N = getOutputSizeAtDimPos(cDims.n[0]);
break;
}
case ContractionOpType::kMatvec:
case ContractionOpType::kBatchMatvec: {
M = outType.getDimSize(outType.getRank() - 1);
M = getOutputSizeAtDimPos(cDims.m[0]);
N = 1;
break;
}
Expand Down Expand Up @@ -167,7 +180,8 @@ static MatmulNarrowSizes getMatmulNarrowSizes(ShapedType outType,
static IREE::LinalgExt::EncodingAttr
makeEncoding(OpBuilder &builder, IREE::LinalgExt::EncodingUser user,
IREE::LinalgExt::EncodingRole role, TypeRange operandTypes,
Type originalType, MatmulNarrowSizes narrow) {
Type originalType, MatmulNarrowSizes narrow,
ArrayAttr indexingMaps) {
auto *context = builder.getContext();
auto userAttr = IREE::LinalgExt::EncodingUserAttr::get(context, user);
auto roleAttr = IREE::LinalgExt::EncodingRoleAttr::get(context, role);
Expand All @@ -184,7 +198,7 @@ makeEncoding(OpBuilder &builder, IREE::LinalgExt::EncodingUser user,
};
return IREE::LinalgExt::EncodingAttr::get(
context, userAttr, roleAttr, operandElemTypesAttr, originalTypeAttr,
getAttr(narrow.M), getAttr(narrow.N));
getAttr(narrow.M), getAttr(narrow.N), indexingMaps);
}

// Creates a linalg::GenericOp that performs an element-wise cast of the same
Expand All @@ -207,14 +221,15 @@ static Value
padAndSetEncoding(OpBuilder &builder, Location loc, Value source,
IREE::LinalgExt::EncodingUser user,
IREE::LinalgExt::EncodingRole role, TypeRange operandTypes,
MatmulNarrowSizes narrow,
MatmulNarrowSizes narrow, ArrayAttr indexingMaps,
std::optional<CastOpInterface> castOp = std::nullopt) {
Value padSource = castOp ? source.getDefiningOp()->getOperand(0) : source;
// No need to specify original_type in the encoding poadded to pad(), because
// the operand there is the `source` tensor, so it will default to reading its
// original shape.
auto encodingForPad = makeEncoding(builder, user, role, operandTypes,
/*originalType=*/Type{}, narrow);
auto encodingForPad =
makeEncoding(builder, user, role, operandTypes,
/*originalType=*/Type{}, narrow, indexingMaps);
Value padded = pad(builder, loc, padSource, encodingForPad);
// For setEncoding() below, we potentially need to specify an encoding with an
// explicit original_type, because the operand there is the padded tensor
Expand All @@ -224,8 +239,9 @@ padAndSetEncoding(OpBuilder &builder, Location loc, Value source,
// the tensor type that the encoding is applied to.
auto encodingForSetEncoding = encodingForPad;
if (padded.getType() != padSource.getType()) {
encodingForSetEncoding = makeEncoding(builder, user, role, operandTypes,
padSource.getType(), narrow);
encodingForSetEncoding =
makeEncoding(builder, user, role, operandTypes, padSource.getType(),
narrow, indexingMaps);
}
Value encoded = setEncoding(builder, loc, padded, encodingForSetEncoding);
if (castOp) {
Expand Down Expand Up @@ -321,9 +337,10 @@ struct setContractionOpEncoding
return rewriter.notifyMatchFailure(op, "unsupported contraction op");
}

MatmulNarrowSizes narrowSizes =
getMatmulNarrowSizes(origOut.getType().cast<ShapedType>(), opType);
MatmulNarrowSizes narrowSizes = getMatmulNarrowSizes(
origOut.getType().cast<ShapedType>(), opType, linalgOp);

auto maps = linalgOp.getIndexingMaps();
Location loc = linalgOp.getLoc();
SmallVector<Type> operandTypes(linalgOp->getOperandTypes());
operandTypes[0] =
Expand All @@ -332,13 +349,13 @@ struct setContractionOpEncoding
cast<RankedTensorType>(operandTypes[1]).clone(rhsElemType);
Value encodedLhs = padAndSetEncoding(
rewriter, loc, origLhs, user, IREE::LinalgExt::EncodingRole::LHS,
operandTypes, narrowSizes, maybeLhsCastOp);
operandTypes, narrowSizes, maps, maybeLhsCastOp);
Value encodedRhs = padAndSetEncoding(
rewriter, loc, origRhs, user, IREE::LinalgExt::EncodingRole::RHS,
operandTypes, narrowSizes, maybeRhsCastOp);
operandTypes, narrowSizes, maps, maybeRhsCastOp);
Value encodedOut = padAndSetEncoding(rewriter, loc, origOut, user,
IREE::LinalgExt::EncodingRole::RESULT,
operandTypes, narrowSizes);
operandTypes, narrowSizes, maps);
Value opTiled;
opTiled = clone(rewriter, linalgOp, encodedOut.getType(),
ValueRange{encodedLhs, encodedRhs, encodedOut})
Expand Down
Loading

0 comments on commit c27ed41

Please sign in to comment.