Skip to content

Commit

Permalink
[GPU]: SearchSorted: Added dynamic shape support.
Browse files Browse the repository at this point in the history
  • Loading branch information
pkowalc1 committed Nov 14, 2024
1 parent bf46223 commit 8f8c333
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,33 +23,31 @@ OUTPUT_TYPE binary_search_thread(const INPUT0_TYPE search_val,
sorted_end_idx = half_idx;
else
sorted_begin_idx = half_idx + 1;

//printf("ThreadIdx:%i, sorted_begin_idx=%i, sorted_end_idx=%i, search_val: %f, half_val: %f\n", get_global_id(0), sorted_begin_idx, sorted_end_idx, search_val, half_val );
}

return sorted_begin_idx;
}

#undef CMP

KERNEL(search_sorted_ref)(const __global INPUT0_TYPE* restrict sorted,
const __global INPUT1_TYPE* restrict values,
__global OUTPUT_TYPE* restrict output)
KERNEL(search_sorted_ref)(
OPTIONAL_SHAPE_INFO_ARG
const __global INPUT0_TYPE* restrict sorted,
const __global INPUT1_TYPE* restrict values,
__global OUTPUT_TYPE* restrict output)
{
// INPUT0_TYPE has to be egual to INPUT1_TYPE
const int this_thread_idx = get_global_id(0);
const INPUT0_TYPE search_val = values[this_thread_idx];

const int SORTED_STRIDE = INPUT0_BATCH_NUM*INPUT0_FEATURE_NUM*INPUT0_SIZE_Y;
const int SORTED_STRIDE = INPUT0_BATCH_NUM*INPUT0_FEATURE_NUM*INPUT0_SIZE_Y*INPUT0_SIZE_Z;

// NOTE: SORTED_STRIDE-1 handles here a special case when sorted is actually 1D
// tensor and values is ND tensor. In such case we effectively want sorted_offset
// to be 0.
const int sorted_offset = min(this_thread_idx/INPUT1_SIZE_X, SORTED_STRIDE-1);

OUTPUT_TYPE sorted_begin_idx = sorted_offset * INPUT0_SIZE_X;
OUTPUT_TYPE sorted_end_idx = sorted_begin_idx + INPUT0_SIZE_X;

const OUTPUT_TYPE idx = binary_search_thread(search_val,
sorted + sorted_begin_idx,
0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,19 @@ KernelsData SearchSortedKernelBase::GetCommonKernelsData(const Params& params) c
auto jit = CreateJit(kernelName, cldnn_jit, entry_point);

auto& kernel = k_data.kernels[0];
FillCLKernelData(kernel, dispatchData, params.engineInfo, kernelName, jit, entry_point, "", false, false, 2);
FillCLKernelData(kernel,
dispatchData,
params.engineInfo,
kernelName,
jit,
entry_point,
"",
false,
false,
2,
GetFusedPrimitiveInputsCount(params),
1,
prim_params.outputs[0].is_dynamic());

return {k_data};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ ParamsKey SearchSortedKernelRef::GetSupportedKey() const {
k.EnableTensorPitches();
k.EnableBatching();
k.EnableDifferentTypes();

k.EnableDynamicShapesSupport();
return k;
}

Expand Down

0 comments on commit 8f8c333

Please sign in to comment.