diff --git a/test/xrt/04_gemm_w_pack/aie.py b/test/xrt/04_gemm_w_pack/aie.py index f8d924346..1e0a048c3 100644 --- a/test/xrt/04_gemm_w_pack/aie.py +++ b/test/xrt/04_gemm_w_pack/aie.py @@ -18,28 +18,27 @@ #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)> module { - func.func @forward(%arg0: memref<2048x2048xi32>, %arg1: memref<2048x2048xi32>) -> memref<2048x2048xi32> { + func.func @forward(%arg0: memref<256x256xi32>, %arg1: memref<256x256xi32>) -> memref<256x256xi32> { %c32 = arith.constant 32 : index %c256 = arith.constant 256 : index - %c2048 = arith.constant 2048 : index %c0 = arith.constant 0 : index %c64 = arith.constant 64 : index %c0_i32 = arith.constant 0 : i32 - %alloc = memref.alloc() : memref<2048x2048xi32> - scf.parallel (%arg2, %arg3) = (%c0, %c0) to (%c2048, %c2048) step (%c64, %c64) { - %subview = memref.subview %alloc[%arg2, %arg3] [64, 64] [1, 1] : memref<2048x2048xi32> to memref<64x64xi32, strided<[2048, 1], offset: ?>> + %alloc = memref.alloc() : memref<256x256xi32> + scf.parallel (%arg2, %arg3) = (%c0, %c0) to (%c256, %c256) step (%c64, %c64) { + %subview = memref.subview %alloc[%arg2, %arg3] [64, 64] [1, 1] : memref<256x256xi32> to memref<64x64xi32, strided<[256, 1], offset: ?>> %alloc_0 = memref.alloc() : memref<1x1x64x64xi32, 1> scf.parallel (%arg4, %arg5) = (%c0, %c0) to (%c64, %c64) step (%c32, %c32) { %alloc_2 = memref.alloc() : memref<1x1x4x8x4x8xi32, 2> linalg.fill ins(%c0_i32 : i32) outs(%alloc_2 : memref<1x1x4x8x4x8xi32, 2>) %subview_3 = memref.subview %alloc_0[0, 0, %arg4, %arg5] [1, 1, 32, 32] [1, 1, 1, 1] : memref<1x1x64x64xi32, 1> to memref<1x1x32x32xi32, strided<[4096, 4096, 64, 1], offset: ?>, 1> - scf.for %arg6 = %c0 to %c2048 step %c256 { - %subview_5 = memref.subview %arg0[%arg2, %arg6] [64, 256] [1, 1] : memref<2048x2048xi32> to memref<64x256xi32, strided<[2048, 1], offset: ?>> - %subview_6 = memref.subview %arg1[%arg6, %arg3] [256, 64] [1, 1] : memref<2048x2048xi32> to memref<256x64xi32, strided<[2048, 1], offset: ?>> + scf.for %arg6 = %c0 to %c256 step %c256 { + %subview_5 = memref.subview %arg0[%arg2, %arg6] [64, 256] [1, 1] : memref<256x256xi32> to memref<64x256xi32, strided<[256, 1], offset: ?>> + %subview_6 = memref.subview %arg1[%arg6, %arg3] [256, 64] [1, 1] : memref<256x256xi32> to memref<256x64xi32, strided<[256, 1], offset: ?>> %alloc_7 = memref.alloc() : memref<1x1x64x256xi32, 1> %alloc_8 = memref.alloc() : memref<1x1x256x64xi32, 1> - air.dma_memcpy_nd (%alloc_7[] [] [], %subview_5[] [] []) : (memref<1x1x64x256xi32, 1>, memref<64x256xi32, strided<[2048, 1], offset: ?>>) - air.dma_memcpy_nd (%alloc_8[] [] [], %subview_6[] [] []) : (memref<1x1x256x64xi32, 1>, memref<256x64xi32, strided<[2048, 1], offset: ?>>) + air.dma_memcpy_nd (%alloc_7[] [] [], %subview_5[] [] []) : (memref<1x1x64x256xi32, 1>, memref<64x256xi32, strided<[256, 1], offset: ?>>) + air.dma_memcpy_nd (%alloc_8[] [] [], %subview_6[] [] []) : (memref<1x1x256x64xi32, 1>, memref<256x64xi32, strided<[256, 1], offset: ?>>) scf.for %arg7 = %c0 to %c256 step %c32 { %subview_9 = memref.subview %alloc_7[0, 0, %arg4, %arg7] [1, 1, 32, 32] [1, 1, 1, 1] : memref<1x1x64x256xi32, 1> to memref<1x1x32x32xi32, strided<[16384, 16384, 256, 1], offset: ?>, 1> %subview_10 = memref.subview %alloc_8[0, 0, %arg7, %arg5] [1, 1, 32, 32] [1, 1, 1, 1] : memref<1x1x256x64xi32, 1> to memref<1x1x32x32xi32, strided<[16384, 16384, 64, 1], offset: ?>, 1> @@ -70,11 +69,11 @@ } %subview_1 = memref.subview %alloc_0[0, 0, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x1x64x64xi32, 1> to memref<64x64xi32, 1> %transpose = memref.transpose %subview_1 (d0, d1) -> (d0, d1) : memref<64x64xi32, 1> to memref<64x64xi32, strided<[64, 1]>, 1> - air.dma_memcpy_nd (%subview[] [] [], %transpose[] [] []) : (memref<64x64xi32, strided<[2048, 1], offset: ?>>, memref<64x64xi32, strided<[64, 1]>, 1>) + air.dma_memcpy_nd (%subview[] [] [], %transpose[] [] []) : (memref<64x64xi32, strided<[256, 1], offset: ?>>, memref<64x64xi32, strided<[64, 1]>, 1>) memref.dealloc %alloc_0 : memref<1x1x64x64xi32, 1> scf.reduce } - return %alloc : memref<2048x2048xi32> + return %alloc : memref<256x256xi32> } } """ diff --git a/test/xrt/04_gemm_w_pack/run.lit b/test/xrt/04_gemm_w_pack/run.lit index feb90b7e8..9a2d42ccd 100644 --- a/test/xrt/04_gemm_w_pack/run.lit +++ b/test/xrt/04_gemm_w_pack/run.lit @@ -6,4 +6,3 @@ // RUN: %python aiecc.py --no-aiesim --aie-generate-cdo --aie-generate-npu --no-compile-host --xclbin-name=aie.xclbin --npu-insts-name=insts.txt aie.mlir // RUN: clang %S/test.cpp -O3 -o test.exe -std=c++11 -Wall %xrt_flags -lrt -lstdc++ -lboost_program_options -lboost_filesystem // RUN: %run_on_npu ./test.exe -x aie.xclbin -k MLIR_AIE -i insts.txt -// XFAIL: * diff --git a/test/xrt/04_gemm_w_pack/test.cpp b/test/xrt/04_gemm_w_pack/test.cpp index fd68d7f96..3b12de2c9 100644 --- a/test/xrt/04_gemm_w_pack/test.cpp +++ b/test/xrt/04_gemm_w_pack/test.cpp @@ -10,9 +10,9 @@ #include "xrt/xrt_device.h" #include "xrt/xrt_kernel.h" -#define M 2048 -#define N 2048 -#define K 2048 +#define M 256 +#define N 256 +#define K 256 #define A_VOLUME M *K #define B_VOLUME N *K