Skip to content

Commit

Permalink
Introduce MaskedOpDistributionPattern that provides
Browse files Browse the repository at this point in the history
a hook to provide vector.mask { op } rewrites.

This removes the rewrite ordering constraint that
would otherwise be there where body op has to be
distributed prior to mask op.

Now, using this hook, developers could write
masked op distribution pattern where pre-distribution
mask op would be removed as part of the rewrite.

Signed-off-by: Manupa Karunaratne <[email protected]>
  • Loading branch information
manupak committed Feb 18, 2025
1 parent bf61f7c commit a8dcc8b
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -587,17 +587,19 @@ static int64_t getShuffleWidth(NestedLayoutAttr layout, int64_t dim) {
/// to shared memory and will be reloaded into a layout where partial
/// reductions will be placed inside threads.
struct DistributeMultiReduction final
: OpDistributionPattern<vector::MultiDimReductionOp> {
using OpDistributionPattern::OpDistributionPattern;
: MaskedOpDistributionPattern<vector::MultiDimReductionOp> {
using MaskedOpDistributionPattern::MaskedOpDistributionPattern;

DistributeMultiReduction(MLIRContext *context, int64_t subgroupSize,
int64_t maxBitsPerShuffle, int64_t benefit = 1)
: OpDistributionPattern(context, benefit), subgroupSize(subgroupSize),
maxBitsPerShuffle(maxBitsPerShuffle) {}

LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReduceOp,
DistributionSignature &signature,
PatternRewriter &rewriter) const override {
: MaskedOpDistributionPattern(context, benefit),
subgroupSize(subgroupSize), maxBitsPerShuffle(maxBitsPerShuffle) {}

LogicalResult
matchAndRewrite(vector::MultiDimReductionOp multiReduceOp,
DistributionSignature &signature, vector::MaskOp maskOp,
std::optional<DistributionSignature> &maskSignature,
PatternRewriter &rewriter) const {
Location loc = multiReduceOp.getLoc();
VectorValue srcVector = multiReduceOp.getSource();
Value acc = multiReduceOp.getAcc();
Expand Down Expand Up @@ -632,11 +634,9 @@ struct DistributeMultiReduction final
}

VectorValue mask = nullptr;
if (auto maskOp = multiReduceOp->getParentOfType<vector::MaskOp>()) {
std::optional<DistributionSignature> signatureMask =
getOpSignature(maskOp);
if (maskOp) {
auto maskLayout = dyn_cast_or_null<NestedLayoutAttr>(
signatureMask.value()[maskOp.getMask()]);
maskSignature.value()[maskOp.getMask()]);
if (!maskLayout) {
return rewriter.notifyMatchFailure(maskOp,
"expected nested layout attr");
Expand Down Expand Up @@ -1680,33 +1680,6 @@ struct DistributeCreateMask final
int64_t subgroupSize;
};

struct DistributeMask final : OpDistributionPattern<vector::MaskOp> {
using OpDistributionPattern::OpDistributionPattern;

LogicalResult matchAndRewrite(vector::MaskOp maskOp,
DistributionSignature &signature,
PatternRewriter &rewriter) const override {
SmallVector<Value> returns =
maskOp.getBody()->getTerminator()->getOperands();
for (auto [idx, ret] : llvm::enumerate(returns)) {
if (VectorValue vectorRet = dyn_cast<VectorValue>(ret)) {
VectorValue maskRet = cast<VectorValue>(maskOp.getResult(idx));
VectorLayoutInterface layout =
dyn_cast<NestedLayoutAttr>(signature[maskRet]);
if (!layout) {
return rewriter.notifyMatchFailure(maskOp,
"layout must be NestedLayoutAttr");
}
ret = getDistributed(rewriter, vectorRet, layout);
}
}
rewriter.eraseOp(maskOp.getBody()->getTerminator());
rewriter.inlineBlockBefore(maskOp.getBody(), maskOp);
replaceOpWithDistributedValues(rewriter, maskOp, returns);
return success();
}
};

} // namespace

void populateGPUDistributeNestedLayoutAttrPatterns(RewritePatternSet &patterns,
Expand All @@ -1723,7 +1696,6 @@ void populateGPUDistributeNestedLayoutAttrPatterns(RewritePatternSet &patterns,
patterns.add<DistributeStep>(patterns.getContext(), threadId, subgroupSize);
patterns.add<DistributeCreateMask>(patterns.getContext(), threadId,
subgroupSize);
patterns.add<DistributeMask>(patterns.getContext());
}

}; // namespace mlir::iree_compiler
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,8 @@ DistributionPattern::getDistributed(RewriterBase &rewriter, VectorValue value,
return toSIMT.getResult();
}

void DistributionPattern::replaceOpWithDistributedValues(
SmallVector<Value> DistributionPattern::getOpDistributedReplacements(
RewriterBase &rewriter, Operation *op, ValueRange values) const {
// Replace all OpResults with the given values.
SmallVector<Value> replacements;
for (auto [opResult, replacement] :
llvm::zip_equal(op->getOpResults(), values)) {
Expand All @@ -162,7 +161,14 @@ void DistributionPattern::replaceOpWithDistributedValues(
}
replacements.push_back(replacement);
}
return replacements;
}

void DistributionPattern::replaceOpWithDistributedValues(
RewriterBase &rewriter, Operation *op, ValueRange values) const {
// Replace all OpResults with the given values.
SmallVector<Value> replacements =
getOpDistributedReplacements(rewriter, op, values);
rewriter.replaceOp(op, replacements);
}

Expand All @@ -186,6 +192,34 @@ void DistributionPattern::setSignatureForRedistribution(
});
}

LogicalResult
DistributionPattern::replaceParentMask(PatternRewriter &rewriter,
vector::MaskOp maskOp) const {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(maskOp);
std::optional<DistributionSignature> signatureMask = getOpSignature(maskOp);
if (!signatureMask.has_value()) {
return rewriter.notifyMatchFailure(maskOp, "mask should have a signature.");
}
SmallVector<Value> returns = maskOp.getBody()->getTerminator()->getOperands();
for (auto [idx, ret] : llvm::enumerate(returns)) {
if (VectorValue vectorRet = dyn_cast<VectorValue>(ret)) {
VectorValue maskRet = cast<VectorValue>(maskOp.getResult(idx));
VectorLayoutInterface layout =
dyn_cast<NestedLayoutAttr>(signatureMask.value()[maskRet]);
if (!layout) {
return rewriter.notifyMatchFailure(maskOp,
"layout must be NestedLayoutAttr");
}
ret = getDistributed(rewriter, vectorRet, layout);
}
}
rewriter.eraseOp(maskOp.getBody()->getTerminator());
rewriter.inlineBlockBefore(maskOp.getBody(), maskOp);
replaceOpWithDistributedValues(rewriter, maskOp, returns);
return success();
}

static void
debugPrintUniqueOperationNames(const std::deque<Operation *> &worklist) {
DenseSet<StringRef> uniqueNames;
Expand Down Expand Up @@ -239,7 +273,12 @@ static void applyVectorDistribution(Operation *root,
std::deque<Operation *> worklist;
LLVM_DEBUG(llvm::dbgs() << "Collecting operations to be distributed\n");
root->walk([&](Operation *op) {
if (hasOpSignature(op)) {
// The distribution of mask op is special.
// Although the signature set for visibility purposes
// but it will be distributed when the body is
// distributed. Therefore, we explicitly exclude
// the yield and the mask op.
if (hasOpSignature(op) && !isa<vector::MaskOp, vector::YieldOp>(op)) {
worklist.push_back(op);
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ struct DistributionPattern : RewritePattern {
TypedValue<VectorType> value,
VectorLayoutInterface layout) const;

/// Get the distributed values that could replace the op
SmallVector<Value> getOpDistributedReplacements(RewriterBase &rewriter,
Operation *op,
ValueRange values) const;

/// Replace an op with its distributed replacement values.
void replaceOpWithDistributedValues(RewriterBase &rewriter, Operation *op,
ValueRange values) const;
Expand All @@ -49,6 +54,9 @@ struct DistributionPattern : RewritePattern {
void setSignatureForRedistribution(PatternRewriter &rewriter, Operation *op,
Attribute inputLayoutsAttr,
Attribute outputLayoutsAttr) const;

LogicalResult replaceParentMask(PatternRewriter &rewriter,
vector::MaskOp) const;
};

template <typename SourceOp>
Expand All @@ -70,6 +78,43 @@ struct OpDistributionPattern : DistributionPattern {
}
};

template <typename SourceOp>
struct MaskedOpDistributionPattern : DistributionPattern {
MaskedOpDistributionPattern(MLIRContext *context, PatternBenefit benefit = 1)
: DistributionPattern(SourceOp::getOperationName(), benefit, context) {}

virtual LogicalResult
matchAndRewrite(SourceOp op, DistributionSignature &opSignature,
vector::MaskOp maskOp,
std::optional<DistributionSignature> &maskSignature,
PatternRewriter &rewriter) const = 0;

LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final {
std::optional<DistributionSignature> opSignature = getOpSignature(op);
if (!opSignature) {
return failure();
}
auto maskOp = op->getParentOfType<vector::MaskOp>();
std::optional<DistributionSignature> maskSignature;
if (maskOp) {
maskSignature = getOpSignature(maskOp);
if (!maskSignature) {
return failure();
}
}
LogicalResult result = matchAndRewrite(cast<SourceOp>(op), *opSignature,
maskOp, maskSignature, rewriter);
if (failed(result)) {
return failure();
}
if (maskOp) {
return replaceParentMask(rewriter, maskOp);
}
return success();
}
};

template <template <typename> class TraitType>
class OpTraitDistributionPattern : public DistributionPattern {
public:
Expand Down

0 comments on commit a8dcc8b

Please sign in to comment.