From 57643b3f4746b3f53334fd6ce8020dd6c902c7f4 Mon Sep 17 00:00:00 2001 From: Mario Lezcano Casado <3291265+lezcano@users.noreply.github.com> Date: Thu, 7 Nov 2024 20:36:33 +0000 Subject: [PATCH] Consolidate `getOrder` as "element order" and implement `getRepOrder` for general and NVIDIA layouts (#5089) This partially reverts commit 38a11b859fff79ea214256d3f1cfe43d54e36c2c. Supersedes https://github.com/triton-lang/triton/pull/5085 It also documents that we are implicitly choosing a way to tile a full tensor depending on the layout. See https://github.com/triton-lang/triton/pull/5085#issuecomment-2460925683 --- include/triton/Analysis/Utility.h | 2 +- .../Conversion/TritonGPUToLLVM/Utility.h | 7 +-- include/triton/Dialect/TritonGPU/IR/Dialect.h | 5 +- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 11 +++++ lib/Analysis/Allocation.cpp | 8 +--- lib/Analysis/AxisInfo.cpp | 6 +-- lib/Analysis/Utility.cpp | 10 ++-- .../TritonGPUToLLVM/ReduceOpToLLVM.cpp | 4 +- .../TritonGPUToLLVM/ScanOpToLLVM.cpp | 6 +-- .../TritonGPUToLLVM/ViewOpToLLVM.cpp | 6 +-- lib/Dialect/TritonGPU/IR/Dialect.cpp | 46 ++++++++++++++++++- .../TritonGPU/IR/LinearLayoutConversions.cpp | 19 ++++---- .../Transforms/ReduceDataDuplication.cpp | 10 ++-- test/Conversion/amd/load_store.mlir | 29 ++++++++++++ .../DecomposeUnsupportedConversions.cpp | 2 +- .../DotOpToLLVM/MMAv2.cpp | 10 ++++ 16 files changed, 136 insertions(+), 45 deletions(-) diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index 20da9784495d..df6029db0de2 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -66,7 +66,7 @@ class ReduceOpHelper { // The shape of the shared memory space needed for the reduction. SmallVector getScratchRepShape(); - SmallVector getThreadOrderWithAxisAtBeginning(); + SmallVector getOrderWithAxisAtBeginning(); unsigned getScratchSizeInBytes(); diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 3ae94d3339ad..fdb503eed247 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -520,15 +520,16 @@ emitBaseIndexWithinCTAForBlockedLayout(Location loc, RewriterBase &rewriter, auto sizePerThread = blockedLayout.getSizePerThread(); auto threadsPerWarp = blockedLayout.getThreadsPerWarp(); auto warpsPerCTA = blockedLayout.getWarpsPerCTA(); - auto order = blockedLayout.getOrder(); + auto threadOrder = blockedLayout.getThreadOrder(); + auto warpOrder = blockedLayout.getWarpOrder(); auto shapePerCTA = triton::gpu::getShapePerCTA(blockedLayout, shape); unsigned rank = shape.size(); // delinearize threadId to get the base index SmallVector multiDimWarpId = - delinearize(rewriter, loc, warpId, warpsPerCTA, order); + delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder); SmallVector multiDimThreadId = - delinearize(rewriter, loc, laneId, threadsPerWarp, order); + delinearize(rewriter, loc, laneId, threadsPerWarp, threadOrder); SmallVector multiDimBase(rank); for (unsigned k = 0; k < rank; ++k) { diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index cfc00926ddc2..a9b49448c1d0 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -76,9 +76,8 @@ SmallVector getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef tensorShape); // Returns the dimensions of the tensor from minor (fast-varying) to -// major (slow-varying). For blocked, mma, and dotOperand layouts, -// though the elements are in registers, the order refers to memory -// layout of the original tensor in global memory. +// major (slow-varying). For distributed layouts, this represents +// the order of the elements within a thread. // For shared Layout, the order refers to which dimension of the original tensor // is contiguous in shared memory. SmallVector getOrder(Attribute layout); diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 8f9a1a850fd5..33308fb24569 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -474,9 +474,16 @@ layout = [0 4 8 12] [3 7 11 15] For the Threads Per Warp and Values Per Thread level, the linear id distribution is variant for each sub-class encoding. + +If the layout does not completely cover the tensor, we tile it until we cover the entire tensor. +We call each individual tile "rep". }]; let methods = [ + InterfaceMethod<"Get the order of reps (tiles of this layout that tile the whole tensor). The fastest-changing axis first", + "SmallVector", + "getRepOrder">, + // Interface for the meta information about the multiple thread hierarchy. InterfaceMethod<"Get the shape of the CTAs per CGA.", "SmallVector", @@ -563,6 +570,7 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11}, }]; code extraDistributedDeclaration = extraBaseClassDeclaration # [{ + SmallVector getRepOrder() const; SmallVector getCTAsPerCGA() const; SmallVector getCTAOrder() const; SmallVector getCTASplitNum() const; @@ -914,6 +922,7 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129, unsigned getTotalElemsPerThreadForOperand(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; SmallVector getInstrShapeForOperand(int kWidth, int opIdx) const; SmallVector getRepForOperand(ArrayRef operandShape, int kWidth, int opIdx) const; + SmallVector getRepOrderForOperand(int opIdx) const; SmallVector getContigPerThread() { auto rank = getWarpsPerCTA().size(); @@ -1022,6 +1031,7 @@ Row | warp 0 warp 2 SmallVector getElemsPerInstrForOperands() const; SmallVector getRepForOperand(ArrayRef operandShape, Type elemType, int kWidth, int opIdx) const; + SmallVector getRepOrderForOperand(int opIdx) const; static SmallVector getMNKDimPerInstr(); SmallVector getContigPerThread() { @@ -1217,6 +1227,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: int getMMAv1Vec(int opIdx) const; SmallVector getRepForOperand(ArrayRef shape, int bitwidth, int opIdx) const; + SmallVector getRepOrderForOperand(int opIdx) const; bool supportReduction() const { if (isAmpere() || isHopper()) { diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 020a9ea4d3bc..131c1ff67e84 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -84,12 +84,8 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy, assert(cvtNeedsSharedMemory(srcTy, dstTy)); - // FIXME This is NOT entirely correct - // This should be getElemOrder, but we don't have such a method - // TODO Implement getElemOrder and make sure it's consistent with - // getContigPerThread - auto inOrd = gpu::getThreadOrder(srcLayout); - auto outOrd = gpu::getThreadOrder(dstLayout); + auto inOrd = gpu::getOrder(srcLayout); + auto outOrd = gpu::getOrder(dstLayout); scratchConfig.order = outOrd; unsigned srcContigPerThread = diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index d2c2c9fd8da3..f0c5ae3167ec 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -1213,7 +1213,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) { // Here order should be ordered by contiguous first, so the first element // should have the largest contiguous. - auto order = triton::gpu::getThreadOrder(layout); + auto order = triton::gpu::getOrder(layout); unsigned align = getPtrAlignment(ptr); auto uniqueContigPerThread = @@ -1235,7 +1235,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) { if (!axisInfo) return 1; auto layout = tensorTy.getEncoding(); - auto order = triton::gpu::getThreadOrder(layout); + auto order = triton::gpu::getOrder(layout); auto maxMultipleBytes = axisInfo->getDivisibility(order[0]); auto maxContig = axisInfo->getContiguity(order[0]); auto elemNumBits = triton::getPointeeBitWidth(tensorTy); @@ -1262,7 +1262,7 @@ unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) { auto *axisInfo = getAxisInfo(mask); if (!axisInfo) return 1; - auto maskOrder = triton::gpu::getThreadOrder(tensorTy.getEncoding()); + auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding()); auto alignment = std::max(axisInfo->getConstancy(maskOrder[0]), 1); LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment " << alignment); diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 501e19722089..ac72b4f26cd6 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -32,9 +32,9 @@ int getParentAxis(Attribute layout, int axis) { return axis; } -SmallVector getParentThreadOrder(Attribute layout) { +SmallVector getParentOrder(Attribute layout) { if (auto sliceEncoding = mlir::dyn_cast(layout)) { - return getParentThreadOrder(sliceEncoding.getParent()); + return getParentOrder(sliceEncoding.getParent()); } return getThreadOrder(layout); } @@ -44,12 +44,12 @@ SmallVector getParentThreadOrder(Attribute layout) { // TODO(jlebar): Move this class into namespace triton. bool ReduceOpHelper::isReductionOnLayoutFastAxis() { return getParentAxis(getSrcLayout(), axis) == - getParentThreadOrder(getSrcLayout())[0]; + getParentOrder(getSrcLayout())[0]; } -SmallVector ReduceOpHelper::getThreadOrderWithAxisAtBeginning() { +SmallVector ReduceOpHelper::getOrderWithAxisAtBeginning() { auto srcLayout = getSrcLayout(); - auto order = getThreadOrder(srcLayout); + auto order = getOrder(srcLayout); auto it = std::find(order.begin(), order.end(), axis); // delete the axis from order order.erase(it); diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index 829d4e7104f0..26dc8a537973 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -322,7 +322,7 @@ struct ReduceOpConversion getMultiDimWarpId(helper, warpId, loc, rewriter); Value warpIdAxis = multiDimWarpId[axis]; - auto smemOrder = helper.getThreadOrderWithAxisAtBeginning(); + auto smemOrder = helper.getOrderWithAxisAtBeginning(); for (auto it : accs) { const SmallVector &key = it.first; SmallVector &acc = it.second; @@ -409,7 +409,7 @@ struct ReduceOpConversion Location loc = op.getLoc(); auto srcLayout = helper.getSrcLayout(); auto axis = op.getAxis(); - auto smemOrder = helper.getThreadOrderWithAxisAtBeginning(); + auto smemOrder = helper.getOrderWithAxisAtBeginning(); SmallVector results(op.getNumOperands()); for (unsigned i = 0; i < op.getNumOperands(); ++i) { auto elemTy = getElementType(op, i); diff --git a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp index 969b227c8dda..64e6ca787780 100644 --- a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp @@ -389,10 +389,10 @@ ScanOpConversion::getDelinearizedIds(ConversionPatternRewriter &rewriter, auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding); auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding); - auto order = triton::gpu::getOrder(srcEncoding); + auto threadOrder = triton::gpu::getThreadOrder(srcEncoding); auto warpOrder = triton::gpu::getWarpOrder(srcEncoding); SmallVector multiDimLaneId = - delinearize(rewriter, loc, laneId, threadsPerWarp, order); + delinearize(rewriter, loc, laneId, threadsPerWarp, threadOrder); SmallVector multiDimWarpId = delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder); @@ -402,7 +402,7 @@ ScanOpConversion::getDelinearizedIds(ConversionPatternRewriter &rewriter, multiDimLaneId[axis] = i32_val(0); threadsPerWarp[axis] = 1; Value laneIdParallel = - linearize(rewriter, loc, multiDimLaneId, threadsPerWarp, order); + linearize(rewriter, loc, multiDimLaneId, threadsPerWarp, threadOrder); multiDimWarpId[axis] = i32_val(0); warpsPerCTA[axis] = 1; Value warpIdParallel = diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp index be5feedd67bf..8ba0fd3356f6 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -181,9 +181,9 @@ struct SplitOpConversion : public ConvertOpToLLVMPattern { int numContiguousValues = 1; auto encoding = cast( cast(op.getSrc().getType()).getEncoding()); - int splitDim = encoding.getThreadOrder().size() - 1; - for (int i = 0; i < encoding.getThreadOrder().size(); i++) { - if (encoding.getThreadOrder()[i] == splitDim) + int splitDim = encoding.getOrder().size() - 1; + for (int i = 0; i < encoding.getOrder().size(); i++) { + if (encoding.getOrder()[i] == splitDim) break; numContiguousValues *= encoding.getSizePerThread()[i]; } diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 8462c24aea67..3338638d48b5 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -261,6 +261,14 @@ SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank, return getMatrixOrder(rank, rowMajor); } +SmallVector getRepOrder(Attribute layout) { + if (auto distributedLayout = mlir::dyn_cast(layout)) + return distributedLayout.getRepOrder(); + else + llvm::report_fatal_error("Unimplemented usage of getRepOrder"); + return {}; +} + SmallVector getWarpOrder(Attribute layout) { if (auto distributedLayout = mlir::dyn_cast(layout)) return distributedLayout.getWarpOrder(); @@ -269,13 +277,13 @@ SmallVector getWarpOrder(Attribute layout) { return {}; } +// Returns the order of the elements in a layout from the fastest running +// dimension to the slowest SmallVector getOrder(Attribute layout) { if (auto blockedLayout = dyn_cast(layout)) { return llvm::to_vector(blockedLayout.getOrder()); } if (auto mmaLayout = dyn_cast(layout)) { - // Order doesn't really matter. We just have to be consistent when unpacking - // the output elements in the LLVM lowerings. We choose row-major auto distributedLayout = cast(layout); auto rank = distributedLayout.getWarpsPerCTA().size(); return getMatrixOrder(rank, /*rowMajor*/ true); @@ -643,6 +651,9 @@ unsigned BlockedEncodingAttr::getTotalElemsPerThread(ArrayRef shape, // If we only had BlockedEncodingAttr, we could simply return ArrayRefs here. // But we need to have a consistent interface with e.g. SliceEncodingAttr, which // computes some of these fields. +SmallVector BlockedEncodingAttr::getRepOrder() const { + return SmallVector(getOrder()); +} SmallVector BlockedEncodingAttr::getCTAsPerCGA() const { return SmallVector(getCTALayout().getCTAsPerCGA()); } @@ -709,6 +720,10 @@ unsigned SliceEncodingAttr::getTotalElemsPerThread(ArrayRef shape, Type eltTy) const { return product(getElemsPerThread(shape, eltTy)); } +SmallVector SliceEncodingAttr::getRepOrder() const { + auto parentRepOrder = ::getRepOrder(getParent()); + return eraseOrder(parentRepOrder, getDim()); +} SmallVector SliceEncodingAttr::getCTASplitNum() const { SmallVector res = ::getCTASplitNum(getParent()); res.erase(res.begin() + getDim()); @@ -1651,6 +1666,10 @@ AMDMfmaEncodingAttr::getInstrShapeForOperand(int kWidth, int opIdx) const { return {kDim, nDim}; } +SmallVector AMDMfmaEncodingAttr::getRepOrder() const { + llvm::report_fatal_error("NYI. AMDMfmaEncodingAttr::getRepOrder"); +} + SmallVector AMDMfmaEncodingAttr::getRepForOperand(ArrayRef operandShape, int kWidth, int opIdx) const { @@ -1734,6 +1753,9 @@ AMDWmmaEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { shapePerCTATile[rank - 1] *= mnkDim[1]; return shapePerCTATile; } +SmallVector AMDWmmaEncodingAttr::getRepOrder() const { + llvm::report_fatal_error("NYI. AMDWmmaEncodingAttr::getRepOrder"); +} SmallVector AMDWmmaEncodingAttr::getCTAsPerCGA() const { return SmallVector(getCTALayout().getCTAsPerCGA()); } @@ -1858,6 +1880,10 @@ bool NvidiaMmaEncodingAttr::isAmpere() const { return getVersionMajor() == 2; } bool NvidiaMmaEncodingAttr::isHopper() const { return getVersionMajor() == 3; } +SmallVector NvidiaMmaEncodingAttr::getRepOrder() const { + auto rank = getWarpsPerCTA().size(); + return getMatrixOrder(rank, /*rowMajor*/ true); +} SmallVector NvidiaMmaEncodingAttr::getCTAsPerCGA() const { return SmallVector(getCTALayout().getCTAsPerCGA()); } @@ -2011,6 +2037,13 @@ SmallVector NvidiaMmaEncodingAttr::getMMAv1ShapePerWarp(int opIdx) const { int NvidiaMmaEncodingAttr::getMMAv1Vec(int opIdx) const { return 2 * getMMAv1Rep(opIdx)[opIdx]; } + +SmallVector +NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const { + auto rank = getWarpsPerCTA().size(); + return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true); +} + SmallVector NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef shape, int bitwidth, int opIdx) const { @@ -2147,6 +2180,15 @@ NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { //===----------------------------------------------------------------------===// // DotOperand Encoding //===----------------------------------------------------------------------===// +SmallVector DotOperandEncodingAttr::getRepOrder() const { + if (auto mma = mlir::dyn_cast(getParent())) { + return mma.getRepOrderForOperand(getOpIdx()); + } + llvm::report_fatal_error( + "getRepOrder not implemented for DotOperandEncodingAttr"); + return {}; +} + SmallVector DotOperandEncodingAttr::getThreadsPerWarp() const { auto parent = getParent(); if (auto mma = mlir::dyn_cast(parent)) { diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 2668f384978e..43c87af487a1 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -292,9 +292,9 @@ LinearLayout ampereMmaToLinearLayout(ArrayRef shape, MLIRContext *ctx = mma.getContext(); SmallVector dimNames = standardOutDimNames(ctx, rank); - auto orderedDimNames = permuteDimNames(dimNames, getOrder(mma)); - // By using `reverse(dimNames)` below, we set the order to be row-major - assert(getOrder(mma) == getMatrixOrder(rank, /*rowMajor=*/true)); + + auto orderedDimNames = permuteDimNames(dimNames, mma.getRepOrder()); + assert(mma.getRepOrder() == getMatrixOrder(rank, /*rowMajor=*/true)); LinearLayout ctaLayout( {{S("register"), {{1, 0}, {0, 8}}}, @@ -327,7 +327,6 @@ LinearLayout hopperMmaToLinearLayout(ArrayRef shape, assert(n == 8 || n == 16 || n == 32 || n == 64 || n == 128 || n == 256); assert(k == 8 || k == 16 || k == 32); - // TODO Make the getOrder of Hopper explicit here via an assert MLIRContext *ctx = mma.getContext(); LinearLayout ctaLayout( {{S("register"), {{1, 0}, {0, 8}}}, @@ -875,14 +874,18 @@ LinearLayout ampereDotToLinearLayout(ArrayRef shape, assert(mma.isAmpere()); MLIRContext *ctx = mma.getContext(); - // A and B have kMajor order - assert(getOrder(dot) == + + // The A and B operands are tiled in a kMajor fashion + auto kMajorOrder = dot.getRepOrder(); + assert(kMajorOrder == getOrderForDotOperand(dot.getOpIdx(), rank, /*kMajor=*/true)); auto kMajorDims = - permuteDimNames(standardOutDimNames(ctx, rank), getOrder(dot)); + permuteDimNames(standardOutDimNames(ctx, rank), kMajorOrder); + // This agrees with the order of the elements, which means that we can share + // the code below for both A and B without having to perform any swaps + assert(getOrder(dot) == kMajorOrder); - // Implement A. For B transpose in the end std::vector> registers; std::vector> lanes; int32_t i = 1; diff --git a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp index dce6c7f2af1b..b1e296c1bbe4 100644 --- a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp +++ b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp @@ -44,17 +44,17 @@ class TritonGPUReduceDataDuplicationPass return; if (!cvtNeedsSharedMemory(srcType, dstType)) return; - auto srcThreadOrder = triton::gpu::getThreadOrder(srcEncoding); - auto rank = srcThreadOrder.size(); + auto srcOrder = triton::gpu::getOrder(srcEncoding); + auto rank = srcOrder.size(); SmallVector sharedOrder; if (rank == 3) { // add all elements except the element that is zero for (unsigned i = 0; i < rank; ++i) - if (srcThreadOrder[i] != 0) - sharedOrder.emplace_back(srcThreadOrder[i]); + if (srcOrder[i] != 0) + sharedOrder.emplace_back(srcOrder[i]); sharedOrder.emplace_back(0); } else { - sharedOrder = srcThreadOrder; + sharedOrder = srcOrder; } auto sharedMemorySpace = triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext()); diff --git a/test/Conversion/amd/load_store.mlir b/test/Conversion/amd/load_store.mlir index 93796439b012..543ed4f2df12 100644 --- a/test/Conversion/amd/load_store.mlir +++ b/test/Conversion/amd/load_store.mlir @@ -27,3 +27,32 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : tt.return } } + +// ----- + +#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: global_store_mfma_vec16 + tt.func public @global_store_mfma_vec16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %cst_1 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %0 = tt.dot %cst_0, %cst_1, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<32x32xf32, #mma> + %1 = math.exp2 %0 : tensor<32x32xf32, #mma> + %2 = arith.truncf %1 : tensor<32x32xf32, #mma> to tensor<32x32xf16, #mma> + %c32_i32 = arith.constant 32 : i32 + %100 = tt.get_program_id x : i32 + %101 = arith.muli %100, %c32_i32 : i32 + %102 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %300 = tt.expand_dims %102 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<1x32xi32, #mma> + %200 = tt.broadcast %300 : tensor<1x32xi32, #mma> -> tensor<32x32xi32, #mma> + %103 = tt.splat %101 : i32 -> tensor<32x32xi32, #mma> + %104 = arith.addi %103, %200 : tensor<32x32xi32, #mma> + %105 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #mma> + %106 = tt.addptr %105, %104 : tensor<32x32x!tt.ptr, #mma>, tensor<32x32xi32, #mma> + // Store 16 elements with four vectorized store instruction + // CHECK-COUNT-4: llvm.intr.masked.store {{.*}}, {{.*}}, {{.*}} {alignment = 16 : i32} : vector<4xf16>, vector<4xi1> into !llvm.ptr + tt.store %106, %2 : tensor<32x32x!tt.ptr, #mma> + tt.return + } +} diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp index 1aa2b516a559..40cb55bbc00d 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -54,7 +54,7 @@ class DecomposeLocalLoadToDotOperand type.getShape(), type.getElementType(), triton::gpu::SharedEncodingAttr::get( op.getContext(), dstDotOp, type.getShape(), - triton::gpu::getThreadOrder(parentEnc), + triton::gpu::getOrder(parentEnc), triton::gpu::getCTALayout(parentEnc), type.getElementType()), srcType.getMemorySpace()); auto tmp = rewriter.create( diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index ebf511b40ca6..7b7ca7d1e238 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -9,6 +9,7 @@ using namespace mlir; using namespace mlir::triton; using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::getOrderForDotOperand; using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; using ValueTableV2 = std::map, Value>; @@ -412,11 +413,20 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, int repM = repA[1], repN = repB[2], repK = repA[2]; int repBatch = repA[0]; + // We can reuse the same iteration order in + // getValuesFromDotOperandLayoutStruct as both a and b are K-major + assert(dotOpA.getRepOrder() == getOrderForDotOperand(dotOpA.getOpIdx(), + aShapePerCTA.size(), + /*kMajor=*/true)); auto ha = getValuesFromDotOperandLayoutStruct( typeConverter, loc, rewriter, loadedA, repBatch, repM, repK, aTensorTy); + assert(dotOpB.getRepOrder() == getOrderForDotOperand(dotOpB.getOpIdx(), + bShapePerCTA.size(), + /*kMajor=*/true)); auto hb = getValuesFromDotOperandLayoutStruct( typeConverter, loc, rewriter, loadedB, repBatch, repN, repK, bTensorTy); + auto fc = unpackLLElements(loc, loadedC, rewriter); auto numMmaRets = dTensorTy.getElementType().getIntOrFloatBitWidth() / 8; int numCPackedElem = 4 / numMmaRets;