Skip to content

Commit

Permalink
[mlir][vector] Add vector.transpose with unit-dim to vector.shape_cas…
Browse files Browse the repository at this point in the history
…t pattern (llvm#72105)

This patch extends the vector.transpose lowering to replace:

vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to vector<1xnx<eltty>>

with:

  vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnx<eltty>>

Source with leading unit-dim (inverse) is also replaced. Unit dim must
be fixed. Non-unit dim can be scalable.

A check is also added to bail out for scalable vectors before unrolling.
  • Loading branch information
c-rhodes authored Nov 15, 2023
1 parent 33b5158 commit b7b6d54
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 0 deletions.
21 changes: 21 additions & 0 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,27 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
return rewriter.notifyMatchFailure(
op, "Options specifies lowering to shuffle");

// Replace:
// vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
// vector<1xnxelty>
// with:
// vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
//
// Source with leading unit dim (inverse) is also replaced. Unit dim must
// be fixed. Non-unit can be scalable.
if (resType.getRank() == 2 &&
((resType.getShape().front() == 1 &&
!resType.getScalableDims().front()) ||
(resType.getShape().back() == 1 &&
!resType.getScalableDims().back())) &&
transp == ArrayRef<int64_t>({1, 0})) {
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
return success();
}

if (inputType.isScalable())
return failure();

// Handle a true 2-D matrix transpose differently when requested.
if (vectorTransformOptions.vectorTransposeLowering ==
vector::VectorTransposeLowering::Flat &&
Expand Down
71 changes: 71 additions & 0 deletions mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,17 @@ func.func @transpose1023_1x1x8x8xf32(%arg0: vector<1x1x8x8xf32>) -> vector<1x1x8
return %0 : vector<1x1x8x8xf32>
}

/// Scalable dim should not be unrolled.

// CHECK-LABEL: func @transpose23_scalable
// CHECK-NOT: vector.extract
// CHECK-NOT: vector.insert
// CHECK: vector.transpose
func.func @transpose23_scalable(%arg0: vector<2x[3]xf32>) -> vector<[3]x2xf32> {
%0 = vector.transpose %arg0, [1, 0] : vector<2x[3]xf32> to vector<[3]x2xf32>
return %0 : vector<[3]x2xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
transform.apply_patterns to %func_op {
Expand Down Expand Up @@ -778,3 +789,63 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}

// -----

/// Transpose of rank-2 vector with leading or trailing unit dim to shape_cast.

// CHECK-LABEL: func @transpose10_4x1xf32
func.func @transpose10_4x1xf32(%arg0: vector<4x1xf32>) -> vector<1x4xf32> {
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<4x1xf32> to vector<1x4xf32>
%0 = vector.transpose %arg0, [1, 0] : vector<4x1xf32> to vector<1x4xf32>
return %0 : vector<1x4xf32>
}

// CHECK-LABEL: func @transpose10_nx4x1xf32
func.func @transpose10_nx4x1xf32(%arg0: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<[4]x1xf32> to vector<1x[4]xf32>
%0 = vector.transpose %arg0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
return %0 : vector<1x[4]xf32>
}

// CHECK-LABEL: func @transpose10_1x4xf32
func.func @transpose10_1x4xf32(%arg0: vector<1x4xf32>) -> vector<4x1xf32> {
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
%0 = vector.transpose %arg0, [1, 0] : vector<1x4xf32> to vector<4x1xf32>
return %0 : vector<4x1xf32>
}

// CHECK-LABEL: func @transpose10_1xnx4xf32
func.func @transpose10_1xnx4xf32(%arg0: vector<1x[4]xf32>) -> vector<[4]x1xf32> {
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x[4]xf32> to vector<[4]x1xf32>
%0 = vector.transpose %arg0, [1, 0] : vector<1x[4]xf32> to vector<[4]x1xf32>
return %0 : vector<[4]x1xf32>
}

/// Scalable unit dim should not be lowered to shape_cast.

// CHECK-LABEL: func @transpose10_4xnx1xf32
func.func @transpose10_4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
// CHECK-NOT: vector.shape_cast
// CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
%0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
return %0 : vector<[1]x4xf32>
}

// CHECK-LABEL: func @transpose10_nx4xnx1xf32
func.func @transpose10_nx4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
// CHECK-NOT: vector.shape_cast
// CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
%0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>

return %0 : vector<[1]x4xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
transform.apply_patterns to %func_op {
transform.apply_patterns.vector.lower_transpose
} : !transform.op<"func.func">
transform.yield
}
}

0 comments on commit b7b6d54

Please sign in to comment.