Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLVMGPUVectorDistribute] Support vector.mask + vector.multi_reduce #19880

Merged
merged 3 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -587,17 +587,20 @@ static int64_t getShuffleWidth(NestedLayoutAttr layout, int64_t dim) {
/// to shared memory and will be reloaded into a layout where partial
/// reductions will be placed inside threads.
struct DistributeMultiReduction final
: OpDistributionPattern<vector::MultiDimReductionOp> {
using OpDistributionPattern::OpDistributionPattern;
: MaskedOpDistributionPattern<vector::MultiDimReductionOp> {
using MaskedOpDistributionPattern::MaskedOpDistributionPattern;

DistributeMultiReduction(MLIRContext *context, int64_t subgroupSize,
int64_t maxBitsPerShuffle, int64_t benefit = 1)
: OpDistributionPattern(context, benefit), subgroupSize(subgroupSize),
maxBitsPerShuffle(maxBitsPerShuffle) {}

LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReduceOp,
DistributionSignature &signature,
PatternRewriter &rewriter) const override {
: MaskedOpDistributionPattern(context, benefit),
subgroupSize(subgroupSize), maxBitsPerShuffle(maxBitsPerShuffle) {}

LogicalResult
matchAndRewrite(vector::MultiDimReductionOp multiReduceOp,
DistributionSignature &signature, vector::MaskOp maskOp,
std::optional<DistributionSignature> &maskSignature,
PatternRewriter &rewriter) const {
Location loc = multiReduceOp.getLoc();
VectorValue srcVector = multiReduceOp.getSource();
Value acc = multiReduceOp.getAcc();
Value res = multiReduceOp.getResult();
Expand Down Expand Up @@ -630,7 +633,23 @@ struct DistributeMultiReduction final
disAcc = multiReduceOp.getAcc();
}

Location loc = multiReduceOp.getLoc();
VectorValue mask = nullptr;
if (maskOp) {
auto maskLayout = dyn_cast_or_null<NestedLayoutAttr>(
maskSignature.value()[maskOp.getMask()]);
if (!maskLayout) {
return rewriter.notifyMatchFailure(maskOp,
"expected nested layout attr");
}
mask = getDistributed(rewriter, maskOp.getMask(), maskLayout);
Value passThruSrc = getCombiningIdentityValue(
loc, rewriter, multiReduceOp.getKind(), disSrc.getType());

disSrc = cast<VectorValue>(
rewriter.create<arith::SelectOp>(loc, mask, disSrc, passThruSrc)
.getResult());
}

SmallVector<bool> reducedDims = multiReduceOp.getReductionMask();
int64_t rank = srcVector.getType().getRank();

Expand All @@ -645,18 +664,29 @@ struct DistributeMultiReduction final
}
Value localInit = getCombiningIdentityValue(
loc, rewriter, multiReduceOp.getKind(), disAcc.getType());
auto localReduction = rewriter.create<vector::MultiDimReductionOp>(
Value localReduction = rewriter.create<vector::MultiDimReductionOp>(
loc, disSrc, localInit, distributedReductionMask,
multiReduceOp.getKind());

// TODO: As per current upstream lowering implementations, there is no point
// in doing this because it does a select much later in a finer granularity
// rather than supporting predication. Moreover, since we are doing a select
// to cater reductions accross the distribution, we can choose not to mask
// the op post-distribution. if (mask) {
// localReduction =
// vector::maskOperation(rewriter, localReduction.getDefiningOp(),
// mask)
// ->getResult(0);
// }

VectorValue locallyReduced;
if (accVector) {
locallyReduced = dyn_cast<VectorValue>(localReduction.getResult());
locallyReduced = dyn_cast<VectorValue>(localReduction);
} else {
// Broadcast scalar accumulator to vector.
VectorType vecType = VectorType::get(ArrayRef{int64_t(1)}, elemTy);
locallyReduced = rewriter.create<vector::BroadcastOp>(
loc, vecType, localReduction.getResult());
locallyReduced =
rewriter.create<vector::BroadcastOp>(loc, vecType, localReduction);
}

assert(locallyReduced && "result should have been a vector");
Expand Down Expand Up @@ -739,7 +769,6 @@ struct DistributeMultiReduction final

for (unsigned i = 0; i < numElements; ++i) {
Value extracted = rewriter.create<vector::ExtractOp>(loc, flat, i);

// Reduce across all reduction dimensions 1-by-1.
for (unsigned i = 0, e = reductionMask.size(); i != e; ++i) {
if (reductionMask[i]) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,8 @@ DistributionPattern::getDistributed(RewriterBase &rewriter, VectorValue value,
return toSIMT.getResult();
}

void DistributionPattern::replaceOpWithDistributedValues(
SmallVector<Value> DistributionPattern::getOpDistributedReplacements(
RewriterBase &rewriter, Operation *op, ValueRange values) const {
// Replace all OpResults with the given values.
SmallVector<Value> replacements;
for (auto [opResult, replacement] :
llvm::zip_equal(op->getOpResults(), values)) {
Expand All @@ -162,7 +161,14 @@ void DistributionPattern::replaceOpWithDistributedValues(
}
replacements.push_back(replacement);
}
return replacements;
}

void DistributionPattern::replaceOpWithDistributedValues(
RewriterBase &rewriter, Operation *op, ValueRange values) const {
// Replace all OpResults with the given values.
SmallVector<Value> replacements =
getOpDistributedReplacements(rewriter, op, values);
rewriter.replaceOp(op, replacements);
}

Expand All @@ -186,6 +192,34 @@ void DistributionPattern::setSignatureForRedistribution(
});
}

LogicalResult
DistributionPattern::replaceParentMask(PatternRewriter &rewriter,
vector::MaskOp maskOp) const {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(maskOp);
std::optional<DistributionSignature> signatureMask = getOpSignature(maskOp);
if (!signatureMask.has_value()) {
return rewriter.notifyMatchFailure(maskOp, "mask should have a signature.");
}
SmallVector<Value> returns = maskOp.getBody()->getTerminator()->getOperands();
for (auto [idx, ret] : llvm::enumerate(returns)) {
if (VectorValue vectorRet = dyn_cast<VectorValue>(ret)) {
VectorValue maskRet = cast<VectorValue>(maskOp.getResult(idx));
VectorLayoutInterface layout =
dyn_cast<NestedLayoutAttr>(signatureMask.value()[maskRet]);
if (!layout) {
return rewriter.notifyMatchFailure(maskOp,
"layout must be NestedLayoutAttr");
}
ret = getDistributed(rewriter, vectorRet, layout);
}
}
rewriter.eraseOp(maskOp.getBody()->getTerminator());
rewriter.inlineBlockBefore(maskOp.getBody(), maskOp);
replaceOpWithDistributedValues(rewriter, maskOp, returns);
return success();
}

static void
debugPrintUniqueOperationNames(const std::deque<Operation *> &worklist) {
DenseSet<StringRef> uniqueNames;
Expand Down Expand Up @@ -239,7 +273,12 @@ static void applyVectorDistribution(Operation *root,
std::deque<Operation *> worklist;
LLVM_DEBUG(llvm::dbgs() << "Collecting operations to be distributed\n");
root->walk([&](Operation *op) {
if (hasOpSignature(op)) {
// The distribution of mask op is special.
// Although the signature set for visibility purposes
// but it will be distributed when the body is
// distributed. Therefore, we explicitly exclude
// the yield and the mask op.
if (hasOpSignature(op) && !isa<vector::MaskOp, vector::YieldOp>(op)) {
worklist.push_back(op);
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ struct DistributionPattern : RewritePattern {
TypedValue<VectorType> value,
VectorLayoutInterface layout) const;

/// Get the distributed values that could replace the op
SmallVector<Value> getOpDistributedReplacements(RewriterBase &rewriter,
Operation *op,
ValueRange values) const;

/// Replace an op with its distributed replacement values.
void replaceOpWithDistributedValues(RewriterBase &rewriter, Operation *op,
ValueRange values) const;
Expand All @@ -49,6 +54,9 @@ struct DistributionPattern : RewritePattern {
void setSignatureForRedistribution(PatternRewriter &rewriter, Operation *op,
Attribute inputLayoutsAttr,
Attribute outputLayoutsAttr) const;

LogicalResult replaceParentMask(PatternRewriter &rewriter,
vector::MaskOp) const;
};

template <typename SourceOp>
Expand All @@ -70,6 +78,43 @@ struct OpDistributionPattern : DistributionPattern {
}
};

template <typename SourceOp>
struct MaskedOpDistributionPattern : DistributionPattern {
MaskedOpDistributionPattern(MLIRContext *context, PatternBenefit benefit = 1)
: DistributionPattern(SourceOp::getOperationName(), benefit, context) {}

virtual LogicalResult
matchAndRewrite(SourceOp op, DistributionSignature &opSignature,
vector::MaskOp maskOp,
std::optional<DistributionSignature> &maskSignature,
PatternRewriter &rewriter) const = 0;

LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final {
std::optional<DistributionSignature> opSignature = getOpSignature(op);
if (!opSignature) {
return failure();
}
auto maskOp = op->getParentOfType<vector::MaskOp>();
std::optional<DistributionSignature> maskSignature;
if (maskOp) {
maskSignature = getOpSignature(maskOp);
if (!maskSignature) {
return failure();
}
}
LogicalResult result = matchAndRewrite(cast<SourceOp>(op), *opSignature,
maskOp, maskSignature, rewriter);
if (failed(result)) {
return failure();
}
if (maskOp) {
return replaceParentMask(rewriter, maskOp);
}
return success();
}
};

template <template <typename> class TraitType>
class OpTraitDistributionPattern : public DistributionPattern {
public:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,50 @@ builtin.module attributes { transform.with_named_sequence } {
// CHECK: %[[DISTR_MASK:.+]] = vector.create_mask {{.*}} : vector<8xi1>
// CHECK: %[[DISTR_MASK_0:.+]] = vector.extract_strided_slice %16 {offsets = [0], sizes = [2], strides = [1]} : vector<8xi1> to vector<2xi1>
// CHECK: vector.transfer_read %arg0{{.*}} %[[DISTR_MASK_0]]

// -----

#nested = #iree_vector_ext.nested_layout<
subgroup_tile = [2, 1],
batch_tile = [2, 1],
outer_tile = [2, 1],
thread_tile = [16, 16],
element_tile = [2, 8],

subgroup_strides = [1, 0],
thread_strides = [16, 1]
>

func.func @masked_read_write_reduce(%arg0 : memref<?x128xf16>, %arg1 : memref<128xf16>) {
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
%cst_6 = arith.constant 0.000000e+00 : f16
%cst_1 = arith.constant dense<0.000000e+00> : vector<128xf16>

%dyn = memref.dim %arg0, %c0 : memref<?x128xf16>
%41 = vector.create_mask %dyn, %c128 : vector<256x128xi1>
%42 = vector.transfer_read %arg0[%c0, %c0], %cst_6, %41 {in_bounds = [true, true]} : memref<?x128xf16>, vector<256x128xf16>
%43 = iree_vector_ext.to_layout %42 to layout(#nested) : vector<256x128xf16>
%44 = vector.mask %41 { vector.multi_reduction <add>, %43, %cst_1 [0] : vector<256x128xf16> to vector<128xf16> } : vector<256x128xi1> -> vector<128xf16>
vector.transfer_write %44, %arg1[%c0] {in_bounds = [true]} : vector<128xf16>, memref<128xf16>
return
}

builtin.module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
transform.yield
}
}

// CHECK-LABEL: func @masked_read_write_reduce

// CHECK: %[[RED_IDENTITY:.+]] = arith.constant dense<0.000000e+00> : vector<2x1x2x1x2x8xf16>

// CHECK: %[[MASK:.+]] = vector.create_mask
// CHECK: %[[MASK_PCK:.+]] = vector.shape_cast %[[MASK]] : vector<8x8xi1> to vector<2x2x2x1x1x8xi1>
// CHECK: %[[MASK_ITL_PCK:.+]] = vector.transpose %[[MASK_PCK]], [0, 3, 1, 4, 2, 5] : vector<2x2x2x1x1x8xi1> to vector<2x1x2x1x2x8xi1>

// CHECK: %[[SELECT:.+]] = arith.select %[[MASK_ITL_PCK]], {{.*}}, %[[RED_IDENTITY]] : vector<2x1x2x1x2x8xi1>, vector<2x1x2x1x2x8xf16>
// CHECK: vector.multi_reduction <add>, %[[SELECT]], {{.*}} [0, 2, 4] : vector<2x1x2x1x2x8xf16> to vector<1x1x8xf16>
Loading
Loading