Skip to content

Commit

Permalink
enabled eltwise fusion to rms
Browse files Browse the repository at this point in the history
  • Loading branch information
e-ddykim committed Jan 14, 2025
1 parent 6198961 commit c7414cf
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include "extract_image_patches_inst.h"
#include "reduce_inst.h"
#include "group_normalization_inst.h"
#include "rms_inst.h"
#include <vector>
#include <map>
#include <list>
Expand Down Expand Up @@ -964,6 +965,7 @@ void prepare_primitive_fusing::fuse_simple_primitives(program &p) {
(parents[i].first->is_type<mvn>() &&
mvn_supports_fusings(parents[i].first->as<mvn>())) ||
(parents[i].first->is_type<group_normalization>()) ||
(parents[i].first->is_type<rms>()) ||
(parents[i].first->is_type<deconvolution>()) ||
(parents[i].first->is_type<permute>()) ||
(parents[i].first->is_type<resample>()) ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ KERNEL(rms_gpu_bfyx_opt)(
OPTIONAL_SHAPE_INFO_ARG
const __global INPUT0_TYPE* input,
const __global INPUT1_TYPE* gamma,
__global OUTPUT_TYPE* output)
__global OUTPUT_TYPE* output
#if HAS_FUSED_OPS_DECLS
, FUSED_OPS_DECLS
#endif
)
{
const uint data_idx = get_global_id(1);
const uint in_data_idx = get_global_id(0);
Expand Down Expand Up @@ -100,18 +104,53 @@ KERNEL(rms_gpu_bfyx_opt)(

rms = slm_buf[0];

#if HAS_FUSED_OPS
uint b, f, z, y, x;
#if INPUT_RANK == 1
f = z = y = x = 1;
#elif INPUT_RANK == 2
z = y = x = 1;
b = data_idx;
#elif INPUT_RANK == 3
x = 1;
f = data_idx % OUTPUT_FEATURE_NUM;
b = data_idx / OUTPUT_FEATURE_NUM;
#else
x = data_idx; // temp variable to calc indexes
y = x % OUTPUT_SIZE_Y; x = x / OUTPUT_SIZE_Y;
z = x % OUTPUT_SIZE_Z; x = x / OUTPUT_SIZE_Z;
f = x % OUTPUT_FEATURE_NUM; x = x / OUTPUT_FEATURE_NUM;
b = x % OUTPUT_BATCH_NUM; x = x / OUTPUT_BATCH_NUM;
#endif
#endif

i = 0;
if ((workers_per_data > SUB_GROUP_SIZE) && USE_BLOCK_WRITE)
{
for (; i < items_num - (items_num % SUBGROUP_BLOCK_SIZE); i += SUBGROUP_BLOCK_SIZE)
{
ACC_TYPE vec_gamma = TO_ACC_TYPE(BLOCK_READ(gamma, subgroup_offset + i * get_sub_group_size()));
OUTPUT_VEC_TYPE vec_tmp;
#if HAS_FUSED_OPS
LAST_DIM = subgroup_offset + i * get_sub_group_size() + get_sub_group_local_id();
#endif
#if SUBGROUP_BLOCK_SIZE == 1
vec_tmp = TO_OUTPUT_TYPE(rms * data[i] * vec_gamma);
OUTPUT_TYPE normalized = TO_OUTPUT_TYPE(rms * data[i] * vec_gamma);
#if HAS_FUSED_OPS
FUSED_OPS;
normalized = FUSED_OPS_RESULT;
#endif
vec_tmp = normalized;
#else
unroll_for (int j = 0; j < SUBGROUP_BLOCK_SIZE; j++)
vec_tmp[j] = TO_OUTPUT_TYPE(rms * data[i + j] * vec_gamma[j]);
unroll_for (int j = 0; j < SUBGROUP_BLOCK_SIZE; j++) {
OUTPUT_TYPE normalized = TO_OUTPUT_TYPE(rms * data[i + j] * vec_gamma[j]);
#if HAS_FUSED_OPS
LAST_DIM += j * get_sub_group_size();
FUSED_OPS;
normalized = FUSED_OPS_RESULT;
#endif
vec_tmp[j] = normalized;
}
#endif
BLOCK_WRITE(output, data_offset + subgroup_offset + i * get_sub_group_size(), vec_tmp);
}
Expand All @@ -120,13 +159,25 @@ KERNEL(rms_gpu_bfyx_opt)(
for (; i < items_num; i++)
{
ACCUMULATOR_TYPE temp = TO_ACCUMULATOR_TYPE(gamma[subgroup_offset + get_sub_group_local_id() + i * get_sub_group_size()]);
output[data_offset + subgroup_offset + get_sub_group_local_id() + i * get_sub_group_size()] = TO_OUTPUT_TYPE(rms * data[i] * temp);
OUTPUT_TYPE normalized = TO_OUTPUT_TYPE(rms * data[i] * temp);
#if HAS_FUSED_OPS
LAST_DIM = subgroup_offset + get_sub_group_local_id() + i * get_sub_group_size();
FUSED_OPS;
normalized = FUSED_OPS_RESULT;
#endif
output[data_offset + subgroup_offset + get_sub_group_local_id() + i * get_sub_group_size()] = normalized;
}

if (in_data_idx < leftovers)
{
ACCUMULATOR_TYPE temp = TO_ACCUMULATOR_TYPE(gamma[workers_per_data * items_num + in_data_idx]);
output[data_offset + workers_per_data * items_num + in_data_idx] = TO_OUTPUT_TYPE(rms * data[items_num] * temp);
OUTPUT_TYPE normalized = TO_OUTPUT_TYPE(rms * data[items_num] * temp);
#if HAS_FUSED_OPS
LAST_DIM = workers_per_data * items_num + in_data_idx;
FUSED_OPS;
normalized = FUSED_OPS_RESULT;
#endif
output[data_offset + workers_per_data * items_num + in_data_idx] = normalized;
}
}
#undef USE_BLOCK_WRITE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ KERNEL(rms_gpu_ref)(
OPTIONAL_SHAPE_INFO_ARG
const __global INPUT0_TYPE* input,
const __global INPUT1_TYPE* gamma,
__global OUTPUT_TYPE* output)
__global OUTPUT_TYPE* output
#if HAS_FUSED_OPS_DECLS
, FUSED_OPS_DECLS
#endif
)
{
const uint b = get_global_id(0);
const uint f = get_global_id(1);
Expand Down Expand Up @@ -38,6 +42,10 @@ KERNEL(rms_gpu_ref)(
const uint gamma_idx = z;
#endif
OUTPUT_TYPE result = TO_OUTPUT_TYPE(rms) * TO_OUTPUT_TYPE(input[input_idx]) * TO_OUTPUT_TYPE(gamma[gamma_idx]);
#if HAS_FUSED_OPS
FUSED_OPS;
result = FUSED_OPS_RESULT;
#endif
output[output_idx] = result;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,35 @@ JitConstants RMSKernelBfyxOpt::GetJitConstants(const rms_params& params, Dispatc
}
jit.AddConstant(MakeJitConstant("SUB_GROUP_SIZE", subgroup_size));
jit.AddConstant(MakeJitConstant("SUBGROUP_BLOCK_SIZE", dispatchData.subgroupBlockSize));
if (!params.fused_ops.empty()) {
jit.AddConstant(MakeJitConstant("INPUT_RANK", params.ov_input_rank));
switch (params.ov_input_rank) {
case 1 :
jit.AddConstant(MakeJitConstant("LAST_DIM", "b"));
break;
case 2 :
jit.AddConstant(MakeJitConstant("LAST_DIM", "f"));
break;
case 3 :
jit.AddConstant(MakeJitConstant("LAST_DIM", "y"));
break;
default:
jit.AddConstant(MakeJitConstant("LAST_DIM", "x"));
break;
}

std::vector<std::string> idx_order;
if (params.inputs[0].GetDims().size() == 5) {
idx_order = { "(b)", "(f)", "(z)", "(y)", "(x)" };
} else if (params.inputs[0].GetDims().size() <= 4) {
idx_order = { "(b)", "(f)", "(y)", "(x)" };
} else {
OPENVINO_THROW("rms_bfyx_opt doesn't support 5D or higher dims.");
}

auto conf = FusedOpsConfiguration("", idx_order, "normalized", params.outputs[0].GetDType(), 1);
jit.Merge(MakeFusedOpsJitConstants(params, { conf }));
}

return jit;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ class RMSKernelBfyxOpt : public RMSKernelBase {
ParamsKey GetSupportedKey() const override;

protected:
std::vector<FusedOpType> GetSupportedFusedOps() const override {
return {
FusedOpType::ELTWISE
};
}
bool Validate(const Params&) const override;
DispatchData SetDefault(const rms_params& params) const override;
JitConstants GetJitConstants(const rms_params& params, DispatchData dispatchData) const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,26 @@ ParamsKey RMSKernelRef::GetSupportedKey() const {
return k;
}

JitConstants RMSKernelRef::GetJitConstants(const rms_params& params, DispatchData dispatchData) const {
auto jit = Parent::GetJitConstants(params, dispatchData);

if (!params.fused_ops.empty()) {
std::vector<std::string> idx_order;
if (params.inputs[0].GetDims().size() == 5) {
idx_order = { "(b)", "(f)", "(z)", "(y)", "(x)" };
} else if (params.inputs[0].GetDims().size() <= 4) {
idx_order = { "(b)", "(f)", "(y)", "(x)" };
} else {
OPENVINO_THROW("rms_ref doesn't support 5D or higher dims.");
}

auto conf = FusedOpsConfiguration("", idx_order, "result", params.outputs[0].GetDType(), 1);
jit.Merge(MakeFusedOpsJitConstants(params, { conf }));
}

return jit;
}

KernelsData RMSKernelRef::GetKernelsData(const Params& params) const {
return GetCommonKernelsData(params);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,13 @@ class RMSKernelRef : public RMSKernelBase {
KernelsData GetKernelsData(const Params& params) const override;
KernelsPriority GetKernelsPriority(const Params& params) const override;
ParamsKey GetSupportedKey() const override;

protected:
std::vector<FusedOpType> GetSupportedFusedOps() const override {
return {
FusedOpType::ELTWISE
};
}
JitConstants GetJitConstants(const rms_params& params, DispatchData dispatchData) const override;
};
} // namespace kernel_selector

0 comments on commit c7414cf

Please sign in to comment.