From d09b982cdc0a9dbf2ed34cf25e92a41c02e770fa Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Thu, 18 Jan 2024 21:15:46 -0800 Subject: [PATCH] add support for split barriers This may also be helpful to keep subgroups running approximately together, which could also improve cache utilization. --- .../matrix_kernel_tiled.cl | 53 +++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 3046222..cb3c6ca 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -10,6 +10,17 @@ #error "NN is undefined! This should be defined as the number of matrix tiles in the N dimension." #endif +#if !defined(cl_intel_split_work_group_barrier) || defined(NO_SPLIT_BARRIERS) +#if !defined(cl_intel_split_work_group_barrier) +#warning "Unexpected: cl_intel_split_work_group_barrier is not supported?" +#endif +#define split_barrier_arrive() +#define split_barrier_wait() +#else +#define split_barrier_arrive() intel_work_group_barrier_arrive(0) +#define split_barrier_wait() intel_work_group_barrier_wait(0) +#endif + #define MM_KERNEL_NAMEX(PREFIX, tM, tN, MM, NN) PREFIX ## _m ## tM ## _n ## tN ## _ ## MM ## x ## NN #define MM_KERNEL_NAME(PREFIX, tM, tN, MM, NN) MM_KERNEL_NAMEX(PREFIX, tM, tN, MM, NN) @@ -36,6 +47,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl } } + split_barrier_arrive(); + for (int k = 0; k < K; k += tK) { int8 aData[MM]; for (int mm = 0; mm < MM; mm++) { @@ -52,8 +65,13 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl sum[mm][nn] = mat_mul_sg8(aData[mm], bData[nn], sum[mm][nn]); } } + + split_barrier_wait(); + split_barrier_arrive(); } + split_barrier_wait(); + for (int nn = 0; nn < NN; nn++) { for (int mm = 0; mm < MM; mm++) { store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); @@ -77,6 +95,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* } } + split_barrier_arrive(); + for (int k = 0; k < K; k += tK) { int8 aData[MM]; for (int mm = 0; mm < MM; mm++) { @@ -93,8 +113,13 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* sum[mm][nn] = mat_mul_sg8(aData[mm], bData[nn], sum[mm][nn]); } } + + split_barrier_wait(); + split_barrier_arrive(); } + split_barrier_wait(); + for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); @@ -120,6 +145,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f } } + split_barrier_arrive(); + for (int k = 0; k < K; k += tK) { short8 aData[MM]; for (int mm = 0; mm < MM; mm++) { @@ -136,8 +163,13 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]); } } + + split_barrier_wait(); + split_barrier_arrive(); } + split_barrier_wait(); + for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); @@ -161,6 +193,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float } } + split_barrier_arrive(); + for (int k = 0; k < K; k += tK) { short8 aData[MM]; for (int mm = 0; mm < MM; mm++) { @@ -177,8 +211,13 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]); } } + + split_barrier_wait(); + split_barrier_arrive(); } + split_barrier_wait(); + for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); @@ -205,6 +244,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN } } + split_barrier_arrive(); + for (int k = 0; k < K; k += tK) { short8 aData[MM]; //if (MM % 2 == 0) { @@ -229,8 +270,13 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]); } } + + split_barrier_wait(); + split_barrier_arrive(); } + split_barrier_wait(); + for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[mm][nn])); @@ -255,6 +301,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl } } + split_barrier_arrive(); + for (int k = 0; k < K; k += tK) { short8 aData[MM]; //if (MM % 2 == 0) { @@ -279,8 +327,13 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]); } } + + split_barrier_wait(); + split_barrier_arrive(); } + split_barrier_wait(); + for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[mm][nn]));