-
Notifications
You must be signed in to change notification settings - Fork 12.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[mlir][SCF] Add an scf.take_assumed_branch transform op.
Given an scf.if conditional, using this transformation is akin to injecting user-specified information that it is always safe to execute only the specified `if` or `else` branch. This is achieved by just replacing the scf.if by the content of one of its branches. This is particularly useful for user-controlled rewriting of conditionals that exist solely to guard against out-of-bounds behavior. At the moment, no assume or assert operation is emitted as it is not always desirable. In the future, this may be controlled by a dedicated attribute. Differential Revision: https://reviews.llvm.org/D148125
- Loading branch information
1 parent
34f5774
commit 88b7e8e
Showing
4 changed files
with
133 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
50 changes: 50 additions & 0 deletions
50
mlir/test/Dialect/SCF/transform-op-take-assumed-branch.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -verify-diagnostics --allow-unregistered-dialect | FileCheck %s | ||
|
||
func.func @if_no_else(%cond: i1, %a: index, %b: memref<?xf32>, %c: i8) { | ||
scf.if %cond { | ||
"some_op"(%cond, %b) : (i1, memref<?xf32>) -> () | ||
scf.yield | ||
} | ||
return | ||
} | ||
|
||
transform.sequence failures(propagate) { | ||
^bb0(%arg1: !transform.any_op): | ||
%if = transform.structured.match ops{["scf.if"]} in %arg1 | ||
: (!transform.any_op) -> !transform.any_op | ||
|
||
// expected-error @+1 {{requires an scf.if op with a single-block `else` region}} | ||
transform.scf.take_assumed_branch %if take_else_branch | ||
: (!transform.any_op) -> () | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: tile_tensor_pad | ||
func.func @tile_tensor_pad( | ||
%arg0 : tensor<?x?xf32>, %cst : f32, %low: index, %high: index) | ||
-> tensor<20x40xf32> | ||
{ | ||
// CHECK: scf.forall | ||
// CHECK-NOT: scf.if | ||
// CHECK-NOT: tensor.generate | ||
// CHECK-NOT: else | ||
// CHECK: tensor.pad {{.*}} nofold | ||
%0 = tensor.pad %arg0 nofold low[%low, %low] high[%high, %high] { | ||
^bb0(%arg9: index, %arg10: index): | ||
tensor.yield %cst : f32 | ||
} : tensor<?x?xf32> to tensor<20x40xf32> | ||
return %0 : tensor<20x40xf32> | ||
} | ||
|
||
transform.sequence failures(propagate) { | ||
^bb0(%arg1: !transform.any_op): | ||
%0 = transform.structured.match ops{["tensor.pad"]} in %arg1 | ||
: (!transform.any_op) -> !pdl.operation | ||
transform.structured.tile_to_forall_op %0 tile_sizes[1, 1] | ||
|
||
%if = transform.structured.match ops{["scf.if"]} in %arg1 | ||
: (!transform.any_op) -> !transform.any_op | ||
transform.scf.take_assumed_branch %if take_else_branch | ||
: (!transform.any_op) -> () | ||
} |