Skip to content

Commit

Permalink
[AMD] Fix issue with rank=1 in tryFitCvtIntoLDS (triton-lang#5084)
Browse files Browse the repository at this point in the history
Co-authored-by: Sam Ginzburg <[email protected]>
  • Loading branch information
SamGinzburg and Sam Ginzburg authored Nov 7, 2024
1 parent 9378d8f commit 4af6cf5
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 5 deletions.
19 changes: 19 additions & 0 deletions test/TritonGPU/amd/optimize-lds-usage.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,22 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war
tt.return
}
}

// -----

// Checks that optimization do not crash on 1d tensor
// CHECK-LABEL: convert_1d
// CHECK: triton_gpu.local_alloc
// CHECK-NEXT: triton_gpu.convert_layout
// CHECK-NEXT: triton_gpu.local_load
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @convert_1d(%arg0: tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>>, %arg1: tensor<128x128xf32, #mma>) attributes {noinline = false} {
%alloc = triton_gpu.local_alloc %arg1 : (tensor<128x128xf32, #mma>) -> !tt.memdesc<128x128xf32, #shared, #triton_gpu.shared_memory>
%1 = triton_gpu.convert_layout %arg0 : tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<128xf32, #blocked>
%load = triton_gpu.local_load %alloc : !tt.memdesc<128x128xf32, #shared, #triton_gpu.shared_memory> -> tensor<128x128xf32, #mma>
tt.return
}
}
16 changes: 13 additions & 3 deletions third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ class OptimizeAMDLDSUsage
auto dstEnc = dstType.getEncoding();

auto ctx = srcEnc.getContext();
auto rank = srcType.getShape().size();
auto rank = srcType.getRank();

unsigned numWarps = triton::gpu::getNumWarpsPerCTA(srcEnc);
auto warpSize = triton::gpu::getWarpSize(srcEnc);

Expand All @@ -109,11 +110,20 @@ class OptimizeAMDLDSUsage
// Create a list of temporary layouts
SmallVector<unsigned> elemsPerThread(rank, 1);
SmallVector<unsigned> threadsPerWarp(rank, 1);
threadsPerWarp[rank - 1] = warpSize / 8;
threadsPerWarp[rank - 2] = warpSize / threadsPerWarp[rank - 1];

// Special case for rank == 1
if (rank == 1) {
threadsPerWarp[0] = warpSize;
} else {
assert(rank > 1);
threadsPerWarp[rank - 1] = warpSize / 8;
threadsPerWarp[rank - 2] = warpSize / threadsPerWarp[rank - 1];
}

auto layoutCTA = triton::gpu::getCTALayout(srcEnc);
auto order = triton::gpu::getOrder(srcEnc);
SmallVector<unsigned> dummyWarpsPerCTA(rank, 1);

auto baseFallbackLayout = triton::gpu::BlockedEncodingAttr::get(
ctx, elemsPerThread, threadsPerWarp, dummyWarpsPerCTA, order,
layoutCTA);
Expand Down
8 changes: 6 additions & 2 deletions third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,13 @@ Attribute createTmpLayout(Attribute layout, ArrayRef<unsigned> warpsPerCTA) {
ctx, src.getOpIdx(), createTmpLayout(src.getParent(), warpsPerCTA),
src.getKWidth());
}
if (auto src = dyn_cast<triton::gpu::SliceEncodingAttr>(layout))
if (auto src = dyn_cast<triton::gpu::SliceEncodingAttr>(layout)) {
// TODO: think of a way to construct slice layouts based on warpsPerCTA
// argument
auto parentWarpsPerCTA = triton::gpu::getWarpsPerCTA(src.getParent());
return triton::gpu::SliceEncodingAttr::get(
ctx, src.getDim(), createTmpLayout(src.getParent(), warpsPerCTA));
ctx, src.getDim(), createTmpLayout(src.getParent(), parentWarpsPerCTA));
}
assert("Encountered unsupported layout");
return Attribute();
}
Expand Down

0 comments on commit 4af6cf5

Please sign in to comment.