diff --git a/test/xrt/16_matmul_8x16_core_transform_bf16/gen.py b/test/xrt/16_matmul_8x16_core_transform_bf16/gen.py index aaac23569..0c26da215 100644 --- a/test/xrt/16_matmul_8x16_core_transform_bf16/gen.py +++ b/test/xrt/16_matmul_8x16_core_transform_bf16/gen.py @@ -22,14 +22,13 @@ #map3 = affine_map<()[s0] -> (s0 * 8)> #map4 = affine_map<(d0) -> (d0 * 4)> module { - func.func @forward(%arg0: memref<512x1024xbf16>, %arg1: memref<128x8x8x4x16xbf16>, %arg2: memref<512x512xf32>) -> memref<512x512xf32> { + func.func @forward(%arg0: memref<512x256xbf16>, %arg1: memref<32x8x8x4x16xbf16>, %arg2: memref<512x512xf32>) -> memref<512x512xf32> { %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index %c64 = arith.constant 64 : index %c32 = arith.constant 32 : index %c128 = arith.constant 128 : index %c256 = arith.constant 256 : index - %c1024 = arith.constant 1024 : index %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %cst = arith.constant 0.000000e+00 : f32 @@ -40,16 +39,16 @@ %0 = affine.apply #map(%arg3) %1 = affine.apply #map(%arg4) %subview = memref.subview %arg2[%0, %1] [128, 128] [1, 1] : memref<512x512xf32> to memref<128x128xf32, strided<[512, 1], offset: ?>> - %alloc_0 = memref.alloc() : memref<128x1024xbf16, 1> - scf.for %arg5 = %c0 to %c1024 step %c256 { - %subview_3 = memref.subview %arg0[%0, %arg5] [128, 256] [1, 1] : memref<512x1024xbf16> to memref<128x256xbf16, strided<[1024, 1], offset: ?>> - %subview_4 = memref.subview %alloc_0[0, %arg5] [128, 256] [1, 1] : memref<128x1024xbf16, 1> to memref<128x256xbf16, strided<[1024, 1], offset: ?>, 1> - linalg.copy ins(%subview_3 : memref<128x256xbf16, strided<[1024, 1], offset: ?>>) outs(%subview_4 : memref<128x256xbf16, strided<[1024, 1], offset: ?>, 1>) + %alloc_0 = memref.alloc() : memref<128x256xbf16, 1> + scf.for %arg5 = %c0 to %c256 step %c256 { + %subview_3 = memref.subview %arg0[%0, %arg5] [128, 256] [1, 1] : memref<512x256xbf16> to memref<128x256xbf16, strided<[256, 1], offset: ?>> + %subview_4 = memref.subview %alloc_0[0, %arg5] [128, 256] [1, 1] : memref<128x256xbf16, 1> to memref<128x256xbf16, strided<[256, 1], offset: ?>, 1> + linalg.copy ins(%subview_3 : memref<128x256xbf16, strided<[256, 1], offset: ?>>) outs(%subview_4 : memref<128x256xbf16, strided<[256, 1], offset: ?>, 1>) } - %alloc_1 = memref.alloc() : memref<128x8x8x16xbf16, 1> - scf.for %arg5 = %c0 to %c128 step %c32 { - %subview_3 = memref.subview %arg1[%arg5, 0, 0, %arg4, 0] [32, 8, 8, 1, 16] [1, 1, 1, 1, 1] : memref<128x8x8x4x16xbf16> to memref<32x8x8x16xbf16, strided<[4096, 512, 64, 1], offset: ?>> - %subview_4 = memref.subview %alloc_1[%arg5, 0, 0, 0] [32, 8, 8, 16] [1, 1, 1, 1] : memref<128x8x8x16xbf16, 1> to memref<32x8x8x16xbf16, strided<[1024, 128, 16, 1], offset: ?>, 1> + %alloc_1 = memref.alloc() : memref<32x8x8x16xbf16, 1> + scf.for %arg5 = %c0 to %c32 step %c32 { + %subview_3 = memref.subview %arg1[%arg5, 0, 0, %arg4, 0] [32, 8, 8, 1, 16] [1, 1, 1, 1, 1] : memref<32x8x8x4x16xbf16> to memref<32x8x8x16xbf16, strided<[4096, 512, 64, 1], offset: ?>> + %subview_4 = memref.subview %alloc_1[%arg5, 0, 0, 0] [32, 8, 8, 16] [1, 1, 1, 1] : memref<32x8x8x16xbf16, 1> to memref<32x8x8x16xbf16, strided<[1024, 128, 16, 1], offset: ?>, 1> linalg.copy ins(%subview_3 : memref<32x8x8x16xbf16, strided<[4096, 512, 64, 1], offset: ?>>) outs(%subview_4 : memref<32x8x8x16xbf16, strided<[1024, 128, 16, 1], offset: ?>, 1>) } %alloc_2 = memref.alloc() : memref<128x128xf32, 1> @@ -60,14 +59,14 @@ %subview_3 = memref.subview %alloc_2[%2, %3] [64, 64] [1, 1] : memref<128x128xf32, 1> to memref<64x64xf32, strided<[128, 1], offset: ?>, 1> %alloc_4 = memref.alloc() : memref<64x64xf32, 2> linalg.fill ins(%cst : f32) outs(%alloc_4 : memref<64x64xf32, 2>) - scf.for %arg7 = %c0 to %c128 step %c8 { + scf.for %arg7 = %c0 to %c32 step %c8 { %4 = affine.apply #map3()[%arg7] - %subview_5 = memref.subview %alloc_0[%2, %4] [64, 64] [1, 1] : memref<128x1024xbf16, 1> to memref<64x64xbf16, strided<[1024, 1], offset: ?>, 1> - %subview_6 = memref.subview %alloc_1[%arg7, %map, 0, 0] [8, 4, 8, 16] [1, 1, 1, 1] : memref<128x8x8x16xbf16, 1> to memref<8x4x8x16xbf16, strided<[1024, 128, 16, 1], offset: ?>, 1> + %subview_5 = memref.subview %alloc_0[%2, %4] [64, 64] [1, 1] : memref<128x256xbf16, 1> to memref<64x64xbf16, strided<[256, 1], offset: ?>, 1> + %subview_6 = memref.subview %alloc_1[%arg7, %map, 0, 0] [8, 4, 8, 16] [1, 1, 1, 1] : memref<32x8x8x16xbf16, 1> to memref<8x4x8x16xbf16, strided<[1024, 128, 16, 1], offset: ?>, 1> %transpose = memref.transpose %subview_6 (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<8x4x8x16xbf16, strided<[1024, 128, 16, 1], offset: ?>, 1> to memref<8x8x4x16xbf16, strided<[1024, 16, 128, 1], offset: ?>, 1> %alloc_7 = memref.alloc() : memref<64x64xbf16, 2> %alloc_8 = memref.alloc() : memref<8x8x4x16xbf16, 2> - memref.copy %subview_5, %alloc_7 : memref<64x64xbf16, strided<[1024, 1], offset: ?>, 1> to memref<64x64xbf16, 2> + memref.copy %subview_5, %alloc_7 : memref<64x64xbf16, strided<[256, 1], offset: ?>, 1> to memref<64x64xbf16, 2> memref.copy %transpose, %alloc_8 : memref<8x8x4x16xbf16, strided<[1024, 16, 128, 1], offset: ?>, 1> to memref<8x8x4x16xbf16, 2> %reshape = memref.reshape %alloc_8(%alloc) : (memref<8x8x4x16xbf16, 2>, memref<2xindex>) -> memref<64x64xbf16, 2> linalg.matmul {cast = #linalg.type_fn} ins(%alloc_7, %reshape : memref<64x64xbf16, 2>, memref<64x64xbf16, 2>) outs(%alloc_4 : memref<64x64xf32, 2>) @@ -78,8 +77,8 @@ memref.dealloc %alloc_4 : memref<64x64xf32, 2> } {mapping = [#gpu.thread, #gpu.thread]} linalg.copy ins(%alloc_2 : memref<128x128xf32, 1>) outs(%subview : memref<128x128xf32, strided<[512, 1], offset: ?>>) - memref.dealloc %alloc_0 : memref<128x1024xbf16, 1> - memref.dealloc %alloc_1 : memref<128x8x8x16xbf16, 1> + memref.dealloc %alloc_0 : memref<128x256xbf16, 1> + memref.dealloc %alloc_1 : memref<32x8x8x16xbf16, 1> memref.dealloc %alloc_2 : memref<128x128xf32, 1> } return %arg2 : memref<512x512xf32> diff --git a/test/xrt/16_matmul_8x16_core_transform_bf16/kernel.cpp b/test/xrt/16_matmul_8x16_core_transform_bf16/kernel.cpp old mode 100755 new mode 100644 index 96acf80d1..b914570ce --- a/test/xrt/16_matmul_8x16_core_transform_bf16/kernel.cpp +++ b/test/xrt/16_matmul_8x16_core_transform_bf16/kernel.cpp @@ -16,7 +16,7 @@ #include -template +template void zero_scalar(T *__restrict c) { for (int i = 0; i < M * N; i++) { c[i] = 0.0f; diff --git a/test/xrt/16_matmul_8x16_core_transform_bf16/run.py b/test/xrt/16_matmul_8x16_core_transform_bf16/run.py index 2817d4a63..56ab8e242 100644 --- a/test/xrt/16_matmul_8x16_core_transform_bf16/run.py +++ b/test/xrt/16_matmul_8x16_core_transform_bf16/run.py @@ -11,7 +11,7 @@ M = 512 N = 512 -K = 1024 +K = 256 Tx = 16 Ty = 8