Skip to content

Commit

Permalink
[gpu]:[SearchSorted]: Added unittests.
Browse files Browse the repository at this point in the history
  • Loading branch information
pkowalc1 committed Nov 4, 2024
1 parent 7835bb3 commit ae7aebf
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@
//

#pragma once
#include <algorithm>
#include <vector>

#include "openvino/op/util/attr_types.hpp"
#include "primitive.hpp"

namespace cldnn {
Expand All @@ -16,21 +12,36 @@ struct search_sorted : public primitive_base<search_sorted> {

search_sorted() : primitive_base("", {}) {}

search_sorted(const primitive_id& id, const input_info& sorted, const input_info& values, bool right_mode)
: primitive_base(id, {sorted, values}),
right_mode(right_mode) {}

/// @brief Enable/Disable right mode(check specification for details).
bool right_mode = false;

size_t hash() const override {
size_t seed = primitive::hash();
seed = hash_combine(seed, right_mode);
return seed;
}

bool operator==(const primitive& rhs) const override {
return compare_common_params(rhs);
if (!compare_common_params(rhs))
return false;

auto rhs_casted = downcast<const search_sorted>(rhs);

return right_mode == rhs_casted.right_mode;
}

void save(BinaryOutputBuffer& ob) const override {
primitive_base<search_sorted>::save(ob);
ob << right_mode;
}

void load(BinaryInputBuffer& ib) override {
primitive_base<search_sorted>::load(ib);
ib >> right_mode;
}
};
} // namespace cldnn
146 changes: 146 additions & 0 deletions src/plugins/intel_gpu/tests/unit/test_cases/search_sorted_gpu_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <intel_gpu/primitives/data.hpp>
#include <intel_gpu/primitives/input_layout.hpp>
#include <intel_gpu/primitives/search_sorted.hpp>

#include "test_utils.h"

using namespace cldnn;
using namespace ::tests;

namespace {

constexpr float EPS = 2e-3f;

namespace helpers {
// TODO: Move to common place.

// Converts float vector to another type vector.
template <typename T>
std::vector<T> ConverFloatVector(const std::vector<float>& vec) {
std::vector<T> ret;
ret.reserve(vec.size());
for (const auto& val : vec) {
ret.push_back(T(val));
}
return ret;
}

// Allocates tensoer with given shape and data.
template <typename TDataType>
memory::ptr AllocateTensor(ov::PartialShape shape, const std::vector<TDataType>& data) {
const layout lo = {shape, ov::element::from<TDataType>(), cldnn::format::bfyx};
EXPECT_EQ(lo.get_linear_size(), data.size());
memory::ptr tensor = get_test_engine().allocate_memory(lo);
set_values<TDataType>(tensor, data);
return tensor;
}
} // namespace helpers

struct SearchSortedTestParams {
ov::PartialShape sortedShape;
ov::PartialShape valuesShape;
bool rightMode;
std::vector<float> sortedData;
std::vector<float> valuesData;
std::vector<int64_t> expectedOutput;
std::string testcaseName;
};

class search_sorted_test : public ::testing::TestWithParam<SearchSortedTestParams> {
public:
static std::string getTestCaseName(const testing::TestParamInfo<SearchSortedTestParams>& obj) {
auto param = obj.param;
std::ostringstream result;
result << "_" << param.testcaseName;
return result.str();
}

struct SearchSortedInferenceParams {
bool rightMode;
memory::ptr sorted;
memory::ptr values;
memory::ptr expectedOutput;
};

template <ov::element::Type_t ET>
SearchSortedInferenceParams PrepareInferenceParams(const SearchSortedTestParams& testParam) {
using T = typename ov::element_type_traits<ET>::value_type;
SearchSortedInferenceParams ret;

ret.rightMode = testParam.rightMode;

ret.sorted =
helpers::AllocateTensor<T>(testParam.sortedShape, helpers::ConverFloatVector<T>(testParam.sortedData));
ret.values =
helpers::AllocateTensor<T>(testParam.valuesShape, helpers::ConverFloatVector<T>(testParam.valuesData));
ret.values = helpers::AllocateTensor<int64_t>(testParam.valuesShape, testParam.expectedOutput);

return ret;
}

void Execute(const SearchSortedInferenceParams& params) {
// Prepare the network.
auto stream = get_test_stream_ptr(get_test_default_config(engine_));

topology topology;
topology.add(input_layout("sorted", params.sorted->get_layout()));
topology.add(input_layout("values", params.sorted->get_layout()));
topology.add(search_sorted("search_sorted", input_info("sorted"), input_info("values"), params.rightMode));
topology.add(reorder("out", input_info("search_sorted"), cldnn::format::bfyx, data_types::f32));

cldnn::network::ptr network = get_network(engine_, topology, get_test_default_config(engine_), stream, false);

network->set_input_data("sorted", params.sorted);
network->set_input_data("values", params.values);

// Run and check results.
auto outputs = network->execute();

auto output = outputs.at("out").get_memory();
cldnn::mem_lock<float> output_ptr(output, get_test_stream());
cldnn::mem_lock<float> wanted_output_ptr(params.expectedOutput, get_test_stream());

ASSERT_EQ(output->get_layout(), params.expectedOutput->get_layout());
ASSERT_EQ(output_ptr.size(), wanted_output_ptr.size());
for (size_t i = 0; i < output_ptr.size(); ++i)
ASSERT_TRUE(are_equal(wanted_output_ptr[i], output_ptr[i], EPS));
}

private:
engine& engine_ = get_test_engine();
};

std::vector<SearchSortedTestParams> generateTestParams() {
std::vector<SearchSortedTestParams> params;
#define TEST_DATA(sorted_shape, values_shape, right_mode, sorted_data, values_data, expected_output_data, description) \
params.push_back(SearchSortedTestParams{sorted_shape, \
values_shape, \
right_mode, \
sorted_data, \
values_data, \
expected_output_data, \
description});

#include "unit_test_utils/tests_data/search_sorted_data.h"
#undef TEST_DATA
return params;
}

} // namespace

#define ROI_ALIGN_ROTATED_TEST_P(precision) \
TEST_P(search_sorted_test, ref_comp_##precision) { \
Execute(PrepareInferenceParams<ov::element::Type_t::precision>(GetParam())); \
}

ROI_ALIGN_ROTATED_TEST_P(f16);
ROI_ALIGN_ROTATED_TEST_P(f32);

INSTANTIATE_TEST_SUITE_P(search_sorted_test_suit,
search_sorted_test,
testing::ValuesIn(generateTestParams()),
search_sorted_test::getTestCaseName);

0 comments on commit ae7aebf

Please sign in to comment.