Skip to content

Commit

Permalink
add support for launching more than one subgroup per work group
Browse files Browse the repository at this point in the history
This should enable better cache reuse across subgroups.
  • Loading branch information
bashbaug committed Jan 19, 2024
1 parent 4caea7b commit 0fb3d66
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 12 deletions.
7 changes: 7 additions & 0 deletions samples/99_matrixexperiments/matrix_helpers.cl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@ float bf16_to_fp32(ushort u)

#if defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short)

inline int compute_m(const int num_sgs, const int tM, const int MM)
{
const int m_start = get_group_id(1) * num_sgs;
const int m_index = num_sgs > 1 ? m_start + get_sub_group_id() : m_start;
return m_index * tM * MM;
}

// Emulated SIMD8 dpas:
__attribute__((overloadable))
float emu_sub_group_bf16_bf16_matrix_mad_k16(int a, int8 b, float acc)
Expand Down
29 changes: 17 additions & 12 deletions samples/99_matrixexperiments/matrix_kernel_tiled.cl
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,20 @@
#define MM_KERNEL_NAMEX(PREFIX, tM, tN, MM, NN) PREFIX ## _m ## tM ## _n ## tN ## _ ## MM ## x ## NN
#define MM_KERNEL_NAME(PREFIX, tM, tN, MM, NN) MM_KERNEL_NAMEX(PREFIX, tM, tN, MM, NN)

#if !defined(SGS_PER_WG)
// Launch four subgroups per work-group, to maximize cache reuse.
#define SGS_PER_WG 4
#endif

#if HAS_SIMD8

__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1)))
__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, SGS_PER_WG, 1)))
kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global float* C, global ushort* A, global ushort* B, int K)
{
const int tM = 8;
const int tN = 8;
const int N = get_global_size(0) * NN;
const int m = get_group_id(1) * tM * MM;
const int m = compute_m(SGS_PER_WG, tM, MM);
const int n = get_group_id(0) * tN * NN;

float8 sum[MM][NN];
Expand Down Expand Up @@ -56,13 +61,13 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl
}
}

__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1)))
__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, SGS_PER_WG, 1)))
kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* C, global ushort* A, global ushort* B, int K)
{
const int tM = 8;
const int tN = 8;
const int N = get_global_size(0) * NN;
const int m = get_group_id(1) * tM * MM;
const int m = compute_m(SGS_PER_WG, tM, MM);
const int n = get_group_id(0) * tN * NN;

float8 sum[MM][NN];
Expand Down Expand Up @@ -99,13 +104,13 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float*

#endif // HAS_SIMD8

__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1)))
__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1)))
kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K)
{
const int tM = 8;
const int tN = 16;
const int N = get_global_size(0) * NN;
const int m = get_group_id(1) * tM * MM;
const int m = compute_m(SGS_PER_WG, tM, MM);
const int n = get_group_id(0) * tN * NN;

float8 sum[MM][NN];
Expand Down Expand Up @@ -140,13 +145,13 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f
}
}

__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1)))
__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1)))
kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K)
{
const int tM = 8;
const int tN = 16;
const int N = get_global_size(0) * NN;
const int m = get_group_id(1) * tM * MM;
const int m = compute_m(SGS_PER_WG, tM, MM);
const int n = get_group_id(0) * tN * NN;

float8 sum[MM][NN];
Expand Down Expand Up @@ -183,14 +188,14 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float

#ifdef cl_intel_subgroup_extended_block_read

__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1)))
__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1)))
kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K)
{
const int tM = 8;
const int tN = 16;
const int M = get_global_size(1) * tM * MM;
const int N = get_global_size(0) * NN;
const int m = get_group_id(1) * tM * MM;
const int m = compute_m(SGS_PER_WG, tM, MM);
const int n = get_group_id(0) * tN * NN;

float8 sum[MM][NN];
Expand Down Expand Up @@ -233,14 +238,14 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN
}
}

__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1)))
__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1)))
kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K)
{
const int tM = 8;
const int tN = 16;
const int M = get_global_size(1) * tM * MM;
const int N = get_global_size(0) * NN;
const int m = get_group_id(1) * tM * MM;
const int m = compute_m(SGS_PER_WG, tM, MM);
const int n = get_group_id(0) * tN * NN;

float8 sum[MM][NN];
Expand Down

0 comments on commit 0fb3d66

Please sign in to comment.