Skip to content

Commit

Permalink
[mlir][SCF] Add an scf.take_assumed_branch transform op.
Browse files Browse the repository at this point in the history
Given an scf.if conditional, using this transformation is akin to injecting
user-specified information that it is always safe to execute only the specified
`if` or `else` branch.

This is achieved by just replacing the scf.if by the content of one of its
branches.

This is particularly useful for user-controlled rewriting of conditionals
that exist solely to guard against out-of-bounds behavior.

At the moment, no assume or assert operation is emitted as it is not always
desirable. In the future, this may be controlled by a dedicated attribute.

Differential Revision: https://reviews.llvm.org/D148125
  • Loading branch information
nicolasvasilache committed Apr 12, 2023
1 parent 34f5774 commit 88b7e8e
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class FuncOp;
} // namespace func
namespace scf {
class ForOp;
class IfOp;
} // namespace scf
} // namespace mlir

Expand Down
41 changes: 41 additions & 0 deletions mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -215,4 +215,45 @@ def LoopCoalesceOp : Op<Transform_Dialect, "loop.coalesce", [
}];
}

def TakeAssumedBranchOp : Op<Transform_Dialect, "scf.take_assumed_branch", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformOpInterface, TransformEachOpTrait]> {
let description = [{
Given an scf.if conditional, inject user-defined information that it is
always safe to execute only the if or else branch.

This is achieved by just replacing the scf.if by the content of one of its
branches.

This is particularly useful for user-controlled rewriting of conditionals
that exist solely to guard against out-of-bounds behavior.

At the moment, no assume or assert operation is emitted as it is not always
desirable. In the future, this may be controlled by a dedicated attribute.

#### Return modes

The transform only consumes its operand and does not produce any result.
The transform definitely fails if `take_else_branch` is specified and the
`else` region is empty.
}];
let arguments = (ins TransformHandleTypeInterface:$target,
OptionalAttr<UnitAttr>:$take_else_branch);
let results = (outs);

let assemblyFormat = [{
$target
(`take_else_branch` $take_else_branch^)?
attr-dict
`:` functional-type(operands, results)
}];

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::scf::IfOp ifOp,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
}

#endif // SCF_TRANSFORM_OPS
41 changes: 41 additions & 0 deletions mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"

using namespace mlir;
Expand Down Expand Up @@ -245,6 +246,46 @@ transform::LoopCoalesceOp::applyToOne(Operation *op,
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// TakeAssumedBranchOp
//===----------------------------------------------------------------------===//
/// Replaces the given op with the contents of the given single-block region,
/// using the operands of the block terminator to replace operation results.
static void replaceOpWithRegion(RewriterBase &rewriter, Operation *op,
Region &region) {
assert(llvm::hasSingleElement(region) && "expected single-region block");
Block *block = &region.front();
Operation *terminator = block->getTerminator();
ValueRange results = terminator->getOperands();
rewriter.inlineBlockBefore(block, op, /*blockArgs=*/{});
rewriter.replaceOp(op, results);
rewriter.eraseOp(terminator);
}

DiagnosedSilenceableFailure transform::TakeAssumedBranchOp::applyToOne(
scf::IfOp ifOp, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
TrackingListener listener(state, *this);
IRRewriter rewriter(ifOp->getContext(), &listener);
rewriter.setInsertionPoint(ifOp);

Region &region =
getTakeElseBranch() ? ifOp.getElseRegion() : ifOp.getThenRegion();
if (!llvm::hasSingleElement(region)) {
return emitDefiniteFailure()
<< "requires an scf.if op with a single-block "
<< ((getTakeElseBranch()) ? "`else`" : "`then`") << " region";
}
replaceOpWithRegion(rewriter, ifOp, region);
return DiagnosedSilenceableFailure::success();
}

void transform::TakeAssumedBranchOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getTarget(), effects);
modifiesPayload(effects);
}

//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
Expand Down
50 changes: 50 additions & 0 deletions mlir/test/Dialect/SCF/transform-op-take-assumed-branch.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -verify-diagnostics --allow-unregistered-dialect | FileCheck %s

func.func @if_no_else(%cond: i1, %a: index, %b: memref<?xf32>, %c: i8) {
scf.if %cond {
"some_op"(%cond, %b) : (i1, memref<?xf32>) -> ()
scf.yield
}
return
}

transform.sequence failures(propagate) {
^bb0(%arg1: !transform.any_op):
%if = transform.structured.match ops{["scf.if"]} in %arg1
: (!transform.any_op) -> !transform.any_op

// expected-error @+1 {{requires an scf.if op with a single-block `else` region}}
transform.scf.take_assumed_branch %if take_else_branch
: (!transform.any_op) -> ()
}

// -----

// CHECK-LABEL: tile_tensor_pad
func.func @tile_tensor_pad(
%arg0 : tensor<?x?xf32>, %cst : f32, %low: index, %high: index)
-> tensor<20x40xf32>
{
// CHECK: scf.forall
// CHECK-NOT: scf.if
// CHECK-NOT: tensor.generate
// CHECK-NOT: else
// CHECK: tensor.pad {{.*}} nofold
%0 = tensor.pad %arg0 nofold low[%low, %low] high[%high, %high] {
^bb0(%arg9: index, %arg10: index):
tensor.yield %cst : f32
} : tensor<?x?xf32> to tensor<20x40xf32>
return %0 : tensor<20x40xf32>
}

transform.sequence failures(propagate) {
^bb0(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["tensor.pad"]} in %arg1
: (!transform.any_op) -> !pdl.operation
transform.structured.tile_to_forall_op %0 tile_sizes[1, 1]

%if = transform.structured.match ops{["scf.if"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.scf.take_assumed_branch %if take_else_branch
: (!transform.any_op) -> ()
}

0 comments on commit 88b7e8e

Please sign in to comment.