Skip to content

Commit

Permalink
[GPU] Optimize fc_bf_tiled kernel for large K + small N case (openvin…
Browse files Browse the repository at this point in the history
…otoolkit#26054)

### Details:
- Optimize fc_bf_tiled kernel for large K + small N case by setting
K_TILE_SIZE 4
- Perf gain on MTL (U7 155H +  32GB RAM + driver 31.0.101.5333)

![image](https://github.com/user-attachments/assets/2e6537fe-90d4-47e1-8d56-f816d3471559)
(No regression and on par on llama3 INT4 default and phi-3 mini INT4
default)
### Tickets:
 - 149212
  • Loading branch information
yeonbok authored Aug 22, 2024
1 parent c85902d commit eb16f7f
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ KERNEL(quantize_input)(
#define OUTPUT_BLOCK_WRITE(ptr, offset, val) BLOCK_WRITEN(OUTPUT_TYPE, TILE_OFM, ptr, offset, val)

#define SLM_FILTER_VEC MAKE_VECTOR_TYPE(ACCUMULATOR_TYPE, TILE_OFM)
#define SLM_FILTER_PACKED_VEC MAKE_VECTOR_TYPE(FILTER_TYPE, FILTER_LOAD_BLOCK_SIZE)
#define SLM_FILTER_PACKED_VEC MAKE_VECTOR_TYPE(FILTER_TYPE, FILTER_ACTUAL_LOAD_BLOCK_SIZE)
#define SLM_FILTER_UNPACKED_VEC MAKE_VECTOR_TYPE(ACCUMULATOR_TYPE, FILTER_ELEMENTS_PER_LOAD)


Expand Down Expand Up @@ -311,6 +311,9 @@ inline void FUNC(fc_bf_tiled_kernel_default)(
#if TILE_OFM != 2
#error "FC bf_tiled kernel: can't use SLM optimization with TILE_OFM != 2"
#endif
#if FILTER_LAYOUT_OS_IYX_OSV16 && TILE_K != 4
#error "FC bf_tiled kernel: can't use SLM optimization with TILE_K != 2 && OS_IYX_OSV16 layout"
#endif

// Skip first barrier synchronization if there is only single outer loop iteration.
#if MAIN_LOOP_ELEMENTS_COUNT / (TILE_IFM * SIMD) > 1
Expand All @@ -319,12 +322,19 @@ inline void FUNC(fc_bf_tiled_kernel_default)(

__local SLM_FILTER_VEC* slm_wei_vec = (__local SLM_FILTER_VEC*)wei_local_mem;

uint weights_idx = weights_offset + local_id * SIMD * FILTER_LOAD_ITERS * FILTER_LOAD_BLOCK_SIZE;
uint weights_idx = weights_offset + local_id * SIMD * FILTER_LOAD_ITERS * FILTER_ACTUAL_LOAD_BLOCK_SIZE;
uint wei_local_idx = local_id * SIMD * FILTER_LOAD_ITERS * FILTER_LOAD_BLOCK_SIZE + sglid;

unroll_for(uint load_iter = 0; load_iter < FILTER_LOAD_ITERS; ++load_iter) {
SLM_FILTER_PACKED_VEC wei_packed = BLOCK_READN(FILTER_TYPE, FILTER_LOAD_BLOCK_SIZE, weights, weights_idx);
#if FILTER_LAYOUT_OS_IYX_OSV16
SLM_FILTER_PACKED_VEC wei_packed0 = BLOCK_READN(FILTER_TYPE, FILTER_ACTUAL_LOAD_BLOCK_SIZE, weights, weights_idx);
SLM_FILTER_PACKED_VEC wei_packed1 = BLOCK_READN(FILTER_TYPE, FILTER_ACTUAL_LOAD_BLOCK_SIZE, weights, (weights_idx + ((IFM_SIZE / 2) * 16)));
SLM_FILTER_UNPACKED_VEC wei_unpacked;
wei_unpacked.s0123 = UNPACK_INT4(ACCUMULATOR_TYPE, *((INT4_PACKED_TYPE_PRELOAD*)&wei_packed0));
wei_unpacked.s4567 = UNPACK_INT4(ACCUMULATOR_TYPE, *((INT4_PACKED_TYPE_PRELOAD*)&wei_packed1));
#else
SLM_FILTER_PACKED_VEC wei_packed = BLOCK_READN(FILTER_TYPE, FILTER_LOAD_BLOCK_SIZE/*4*/, weights, weights_idx);
SLM_FILTER_UNPACKED_VEC wei_unpacked = UNPACK_INT4(ACCUMULATOR_TYPE, *((INT4_PACKED_TYPE_PRELOAD*)&wei_packed));
#endif
ACCUMULATOR_TYPE* w = (ACCUMULATOR_TYPE*)(&wei_unpacked);
unroll_for(uint fi = 0; fi < TILE_OFM; ++fi) {
unroll_for(uint kii = 0; kii < FILTER_LOAD_BLOCK_SIZE; ++kii) {
Expand Down Expand Up @@ -383,8 +393,7 @@ inline void FUNC(fc_bf_tiled_kernel_default)(
#endif

#undef STORE_TO_SLM

weights_idx += SIMD * FILTER_LOAD_BLOCK_SIZE;
weights_idx += SIMD * FILTER_ACTUAL_LOAD_BLOCK_SIZE;
}

wei_local_idx = sglid;
Expand Down Expand Up @@ -478,6 +487,8 @@ inline void FUNC(fc_bf_tiled_kernel_default)(
}
#if TILE_OFM == 1 && FILTER_LAYOUT_OS_IS_YX_OSV32_ISV2
weights_offset += TILE_K_OFM_PACKED * 2 * SIMD;
#elif FILTER_LAYOUT_OS_IYX_OSV16 && TILE_OFM == 2 && USE_SLM == 1
weights_offset += TILE_K_OFM_PACKED / 2 * SIMD;
#else
weights_offset += TILE_K_OFM_PACKED * SIMD;
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,16 @@ inline uint get_os_zyxi_osv16_index(uint o, uint i, uint z, uint y, uint x, uint
((o) / (sub_group_size))*CAT(prefix, _OFM_PITCH) \
)

#define GET_FILTER_OS_IYX_OSV_INDEX_INT4_PACKED(prefix, o, i, y, x, sub_group_size) \
CAT(prefix, _OFFSET) + \
((o) % (sub_group_size)) + \
(sub_group_size)*( \
(x)*CAT(prefix, _X_PITCH) + \
(y)*CAT(prefix, _Y_PITCH) + \
(i)*CAT(prefix, _IFM_PITCH) + \
((o) / (sub_group_size))*(CAT(prefix, _OFM_PITCH)/2) \
)

#define GET_FILTER_OS_IS_YX_OSV_ISV_INDEX_INT4_PACKED(prefix, o, i, y, x, sub_group_size) \
CAT(prefix, _OFFSET) + \
((o) % (sub_group_size)) + \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,24 @@ inline half4 unpack_to_half(uint4x4_t v) __attribute__((overloadable)) {
return (half4)(f0.s0, f0.s1, f1.s0, f1.s1);
}

inline half4 unpack_to_half_osv32_isv2(uint4x4_t v) __attribute__((overloadable)) {
half2 f0 = unpack_to_half(v.s0);
half2 f1 = unpack_to_half(v.s1);
return (half4)(f0.s0, f0.s1, f1.s0, f1.s1);
}

inline half4 unpack_to_half(int4x4_t v) __attribute__((overloadable)) {
half2 f0 = unpack_to_half(v.s0);
half2 f1 = unpack_to_half(v.s1);
return (half4)(f0.s0, f0.s1, f1.s0, f1.s1);
}

inline half4 unpack_to_half_osv32_isv2(int4x4_t v) __attribute__((overloadable)) {
half2 f0 = unpack_to_half(v.s0);
half2 f1 = unpack_to_half(v.s1);
return (half4)(f0.s0, f0.s1, f1.s0, f1.s1);
}

inline half8 unpack_to_half(uint4x8_t v) __attribute__((overloadable)) {
half2 f0 = unpack_to_half(v.s0);
half2 f1 = unpack_to_half(v.s1);
Expand Down Expand Up @@ -211,4 +223,5 @@ inline uchar8 unpack_to_uchar_osv32_isv2(uint4x8_t v) __attribute__((overloadabl

#define UNPACK_INT4x2(target_type, value) CAT(unpack_to_, target_type)(value)
#define UNPACK_INT4x2_OSV32_ISV2(target_type, value) CAT(CAT(unpack_to_, target_type), _osv32_isv2)(value)
#define UNPACK_INT4x4_OSV32_ISV2(target_type, value) CAT(CAT(unpack_to_, target_type), _osv32_isv2)(value)
#define UNPACK_TRANSPOSED_INT4x2(target_type, value) CAT(unpack_transposed_to_, target_type)(value)
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,30 @@ KERNEL(reorder_weights_int4)(const __global INPUT0_TYPE* input, __global OUTPUT_

OUTPUT_TYPE out = in0 | (in1 << 4);
output[out_byte_offset] = out;
#elif defined(OUTPUT_LAYOUT_OS_IYX_OSV16)
// osv32_isv2 layout for int4 packed weight
// f0_k0k1 | f1_k0k1 | .... | f15_k0k1
// f0_k2k3 | f1_k2k3 | .... | f15_k2k3
// f0_k3k4 | f1_k3k4 | .... | f15_k3k4
// ...
// f0_k(K/2-2)k(K/2-1) | f1_k(K/2-2)k(K/2-1) | ....f15_k(K/2-2)k(K/2-1)
// -------------------------------------
// f16_k2k3 | f17_k2k3 | ... | f31_k2k3
// ...
const unsigned o = (uint)get_global_id(0);
const unsigned i = (uint)get_global_id(1) * 2;

const uint input0_offset = GET_FILTER_INDEX(INPUT0, 0, o, i, 0, 0);

INPUT0_TYPE in1 = input[input0_offset / 2] & 0xFF;

INPUT0_TYPE packed_out_channels = in1;

const uint output_idx = GET_FILTER_OS_IYX_OSV_INDEX_INT4_PACKED(OUTPUT, o, i/2, 0, 0, 16); // Calculate offset as osv16 due to packing
output[output_idx] = packed_out_channels;



#elif defined(OUTPUT_LAYOUT_OS_IYX_OSV32)
// os_iyx osv32 layout for int4 packed weight
// k0_f0f16 | k0_f1f17 | .... | k0_f15f31 || k1_f0f16 | k1_f1f17 | ... | k1_f15f31
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,19 @@ static bool should_dynamic_quantize(const fully_connected_params& params) {
return false;
}

static bool is_weight_with_small_ofm(const fully_connected_params& params, size_t output_f) {
size_t min_num_threads = params.engineInfo.computeUnitsCount * simd;
GPU_DEBUG_TRACE_DETAIL << "out_ofm (== weight N dim) size " << output_f << " is small compared to the available threads. "
<< "(computeUnitsCount : " << params.engineInfo.computeUnitsCount
<< " min_num_threads : " << min_num_threads << ")" << std::endl;
GPU_DEBUG_TRACE_DETAIL << "Use ofm_tile size 1 if the batch size is 1." << std::endl;
return (output_f / 2 /*most frequently used tile_ofm*/ <= min_num_threads);
}

static bool is_weight_with_large_ifm(const fully_connected_params& fc_params) {
return (fc_params.weights.IFM().v >= fc_params.weights.OFM().v * 3 && fc_params.weights.OFM().v <= 4096);
}

FullyConnected_bf_tiled::FullyConnected_bf_tiled() : FullyConnectedKernelBase("fully_connected_gpu_bf_tiled") {
for (unsigned tile_b = 1; tile_b <= 32; ++tile_b)
for (unsigned tile_ofm = 1; tile_ofm <= 4; tile_ofm *= 2)
Expand Down Expand Up @@ -324,14 +337,18 @@ FullyConnected_bf_tiled::GetAutoTuneParams(const fully_connected_params& params,
if (params.weights.GetDType() == WeightsType::UINT4 || params.weights.GetDType() == WeightsType::INT4) {
if (!params.is_shape_agnostic && batch == 1) {
// Tuning for Meteor Lake
size_t min_num_threads = params.engineInfo.computeUnitsCount * simd;
if (output_f / 2 <= min_num_threads && params.weights.GetLayout() == WeightsLayout::os_is_yx_osv32_isv2) {
GPU_DEBUG_TRACE_DETAIL << "FC bf tiled: Set ofm_tile 1. (output_f : " << output_f
<< ", computeUnitsCount : " << params.engineInfo.computeUnitsCount
<< " min_num_threads : " << min_num_threads << ")" << std::endl;
return selector.Default(tune_params(1, 1, 4, 2, 1, 1, EXE_MODE_DEFAULT));
if (is_weight_with_small_ofm(params, output_f)) {
if (params.weights.GetLayout() == WeightsLayout::os_is_yx_osv32_isv2) {
return selector.Default(tune_params(1, 1, 4, 2, 1, 1, EXE_MODE_DEFAULT));
} else if (params.weights.GetLayout() == WeightsLayout::os_iyx_osv16) {
return selector.Default(tune_params(1, 1, 4, 4, 1, 1, EXE_MODE_DEFAULT));
}
} else {
return selector.Default(tune_params(1, 2, 4, 2, 1, 1, EXE_MODE_DEFAULT));
if (params.weights.GetLayout() == WeightsLayout::os_iyx_osv16) {
return selector.Default(tune_params(1, 1, 4, 4, 1, 1, EXE_MODE_DEFAULT));
} else {
return selector.Default(tune_params(1, 2, 4, 2, 1, 1, EXE_MODE_DEFAULT));
}
}
} else {
// Try to use SLM kernels if possible
Expand All @@ -343,7 +360,10 @@ FullyConnected_bf_tiled::GetAutoTuneParams(const fully_connected_params& params,
selector.Case(tune_params(8, 2, 2, 4, 1, 1, EXE_MODE_DEFAULT, KernelType::SLM))
.Case(tune_params(8, 2, 1, 4, 1, 1, EXE_MODE_DEFAULT, KernelType::SLM));
}
return selector.Default(tune_params(8, 2, 1, 4, 1, 1, EXE_MODE_DEFAULT));
if (params.weights.GetLayout() == WeightsLayout::os_iyx_osv16)
return selector.Default(tune_params(8, 1, 1, 4, 1, 1, EXE_MODE_DEFAULT));
else
return selector.Default(tune_params(8, 2, 1, 4, 1, 1, EXE_MODE_DEFAULT));
}
} else if (params.compressed && params.engineInfo.supports_immad) {
return selector.Default(tune_params(1, 1, 1, 4, 1, 1, EXE_MODE_DEFAULT));
Expand Down Expand Up @@ -480,6 +500,8 @@ JitConstants FullyConnected_bf_tiled::GetJitConstants(const fully_connected_para
}
if (params.weights.GetLayout() == WeightsLayout::os_is_yx_osv32_isv2)
jit.AddConstant(MakeJitConstant("W_IDX", "fi * TILE_K + kii"));
else if (params.weights.GetLayout() == WeightsLayout::os_iyx_osv16)
jit.AddConstant(MakeJitConstant("W_IDX", "fi * TILE_K + kii"));
else
jit.AddConstant(MakeJitConstant("W_IDX", "kii * TILE_OFM + fi"));

Expand Down Expand Up @@ -512,9 +534,17 @@ JitConstants FullyConnected_bf_tiled::GetJitConstants(const fully_connected_para
jit.AddConstant(MakeJitConstant("USE_SLM", 1));
jit.AddConstant(MakeJitConstant("LWS_BATCHES", lws_batches));
jit.AddConstant(MakeJitConstant("FILTER_LOAD_ITERS", weights_load_iters));

if (params.weights.GetLayout() == WeightsLayout::os_iyx_osv16) {
jit.AddConstant(MakeJitConstant("FILTER_ACTUAL_LOAD_BLOCK_SIZE", block_read_size / 2));
jit.Merge(make_int4_packed_type_jit_constant("INT4_PACKED_TYPE_PRELOAD", params.weights.GetDType(), weights_elements_per_load / 2));
} else {
jit.AddConstant(MakeJitConstant("FILTER_ACTUAL_LOAD_BLOCK_SIZE", block_read_size));
jit.Merge(make_int4_packed_type_jit_constant("INT4_PACKED_TYPE_PRELOAD", params.weights.GetDType(), weights_elements_per_load));
}

jit.AddConstant(MakeJitConstant("FILTER_LOAD_BLOCK_SIZE", block_read_size));
jit.AddConstant(MakeJitConstant("FILTER_ELEMENTS_PER_LOAD", weights_elements_per_load));
jit.Merge(make_int4_packed_type_jit_constant("INT4_PACKED_TYPE_PRELOAD", params.weights.GetDType(), weights_elements_per_load));
} else {
jit.AddConstant(MakeJitConstant("USE_SLM", 0));
}
Expand Down Expand Up @@ -675,9 +705,16 @@ KernelsData FullyConnected_bf_tiled::GetTunedKernelsDataByIndex(const Params &pa
return {};

tune_params tparams = GetAutoTuneParams(fc_params, KernelType::ANY, autoTuneIndex);
auto output_f = get_output_aligned_bf_size(fc_params, false).second;

WeightsLayout weights_layout = WeightsLayout::os_iyx_osv16;
if (fc_params.compressed && fc_params.inputs[0].GetDType() == Datatype::F16
&& (fc_params.weights.GetDType() == WeightsType::INT4 || fc_params.weights.GetDType() == WeightsType::UINT4)
&& is_weight_with_small_ofm(fc_params, output_f) && is_weight_with_large_ifm(fc_params)
&& (fc_params.weights.GetLayout() == WeightsLayout::oiyx || fc_params.weights.GetLayout() == WeightsLayout::os_iyx_osv16)) {
// Large K + Small N case to use [osv16 + TILE_K 4] + TILE_OFM 1 for batch 1
weights_layout = WeightsLayout::os_iyx_osv16;
} else if (fc_params.compressed && fc_params.inputs[0].GetDType() == Datatype::F16
// ioyx => os_is_yx_osv32_isv2 is not supported yet
&& (fc_params.weights.GetLayout() == WeightsLayout::oiyx || fc_params.weights.GetLayout() == WeightsLayout::os_is_yx_osv32_isv2)
&& (fc_params.weights.GetDType() == WeightsType::INT4 || fc_params.weights.GetDType() == WeightsType::UINT4)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "kernel_selector_common.h"
#include "kernel_selector_params.h"
#include "kernel_selector_utils.h"
#include "common_types.h"

namespace kernel_selector {

Expand All @@ -17,6 +18,7 @@ ParamsKey ReorderWeightsKernelInt4::GetSupportedKey() const {
k.EnableOutputWeightsType(WeightsType::INT4);
k.EnableInputWeightsLayout(WeightsLayout::oiyx);
k.EnableInputWeightsLayout(WeightsLayout::ioyx);
k.EnableOutputWeightsLayout(WeightsLayout::os_iyx_osv16);
k.EnableOutputWeightsLayout(WeightsLayout::os_iyx_osv32);
k.EnableOutputWeightsLayout(WeightsLayout::os_is_yx_osv32_isv2);
k.EnableOutputWeightsLayout(WeightsLayout::oiyx);
Expand All @@ -40,6 +42,8 @@ ReorderWeightsKernelInt4::DispatchData ReorderWeightsKernelInt4::SetDefault(cons
dispatchData.gws = { Align(output.OFM().v, 32) / 2, output.IFM().v, 1 };
} else if (output.GetLayout() == WeightsLayout::os_is_yx_osv32_isv2) {
dispatchData.gws = { Align(output.OFM().v, 32), output.IFM().v / 2, 1 };
} else if (output.GetLayout() == WeightsLayout::os_iyx_osv16) {
dispatchData.gws = { Align(output.OFM().v, 16), output.IFM().v / 2, 1 };
} else {
dispatchData.gws = { CeilDiv(output.LogicalSize(), 2), 1, 1 };
}
Expand All @@ -60,6 +64,7 @@ bool ReorderWeightsKernelInt4::Validate(const Params& params) const {

bool supported_case = input.GetLayout() == WeightsLayout::oiyx && output.GetLayout() == WeightsLayout::os_iyx_osv32;
supported_case |= input.GetLayout() == WeightsLayout::oiyx && output.GetLayout() == WeightsLayout::os_is_yx_osv32_isv2;
supported_case |= input.GetLayout() == WeightsLayout::oiyx && output.GetLayout() == WeightsLayout::os_iyx_osv16;
supported_case |= input.GetLayout() == WeightsLayout::ioyx && output.GetLayout() == WeightsLayout::oiyx;
supported_case |= input.GetLayout() == WeightsLayout::ioyx && output.GetLayout() == WeightsLayout::os_iyx_osv32;
return supported_case;
Expand Down

0 comments on commit eb16f7f

Please sign in to comment.