Skip to content

Commit

Permalink
[mlir][sparse] avoid non-perm on sparse tensor convert for new (llvm#…
Browse files Browse the repository at this point in the history
…72459)

This avoids seeing non-perm on the convert from COO to non-COO for
higher dimensional new operators (viz. reading in BSR).

This is step 1 out of 3 to make sparse_tensor.new work for BSR
  • Loading branch information
aartbik authored Nov 16, 2023
1 parent 8404406 commit e8fc282
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1189,27 +1189,38 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
LogicalResult matchAndRewrite(NewOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
const auto dstTp = getSparseTensorType(op.getResult());
const auto encDst = dstTp.getEncoding();
if (!dstTp.hasEncoding() || getCOOStart(encDst) == 0)
auto stt = getSparseTensorType(op.getResult());
auto enc = stt.getEncoding();
if (!stt.hasEncoding() || getCOOStart(enc) == 0)
return failure();

// Implement the NewOp as follows:
// %orderedCoo = sparse_tensor.new %filename
// %t = sparse_tensor.convert %orderedCoo
// with enveloping reinterpreted_map ops for non-permutations.
RankedTensorType dstTp = stt.getRankedTensorType();
RankedTensorType cooTp = getCOOType(dstTp, /*ordered=*/true);
Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource());
Value convert = rewriter.replaceOpWithNewOp<ConvertOp>(
op, dstTp.getRankedTensorType(), cooTensor);
Value convert = cooTensor;
if (!stt.isPermutation()) { // demap coo, demap dstTp
auto coo = getSparseTensorType(cooTensor).getEncoding().withoutDimToLvl();
convert = rewriter.create<ReinterpretMapOp>(loc, coo, convert);
dstTp = getSparseTensorType(convert).withEncoding(enc.withoutDimToLvl());
}
convert = rewriter.create<ConvertOp>(loc, dstTp, convert);
if (!stt.isPermutation()) // remap to original enc
convert = rewriter.create<ReinterpretMapOp>(loc, enc, convert);
rewriter.replaceOp(op, convert);

// Release the ordered COO tensor.
// Release the temporary ordered COO tensor.
rewriter.setInsertionPointAfterValue(convert);
rewriter.create<DeallocTensorOp>(loc, cooTensor);

return success();
}
};

/// Sparse rewriting rule for the out operator.
struct OutRewriter : public OpRewritePattern<OutOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(OutOp op,
Expand Down Expand Up @@ -1250,6 +1261,7 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
primaryTypeFunctionSuffix(eltTp)};
Value value = genAllocaScalar(rewriter, loc, eltTp);
ModuleOp module = op->getParentOfType<ModuleOp>();

// For each element in the source tensor, output the element.
rewriter.create<ForeachOp>(
loc, src, std::nullopt,
Expand Down

0 comments on commit e8fc282

Please sign in to comment.