diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/search_sorted.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/search_sorted.hpp index 45641af26350ae..1288b39e25e617 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/primitives/search_sorted.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/search_sorted.hpp @@ -3,10 +3,6 @@ // #pragma once -#include -#include - -#include "openvino/op/util/attr_types.hpp" #include "primitive.hpp" namespace cldnn { @@ -16,21 +12,36 @@ struct search_sorted : public primitive_base { search_sorted() : primitive_base("", {}) {} + search_sorted(const primitive_id& id, const input_info& sorted, const input_info& values, bool right_mode) + : primitive_base(id, {sorted, values}), + right_mode(right_mode) {} + + /// @brief Enable/Disable right mode(check specification for details). + bool right_mode = false; + size_t hash() const override { size_t seed = primitive::hash(); + seed = hash_combine(seed, right_mode); return seed; } bool operator==(const primitive& rhs) const override { - return compare_common_params(rhs); + if (!compare_common_params(rhs)) + return false; + + auto rhs_casted = downcast(rhs); + + return right_mode == rhs_casted.right_mode; } void save(BinaryOutputBuffer& ob) const override { primitive_base::save(ob); + ob << right_mode; } void load(BinaryInputBuffer& ib) override { primitive_base::load(ib); + ib >> right_mode; } }; } // namespace cldnn diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/search_sorted_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/search_sorted_gpu_test.cpp new file mode 100644 index 00000000000000..a46ddf4992e86c --- /dev/null +++ b/src/plugins/intel_gpu/tests/unit/test_cases/search_sorted_gpu_test.cpp @@ -0,0 +1,146 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include + +#include "test_utils.h" + +using namespace cldnn; +using namespace ::tests; + +namespace { + +constexpr float EPS = 2e-3f; + +namespace helpers { +// TODO: Move to common place. + +// Converts float vector to another type vector. +template +std::vector ConverFloatVector(const std::vector& vec) { + std::vector ret; + ret.reserve(vec.size()); + for (const auto& val : vec) { + ret.push_back(T(val)); + } + return ret; +} + +// Allocates tensoer with given shape and data. +template +memory::ptr AllocateTensor(ov::PartialShape shape, const std::vector& data) { + const layout lo = {shape, ov::element::from(), cldnn::format::bfyx}; + EXPECT_EQ(lo.get_linear_size(), data.size()); + memory::ptr tensor = get_test_engine().allocate_memory(lo); + set_values(tensor, data); + return tensor; +} +} // namespace helpers + +struct SearchSortedTestParams { + ov::PartialShape sortedShape; + ov::PartialShape valuesShape; + bool rightMode; + std::vector sortedData; + std::vector valuesData; + std::vector expectedOutput; + std::string testcaseName; +}; + +class search_sorted_test : public ::testing::TestWithParam { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + auto param = obj.param; + std::ostringstream result; + result << "_" << param.testcaseName; + return result.str(); + } + + struct SearchSortedInferenceParams { + bool rightMode; + memory::ptr sorted; + memory::ptr values; + memory::ptr expectedOutput; + }; + + template + SearchSortedInferenceParams PrepareInferenceParams(const SearchSortedTestParams& testParam) { + using T = typename ov::element_type_traits::value_type; + SearchSortedInferenceParams ret; + + ret.rightMode = testParam.rightMode; + + ret.sorted = + helpers::AllocateTensor(testParam.sortedShape, helpers::ConverFloatVector(testParam.sortedData)); + ret.values = + helpers::AllocateTensor(testParam.valuesShape, helpers::ConverFloatVector(testParam.valuesData)); + ret.values = helpers::AllocateTensor(testParam.valuesShape, testParam.expectedOutput); + + return ret; + } + + void Execute(const SearchSortedInferenceParams& params) { + // Prepare the network. + auto stream = get_test_stream_ptr(get_test_default_config(engine_)); + + topology topology; + topology.add(input_layout("sorted", params.sorted->get_layout())); + topology.add(input_layout("values", params.sorted->get_layout())); + topology.add(search_sorted("search_sorted", input_info("sorted"), input_info("values"), params.rightMode)); + topology.add(reorder("out", input_info("search_sorted"), cldnn::format::bfyx, data_types::f32)); + + cldnn::network::ptr network = get_network(engine_, topology, get_test_default_config(engine_), stream, false); + + network->set_input_data("sorted", params.sorted); + network->set_input_data("values", params.values); + + // Run and check results. + auto outputs = network->execute(); + + auto output = outputs.at("out").get_memory(); + cldnn::mem_lock output_ptr(output, get_test_stream()); + cldnn::mem_lock wanted_output_ptr(params.expectedOutput, get_test_stream()); + + ASSERT_EQ(output->get_layout(), params.expectedOutput->get_layout()); + ASSERT_EQ(output_ptr.size(), wanted_output_ptr.size()); + for (size_t i = 0; i < output_ptr.size(); ++i) + ASSERT_TRUE(are_equal(wanted_output_ptr[i], output_ptr[i], EPS)); + } + +private: + engine& engine_ = get_test_engine(); +}; + +std::vector generateTestParams() { + std::vector params; +#define TEST_DATA(sorted_shape, values_shape, right_mode, sorted_data, values_data, expected_output_data, description) \ + params.push_back(SearchSortedTestParams{sorted_shape, \ + values_shape, \ + right_mode, \ + sorted_data, \ + values_data, \ + expected_output_data, \ + description}); + +#include "unit_test_utils/tests_data/search_sorted_data.h" +#undef TEST_DATA + return params; +} + +} // namespace + +#define ROI_ALIGN_ROTATED_TEST_P(precision) \ + TEST_P(search_sorted_test, ref_comp_##precision) { \ + Execute(PrepareInferenceParams(GetParam())); \ + } + +ROI_ALIGN_ROTATED_TEST_P(f16); +ROI_ALIGN_ROTATED_TEST_P(f32); + +INSTANTIATE_TEST_SUITE_P(search_sorted_test_suit, + search_sorted_test, + testing::ValuesIn(generateTestParams()), + search_sorted_test::getTestCaseName);