forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
/
opset.cpp
125 lines (105 loc) · 4.2 KB
/
opset.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/opsets/opset.hpp"
#include <gtest/gtest.h>
#include "openvino/op/op.hpp"
#include "openvino/opsets/opset1.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/opsets/opset2.hpp"
#include "openvino/opsets/opset3.hpp"
#include "openvino/opsets/opset4.hpp"
#include "openvino/opsets/opset5.hpp"
#include "openvino/opsets/opset6.hpp"
#include "openvino/opsets/opset7.hpp"
#include "openvino/opsets/opset8.hpp"
#include "openvino/opsets/opset9.hpp"
struct OpsetTestParams {
using OpsetGetterFunction = std::function<const ov::OpSet&()>;
OpsetTestParams(const OpsetGetterFunction& opset_getter_, const uint32_t expected_ops_count_)
: opset_getter{opset_getter_},
expected_ops_count{expected_ops_count_} {}
OpsetGetterFunction opset_getter;
uint32_t expected_ops_count;
};
class OpsetTests : public testing::TestWithParam<OpsetTestParams> {};
struct OpsetTestNameGenerator {
std::string operator()(const testing::TestParamInfo<OpsetTestParams>& info) const {
return "opset" + std::to_string(info.index + 1);
}
};
TEST_P(OpsetTests, create_parameter) {
const auto& params = GetParam();
const auto op = std::unique_ptr<ov::Node>(params.opset_getter().create("Parameter"));
ASSERT_NE(nullptr, op);
EXPECT_TRUE(ov::op::util::is_parameter(op.get()));
}
TEST_P(OpsetTests, opset_dump) {
const auto& params = GetParam();
const auto& opset = params.opset_getter();
std::cout << "All opset operations: ";
for (const auto& t : opset.get_types_info()) {
std::cout << t.name << " ";
}
std::cout << std::endl;
ASSERT_EQ(params.expected_ops_count, opset.get_types_info().size());
}
INSTANTIATE_TEST_SUITE_P(opset,
OpsetTests,
testing::Values(OpsetTestParams{ov::get_opset1, 110},
OpsetTestParams{ov::get_opset2, 112},
OpsetTestParams{ov::get_opset3, 127},
OpsetTestParams{ov::get_opset4, 137},
OpsetTestParams{ov::get_opset5, 145},
OpsetTestParams{ov::get_opset6, 152},
OpsetTestParams{ov::get_opset7, 156},
OpsetTestParams{ov::get_opset8, 167},
OpsetTestParams{ov::get_opset9, 173},
OpsetTestParams{ov::get_opset10, 177}),
OpsetTestNameGenerator{});
class MyOpOld : public ov::op::Op {
public:
static constexpr ov::DiscreteTypeInfo type_info{"MyOpOld", static_cast<uint64_t>(0)};
const ov::DiscreteTypeInfo& get_type_info() const override {
return type_info;
}
MyOpOld() = default;
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& inputs) const override {
return nullptr;
}
};
constexpr ov::DiscreteTypeInfo MyOpOld::type_info;
class MyOpNewFromOld : public MyOpOld {
public:
OPENVINO_OP("MyOpNewFromOld", "custom_opset", MyOpOld);
MyOpNewFromOld() = default;
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& inputs) const override {
return nullptr;
}
};
class MyOpIncorrect : public MyOpOld {
public:
OPENVINO_OP("MyOpIncorrect", "custom_opset", MyOpOld);
MyOpIncorrect() = default;
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& inputs) const override {
return nullptr;
}
};
class MyOpNew : public ov::op::Op {
public:
OPENVINO_OP("MyOpNew", "custom_opset", MyOpOld);
MyOpNew() = default;
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& inputs) const override {
return nullptr;
}
};
TEST(opset, custom_opset) {
ov::OpSet opset;
opset.insert<MyOpIncorrect>();
opset.insert<MyOpNewFromOld>();
opset.insert<MyOpNew>();
EXPECT_EQ(opset.get_types_info().size(), 3);
EXPECT_TRUE(opset.contains_type("MyOpNewFromOld"));
EXPECT_TRUE(opset.contains_type("MyOpNew"));
EXPECT_TRUE(opset.contains_type("MyOpIncorrect"));
}