Skip to content

Commit

Permalink
refactor cpu test instances
Browse files Browse the repository at this point in the history
  • Loading branch information
antonvor committed Oct 4, 2023
1 parent 5498491 commit 6dfeab1
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
namespace {
using ov::test::Gather7LayerTest;
using ov::test::Gather8LayerTest;
using ov::test::Gather8withIndicesDataLayerTest;

const std::vector<ov::element::Type> model_types = {
ov::element::f32,
Expand Down Expand Up @@ -203,19 +204,16 @@ const auto gatherParamsVec3 = testing::Combine(

INSTANTIATE_TEST_CASE_P(smoke_Vec3, Gather8LayerTest, gatherParamsVec3, Gather8LayerTest::getTestCaseName);

gather7ParamsTuple dummyParams = {
std::vector<size_t>{2, 3},
std::vector<size_t>{2, 2},
std::tuple<int, int>{1, 1},
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::UNSPECIFIED,
InferenceEngine::Precision::UNSPECIFIED,
InferenceEngine::Layout::ANY,
InferenceEngine::Layout::ANY,
ov::test::utils::DEVICE_CPU,

const ov::test::gather7ParamsTuple dummyParams = {
ov::test::static_shapes_to_test_representation(std::vector<ov::Shape>{{2, 3}}), // input shape
ov::Shape{2, 2}, // indices shape
std::tuple<int, int>{1, 1}, // axis, batch
ov::element::f32, // model type
ov::test::utils::DEVICE_CPU // device
};

std::vector<std::vector<int>> indicesData = {
const std::vector<std::vector<int64_t>> indicesData = {
{0, 1, 2, 0}, // positive in bound
{-1, -2, -3, -1}, // negative in bound
{-1, 0, 1, 2}, // positive and negative in bound
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,9 @@ TEST_P(Gather8LayerTest, Inference) {
TEST_P(Gather8IndiceScalarLayerTest, Inference) {
run();
};

TEST_P(Gather8withIndicesDataLayerTest, Inference) {
run();
};
} // namespace test
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,20 @@ class Gather8IndiceScalarLayerTest : public testing::WithParamInterface<gather7P
protected:
void SetUp() override;
};

typedef std::tuple<
gather7ParamsTuple,
std::vector<int64_t> // indices data
> gather8withIndicesDataParamsTuple;

class Gather8withIndicesDataLayerTest : public testing::WithParamInterface<gather8withIndicesDataParamsTuple>,
virtual public ov::test::SubgraphBaseTest {
public:
static std::string getTestCaseName(const testing::TestParamInfo<gather8withIndicesDataParamsTuple>& obj);

protected:
void SetUp() override;
};

} // namespace test
} // namespace ov
71 changes: 71 additions & 0 deletions src/tests/functional/shared_test_classes/src/single_op/gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,5 +167,76 @@ void Gather8IndiceScalarLayerTest::SetUp() {
auto result = std::make_shared<ov::op::v0::Result>(gather);
function = std::make_shared<ov::Model>(result, ov::ParameterVector{param}, "gather");
}

std::string Gather8withIndicesDataLayerTest::getTestCaseName(const testing::TestParamInfo<gather8withIndicesDataParamsTuple>& obj) {
gather7ParamsTuple basicParams;
std::vector<int64_t> indicesData;
std::tie(basicParams, indicesData) = obj.param;

std::tuple<int, int> axis_batch_idx;
std::vector<int> indices;
ov::Shape indices_shape;
std::vector<InputShape> shapes;
ov::element::Type model_type;
std::string device_name;
std::tie(shapes, indices_shape, axis_batch_idx, model_type, device_name) = basicParams;

std::ostringstream result;
result << "IS=(";
for (size_t i = 0lu; i < shapes.size(); i++) {
result << ov::test::utils::partialShape2str({shapes[i].first}) << (i < shapes.size() - 1lu ? "_" : "");
}
result << ")_TS=";
for (size_t i = 0lu; i < shapes.front().second.size(); i++) {
result << "{";
for (size_t j = 0lu; j < shapes.size(); j++) {
result << ov::test::utils::vec2str(shapes[j].second[i]) << (j < shapes.size() - 1lu ? "_" : "");
}
result << "}_";
}
result << "axis=" << std::get<0>(axis_batch_idx) << "_";
result << "batch_idx=" << std::get<1>(axis_batch_idx) << "_";
result << "indices_shape=" << ov::test::utils::vec2str(indices_shape) << "_";
result << "netPRC=" << model_type.get_type_name() << "_";
result << "trgDev=" << device_name << "_";

result << "indicesData=" << ov::test::utils::vec2str(indicesData) << "_";

return result.str();
}

void Gather8withIndicesDataLayerTest::SetUp() {
gather7ParamsTuple basicParams;
std::vector<int64_t> indicesData;
std::tie(basicParams, indicesData) = GetParam();

std::tuple<int, int> axis_batch_idx;
ov::Shape indices_shape;
std::vector<InputShape> shapes;
ov::element::Type model_type;
std::tie(shapes, indices_shape, axis_batch_idx, model_type, targetDevice) = basicParams;
init_input_shapes(shapes);

int axis = std::get<0>(axis_batch_idx);
int batch_idx = std::get<1>(axis_batch_idx);

auto param = std::make_shared<ov::op::v0::Parameter>(model_type, inputDynamicShapes.front());

// create indices tensor and fill data
ov::Tensor indices_node_tensor{ov::element::i64, indices_shape};
auto indices_tensor_data = indices_node_tensor.data<int64_t>();
for (size_t i = 0; i < shape_size(indices_shape); i++) {
indices_tensor_data[i] = indicesData[i];
}

auto indices_node = std::make_shared<ov::op::v0::Constant>(indices_node_tensor);
auto axis_node = ov::op::v0::Constant::create(ov::element::i64, ov::Shape(), {axis});

auto gather = std::make_shared<ov::op::v8::Gather>(param, indices_node, axis_node, batch_idx);

auto result = std::make_shared<ov::op::v0::Result>(gather);
function = std::make_shared<ov::Model>(result, ov::ParameterVector{param}, "gather");
}

} // namespace test
} // namespace ov

0 comments on commit 6dfeab1

Please sign in to comment.