Skip to content

Commit

Permalink
[MLIR][NVGPU] Improve and Cleanup verifier of TMA OPs (llvm#70923)
Browse files Browse the repository at this point in the history
This PR improves and cleans-up verifiers of TmaCreateDescriptor and
TmaAsyncLoad Ops and unifies them.

The PR verifiers followings that didn't before:
- address space
- rank match between descriptor and memref
- element type match between descriptor and memref
- shape type match between descriptor and memref
  • Loading branch information
grypp authored Nov 8, 2023
1 parent 96b5e09 commit 6eb97f0
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 31 deletions.
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ constexpr int kWarpSize = 32;

/// M size of wgmma.mma_async instruction
constexpr int kWgmmaSizeM = 64;
/// Maximum tensor dimension that TMA supports
constexpr int kMaxTMATensorDimension = 5;

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc"
Expand Down
102 changes: 72 additions & 30 deletions mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,34 +335,83 @@ LogicalResult LdMatrixOp::verify() {
// NVGPU_TmaAsyncLoadOp
//===----------------------------------------------------------------------===//

LogicalResult TmaAsyncLoadOp::verify() {
// Destination memref
auto dstMemref = llvm::cast<MemRefType>(getDst().getType());
std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
Operation *op, nvgpu::TensorMapDescriptorType descType,
std::optional<MemRefType> memrefType = std::nullopt) {
MemRefType descMemref = descType.getTensor();
// Limitation
if (descType.getInterleave() != TensorMapInterleaveKind::INTERLEAVE_NONE)
return op->emitError() << "Interleave options are not supported yet.";

// Address space check for shared memory check
if (!NVGPUDialect::hasSharedMemoryAddressSpace(descMemref)) {
return op->emitError() << "the tensor map descriptor has incorrect address "
"space, it must be shared memory address space.";
}
// Support only static shape for the time being
if (!descMemref.hasStaticShape())
return op->emitError() << "the tensor map descriptor must be static shaped";

// No verification if memref type is not provided
if (!memrefType.has_value())
return std::nullopt;

MemRefType dstMemref = memrefType.value();

// Check element type
if (descMemref.getElementType() != dstMemref.getElementType()) {
return op->emitError() << "the element type of tensor map descriptor and "
"memref must be same";
}

if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref)) {
return emitError()
<< "The operation stores data to shared memory, but "
"the destination memref does not have a memory space of "
<< NVGPUDialect::kSharedMemoryAddressSpace;
return op->emitError() << "the destination memref has incorrect address "
"space, it must be shared memory address space.";
}
if (getCoordinates().size() > 5) {
return emitError() << "Maximum 5 coordinates are supported.";
if (!dstMemref.hasStaticShape())
return op->emitError() << "the destination memref must be static shaped";

if (dstMemref.getRank() != descMemref.getRank()) {
return op->emitError() << "the shape of tensor map descriptor and "
"memref must have same rank";
}
if (getCoordinates().size() != size_t(dstMemref.getRank())) {
return emitError() << "Destination memref rank is "
<< size_t(dstMemref.getRank()) << " but there are "
<< getCoordinates().size()
<< " coordinates. They must match.";
if (!descMemref.getShape().equals(dstMemref.getShape())) {
return op->emitError() << "memref and tensor map shapes mismatch "
<< descMemref << " != " << dstMemref;
}

return std::nullopt;
}

LogicalResult TmaAsyncLoadOp::verify() {
std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref(
*this, getTensorMapDescriptor().getType(), getDst().getType());
if (error.has_value())
return error.value();

if (getCoordinates().size() > kMaxTMATensorDimension) {
return emitError() << "Maximum " << kMaxTMATensorDimension
<< " coordinates are supported.";
}
if (getCoordinates().size() !=
getTensorMapDescriptor().getType().getTensor().getRank()) {
return emitError() << "number of coordinates do not match with the rank of "
"tensor descriptor map.";
}

return success();
}

LogicalResult TmaCreateDescriptorOp::verify() {
if (getBoxDimensions().size() > 5) {
return emitError() << "Maximum 5 dimensional box is supported.";
if (getBoxDimensions().size() > kMaxTMATensorDimension) {
return emitError() << "Maximum " << kMaxTMATensorDimension
<< " coordinates are supported.";
}
nvgpu::TensorMapDescriptorType desc = getTensorMap().getType();
if (desc.getInterleave() != TensorMapInterleaveKind::INTERLEAVE_NONE)
return emitError() << "Interleave options are not supported yet.";

std::optional<InFlightDiagnostic> error =
verifyTmaDescriptorWithMemref(*this, getTensorMap().getType());
if (error.has_value())
return error.value();

return success();
}
Expand All @@ -372,17 +421,10 @@ LogicalResult TmaCreateDescriptorOp::verify() {
//===----------------------------------------------------------------------===//

LogicalResult WarpgroupGenerateDescriptorOp::verify() {
MemRefType memrefType = getTensor().getType();
MemRefType tensorMapType = getTensorMap().getType().getTensor();

if (memrefType != tensorMapType)
return emitError() << "memref and tensor map type mismatch";

if (!memrefType.hasStaticShape() || !tensorMapType.hasStaticShape())
return emitError() << "supports only static shapes";

if (memrefType.getRank() != 2)
return emitError() << "supports only 2d memref is supported for now";
std::optional<InFlightDiagnostic> error =
verifyTmaDescriptorWithMemref(*this, getTensorMap().getType());
if (error.has_value())
return error.value();

if (getTensorMap().getType().getSwizzle() !=
TensorMapSwizzleKind::SWIZZLE_128B) {
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ func.func @mbarrier_txcount_pred() {
!tensorMap1d = !nvgpu.tensormap.descriptor<tensor = memref<128xf32,3>, swizzle=none, l2promo = none, oob = nan, interleave = none>
!tensorMap2d = !nvgpu.tensormap.descriptor<tensor = memref<32x32xf32,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
!tensorMap3d = !nvgpu.tensormap.descriptor<tensor = memref<2x32x32xf32,3>, swizzle=swizzle_64b, l2promo = l2promo_64b, oob = zero, interleave = none>
!tensorMap4d = !nvgpu.tensormap.descriptor<tensor = memref<2x2x32x32xf32,3>, swizzle=swizzle_128b,l2promo = l2promo_128b,oob = zero, interleave = interleave_16b>
!tensorMap4d = !nvgpu.tensormap.descriptor<tensor = memref<2x2x32x32xf32,3>, swizzle=swizzle_128b,l2promo = l2promo_128b,oob = zero, interleave = none>
!tensorMap5d = !nvgpu.tensormap.descriptor<tensor = memref<2x2x2x32x32xf32,3>, swizzle=none, l2promo = none, oob = zero, interleave = none>
!mbarrier = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
func.func @async_tma_load(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d, %tensorMap3d: !tensorMap3d, %tensorMap4d: !tensorMap4d, %tensorMap5d: !tensorMap5d,
Expand Down
43 changes: 43 additions & 0 deletions mlir/test/Dialect/NVGPU/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,46 @@ func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc: !tR
%0 = nvgpu.warpgroup.mma %descA, %descB, %acc: !tDescA, !tDescB, !tResult -> !tResult
return
}

// -----

!desc = !nvgpu.tensormap.descriptor<tensor = memref<32x32xf32,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
!mbarrier = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
func.func @tma_load_1(%desc: !desc, %buffer1: memref<128xf32,3>, %buffer2: memref<32x32xf32,3>, %buffer3: memref<32x32xf32>, %mbarrier: !mbarrier) {
%c0 = arith.constant 0 : index
// Pass fine
nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer2 : !desc, !mbarrier -> memref<32x32xf32,3>
// expected-error @+1 {{Maximum 5 coordinates are supported.}}
nvgpu.tma.async.load %desc[%c0, %c0, %c0, %c0, %c0, %c0], %mbarrier[%c0] to %buffer2 : !desc, !mbarrier -> memref<32x32xf32,3>
return
}
// -----

!desc = !nvgpu.tensormap.descriptor<tensor = memref<32x32xf32>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
!mbarrier = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
func.func @tma_load_2(%desc: !desc, %buffer1: memref<128xf32,3>, %buffer2: memref<32x32xf32,3>, %buffer3: memref<32x32xf32>, %mbarrier: !mbarrier) {
%c0 = arith.constant 0 : index
// expected-error @+1 {{the tensor map descriptor has incorrect address space, it must be shared memory address space.}}
nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer2 : !desc, !mbarrier -> memref<32x32xf32,3>
return
}
// -----

!desc = !nvgpu.tensormap.descriptor<tensor = memref<32x32xf32,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
!mbarrier = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
func.func @tma_load_3(%desc: !desc, %buffer1: memref<128xf32,3>, %buffer2: memref<32x32xf32,3>, %buffer3: memref<32x32xf32>, %mbarrier: !mbarrier) {
%c0 = arith.constant 0 : index
// expected-error @+1 {{the destination memref has incorrect address space, it must be shared memory address space}}
nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer3 : !desc, !mbarrier -> memref<32x32xf32>
return
}
// -----

!desc = !nvgpu.tensormap.descriptor<tensor = memref<32x32xf32,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
!mbarrier = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
func.func @tma_load_4(%desc: !desc, %buffer1: memref<128xf32,3>, %buffer2: memref<32x32xf32,3>, %buffer3: memref<32x32xf32>, %mbarrier: !mbarrier) {
%c0 = arith.constant 0 : index
// expected-error @+1 {{the shape of tensor map descriptor and memref must have same rank}}
nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer1 : !desc, !mbarrier -> memref<128xf32,3>
return
}

0 comments on commit 6eb97f0

Please sign in to comment.