Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend offset composition to cover scf, air.hier and affine loops #730

Merged
merged 2 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions mlir/lib/Util/Util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -874,10 +874,14 @@ LogicalResult eraseWrapNStrideDim(OpBuilder builder,
offset_producer = castOp.getIn().getDefiningOp();
}
if (!offset_producer) {
if (!affine::getForInductionVarOwner(offsets[i]))
if (auto afo = affine::getForInductionVarOwner(offsets[i])) {
builder.setInsertionPointToStart(afo.getBody());
} else if (auto sfo = scf::getForInductionVarOwner(offsets[i])) {
builder.setInsertionPointToStart(sfo.getBody());
} else if (auto aho = air::getHierarchyArgOwner(offsets[i])) {

} else
continue;
auto afo = affine::getForInductionVarOwner(offsets[i]);
builder.setInsertionPointToStart(afo.getBody());
// Create a new affine.apply on affine.for ind. vars, as handle for
// subsequent offset composition.
auto sym0_expr = getAffineSymbolExpr(0, builder.getContext());
Expand Down Expand Up @@ -960,12 +964,19 @@ LogicalResult air::canonicalizeWrapAndStrideList(OpBuilder builder,
if (!sizes.empty()) {
for (int i = sizes.size() - 1; i >= 1; i--) {
auto const_offset = getConstantIntValue(offsets[i]);
if (!const_offset)
continue;
auto const_size = getConstantIntValue(sizes[i]);
if (!const_size)
continue;
auto const_stride = getConstantIntValue(strides[i]);
if (!const_stride)
continue;
auto const_offset_prev = getConstantIntValue(offsets[i - 1]);
if (!const_offset_prev)
continue;
auto const_stride_prev = getConstantIntValue(strides[i - 1]);
if (!(const_offset && const_size && const_stride && const_offset_prev &&
const_stride_prev))
if (!const_stride_prev)
continue;
if (*const_stride_prev == *const_size * *const_stride)
listsHaveChanged |=
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#map = affine_map<()[s0] -> (s0 * 32)>
#map1 = affine_map<()[s0, s1] -> (s0 + s1)>
#map2 = affine_map<(d0, d1) -> (d0 + d1)>
#map3 = affine_map<()[s0] -> (s0 * 64)>
module {

// CHECK-LABEL: test0
Expand Down Expand Up @@ -450,4 +451,36 @@ module {
return
}

// Offset propagated from scf.for and air.hier induction vars.
// CHECK-LABEL: test13

// CHECK: air.channel.put async [%{{.*}}] @channel_14[] (%{{.*}}[%c0, %1, %results, %c0] [%c8, %c2_0, %c32, %c32] [%c32, %c8192, %c256, %c1]) : (memref<2x128x256xi32>)

func.func @test13(%arg0: memref<2x128x256xi32>, %arg1: memref<2x256x128xi32>) {
%c2 = arith.constant 2 : index
%0 = air.launch async (%arg3, %arg4, %arg5) in (%arg6=%c2, %arg7=%c2, %arg8=%c2) args(%arg10=%arg0, %arg11=%arg1) : memref<2x128x256xi32>, memref<2x256x128xi32> {
%c4096 = arith.constant 4096 : index
%c8 = arith.constant 8 : index
%c16384 = arith.constant 16384 : index
%c64 = arith.constant 64 : index
%c128 = arith.constant 128 : index
%c256 = arith.constant 256 : index
%c32 = arith.constant 32 : index
%c8192 = arith.constant 8192 : index
%c32768 = arith.constant 32768 : index
%c0 = arith.constant 0 : index
%c2_0 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%async_token, %results = air.execute -> (index) {
%7 = affine.apply #map3()[%arg4]
air.execute_terminator %7 : index
}
%2 = scf.for %arg12 = %c0 to %c256 step %c32 iter_args(%arg13 = %async_token) -> (!air.async.token) {
%7 = air.channel.put async [%arg13, %async_token] @channel_14[] (%arg10[%arg3, %c0, %c0, %results, %arg12] [%c1, %c2_0, %c1, %c32, %c32] [%c32768, %c8192, %c32, %c256, %c1]) : (memref<2x128x256xi32>)
scf.yield %7 : !air.async.token
}
}
return
}

}
Loading