diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index 811bdc57ce14fb..3fe0c551be57a4 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -1189,20 +1189,30 @@ struct NewRewriter : public OpRewritePattern { LogicalResult matchAndRewrite(NewOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - const auto dstTp = getSparseTensorType(op.getResult()); - const auto encDst = dstTp.getEncoding(); - if (!dstTp.hasEncoding() || getCOOStart(encDst) == 0) + auto stt = getSparseTensorType(op.getResult()); + auto enc = stt.getEncoding(); + if (!stt.hasEncoding() || getCOOStart(enc) == 0) return failure(); // Implement the NewOp as follows: // %orderedCoo = sparse_tensor.new %filename // %t = sparse_tensor.convert %orderedCoo + // with enveloping reinterpreted_map ops for non-permutations. + RankedTensorType dstTp = stt.getRankedTensorType(); RankedTensorType cooTp = getCOOType(dstTp, /*ordered=*/true); Value cooTensor = rewriter.create(loc, cooTp, op.getSource()); - Value convert = rewriter.replaceOpWithNewOp( - op, dstTp.getRankedTensorType(), cooTensor); + Value convert = cooTensor; + if (!stt.isPermutation()) { // demap coo, demap dstTp + auto coo = getSparseTensorType(cooTensor).getEncoding().withoutDimToLvl(); + convert = rewriter.create(loc, coo, convert); + dstTp = getSparseTensorType(convert).withEncoding(enc.withoutDimToLvl()); + } + convert = rewriter.create(loc, dstTp, convert); + if (!stt.isPermutation()) // remap to original enc + convert = rewriter.create(loc, enc, convert); + rewriter.replaceOp(op, convert); - // Release the ordered COO tensor. + // Release the temporary ordered COO tensor. rewriter.setInsertionPointAfterValue(convert); rewriter.create(loc, cooTensor); @@ -1210,6 +1220,7 @@ struct NewRewriter : public OpRewritePattern { } }; +/// Sparse rewriting rule for the out operator. struct OutRewriter : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OutOp op, @@ -1250,6 +1261,7 @@ struct OutRewriter : public OpRewritePattern { primaryTypeFunctionSuffix(eltTp)}; Value value = genAllocaScalar(rewriter, loc, eltTp); ModuleOp module = op->getParentOfType(); + // For each element in the source tensor, output the element. rewriter.create( loc, src, std::nullopt,