Skip to content

Commit

Permalink
change random generator, separate structure for ranges
Browse files Browse the repository at this point in the history
  • Loading branch information
sbalandi committed Apr 17, 2024
1 parent a415155 commit d0310f0
Show file tree
Hide file tree
Showing 7 changed files with 335 additions and 398 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,14 @@
#include "openvino/op/convert.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/roi_align.hpp"
#include "openvino/op/fake_quantize.hpp"
#include "openvino/op/deformable_convolution.hpp"
#include "openvino/op/gru_sequence.hpp"
#include "openvino/op/batch_norm.hpp"
#include "openvino/op/select.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/strided_slice.hpp"
#include "openvino/op/lstm_sequence.hpp"

namespace ov {
namespace test {
Expand All @@ -108,11 +116,11 @@ namespace utils {
static std::map<ov::NodeTypeInfo, std::vector<std::vector<ov::test::utils::InputGenerateData>>> inputRanges = {
// NodeTypeInfo: {IntRanges{}, RealRanges{}} (Ranges are used by generate<ov::Node>)
{ ov::op::v0::Erf::get_type_info_static(), {{{-3, 6}}, {{-3, 6, 10}}} },
{ ov::op::v1::Divide::get_type_info_static(), {{{101, 100}}, {{1, 2, 128}}} },
{ ov::op::v1::Divide::get_type_info_static(), {{{101, 100}}, {{2, 2, 128}}} },
{ ov::op::v1::FloorMod::get_type_info_static(), {{{2, 4}}, {{2, 2, 128}}} },
{ ov::op::v1::Mod::get_type_info_static(), {{{2, 4}}, {{2, 2, 128}}} },
{ ov::op::v1::ReduceMax::get_type_info_static(), {{{0, 5}}, {{-5, 5, 1000}}} },
{ ov::op::v1::ReduceMean::get_type_info_static(), {{{0, 5}}, {{0, 5, 1000}}} },
{ ov::op::v1::ReduceMean::get_type_info_static(), {{{0, 5, 1000}}, {{0, 5, 1000}}} },
{ ov::op::v1::ReduceMin::get_type_info_static(), {{{0, 5}}, {{0, 5, 1000}}} },
{ ov::op::v1::ReduceProd::get_type_info_static(), {{{0, 5}}, {{0, 5, 1000}}} },
{ ov::op::v1::ReduceSum::get_type_info_static(), {{{0, 5}}, {{0, 5, 1000}}} },
Expand Down Expand Up @@ -155,7 +163,7 @@ static std::map<ov::NodeTypeInfo, std::vector<std::vector<ov::test::utils::Input
{ ov::op::v0::Tan::get_type_info_static(), {{{0, 15}}, {{-1, 2, 32768}}} },
{ ov::op::v0::Elu::get_type_info_static(), {{{0, 15}}, {{-1, 2, 32768}}} },
{ ov::op::v0::Erf::get_type_info_static(), {{{0, 15}}, {{-1, 2, 32768}}} },
{ ov::op::v0::HardSigmoid::get_type_info_static(), {{{0, 15}}, {{-1, 2, 32768}}} },
{ ov::op::v0::HardSigmoid::get_type_info_static(), {{{0, 15}}, {{-1, 2, 32768}, {0.2, 0, 1, 1, true}, {0.5, 0, 1, 1, true}}} },
{ ov::op::v0::Selu::get_type_info_static(), {{{0, 15}}, {{-1, 2, 32768}}} },
{ ov::op::v0::Sigmoid::get_type_info_static(), {{{0, 15}}, {{-1, 2, 32768}}} },
{ ov::op::v0::Tanh::get_type_info_static(), {{{0, 15}}, {{-1, 2, 32768}}} },
Expand All @@ -178,8 +186,8 @@ static std::map<ov::NodeTypeInfo, std::vector<std::vector<ov::test::utils::Input
// new temp
{ ov::op::v1::Convolution::get_type_info_static(), {{{0, 15}}, {{0, 8, 32}}} },
{ ov::op::v1::ConvolutionBackpropData::get_type_info_static(), {{{0, 15}}, {{0, 8, 32}}} },
{ ov::op::v1::GroupConvolution::get_type_info_static(), {{{0, 15}}, {{0, 8, 32}}} },
{ ov::op::v1::GroupConvolutionBackpropData::get_type_info_static(), {{{0, 15}}, {{0, 8, 32}}} },
{ ov::op::v1::GroupConvolution::get_type_info_static(), {{{0, 15}}, {{0, 8, 32}}}},
{ ov::op::v1::GroupConvolutionBackpropData::get_type_info_static(), {{{0, 15}}, {{0, 8, 32}}}},
{ ov::op::v12::ScatterElementsUpdate::get_type_info_static(), {{{0, 15}}, {{0, 8, 32}}} },
{ ov::op::v3::ScatterUpdate::get_type_info_static(), {{{0, 15}}, {{0, 8, 32}}} },
{ ov::op::v0::Unsqueeze::get_type_info_static(), {{{0, 15}}, {{0, 8, 32}}} },
Expand All @@ -189,10 +197,16 @@ static std::map<ov::NodeTypeInfo, std::vector<std::vector<ov::test::utils::Input
{ ov::op::v0::LRN::get_type_info_static(), {{{0, 15}}, {{0, 8, 32}}} },
{ ov::op::v1::Pad::get_type_info_static(), {{{0, 15}}, {{0, 8, 32}}} },
{ ov::op::v3::Broadcast::get_type_info_static(), {{{0, 2000}}, {{0, 2000, 32768}}} },
{ ov::op::v9::NonMaxSuppression::get_type_info_static(), {{{0, 15}}, {{0, 8, 32}, {0, 1, 1000, 1, true}}} },
{ ov::op::v8::MatrixNms::get_type_info_static(), {{{0, 15}}, {{0, 8, 32}, {0, 1, 1000, 1, true}}} },
{ ov::op::v5::NonMaxSuppression::get_type_info_static(), {{{0, 15}, {0, 1, 1000, 1, true}}, {{0, 8, 32}, {0, 1, 1000, 1, true}}} },
{ ov::op::v9::NonMaxSuppression::get_type_info_static(), {{{0, 15}, {0, 1, 1000, 1, true}}, {{0, 8, 32}, {0, 1, 1000, 1, true}}} },
{ ov::op::v8::MatrixNms::get_type_info_static(), {{{0, 15}, {0, 1, 1000, 1, true}}, {{0, 8, 32}, {0, 1, 1000, 1, true}}} },
{ ov::op::v6::ExperimentalDetectronGenerateProposalsSingleImage::get_type_info_static(), {{{1, 0, 1, 1}}, {{1, 0, 1, 1}}}},
{ ov::op::v6::ExperimentalDetectronPriorGridGenerator::get_type_info_static(), {{{0, 0, 1, 1}}, {{-100, 200, 2, 1}, {0, 0, 1, 1}, {0, 0, 1, 1}}}},
{ ov::op::v6::ExperimentalDetectronPriorGridGenerator::get_type_info_static(), {{{0, 0, 1, 1}},
{{-100, 200, 2, 1}, {0, 0, 1, 1, true}, {0, 0, 1, 1, true}}}},
{ ov::op::v8::DeformableConvolution::get_type_info_static(), {{{0, 15}, {0, 2, 10, 1, true}, {0, 1, 20, 1, true}},
{{0, 8, 32}, {0, 2, 10, 1, true}, {0, 1, 20, 1, true}}}},
{ ov::op::v5::GRUSequence::get_type_info_static(), {{{0, 15}, {0, 15}, {0, 10, 1, 1, true}}, {{0, 8, 32}}}},
{ ov::op::v5::BatchNormInference::get_type_info_static(), {{{0, 3}}, {{0, 3, 1}}}},
{ ov::op::v5::RNNSequence::get_type_info_static(), {{{0, 15}, {0, 15}, {0, 10, 1, 1, true}}, {{0, 8, 32}, {0, 8, 32}, {0, 10, 1, 1, true}}} },
{ ov::op::v1::LogicalAnd::get_type_info_static(), {{{0, 2}}, {{0, 2}}} },
{ ov::op::v1::LogicalNot::get_type_info_static(), {{{0, 2}}, {{0, 2}}} },
Expand All @@ -208,13 +222,33 @@ static std::map<ov::NodeTypeInfo, std::vector<std::vector<ov::test::utils::Input
{ ov::op::v9::ROIAlign::get_type_info_static(), {{{0, 15}, {0, 1000, 1, 1, true}, {0, 1000, 1, 1, true}},
{{-1000, 2000, 32768}, {0, 1000, 1, 1, true}, {0, 1000, 1, 1, true}}} },
{ ov::op::v0::Convert::get_type_info_static(), {{{0, 1000}}, {{-100, 200, 32768}}} },
{ ov::op::v0::FakeQuantize::get_type_info_static(), {{{0, 15}}, {{0, 8, 32}}} },
{ ov::op::v0::FakeQuantize::get_type_info_static(), {{{0, 15}}, {{0, 8, 32}}} },
{ ov::op::v1::Select::get_type_info_static(), {{{0, 15}}, {{0, 8, 32}}} },
{ ov::op::v1::Multiply::get_type_info_static(), {{{0, 15}}, {{0, 8, 32}}} },
{ ov::op::v1::StridedSlice::get_type_info_static(), {{{0, 15}}, {{0, 8, 32}}} },
{ ov::op::v5::LSTMSequence::get_type_info_static(), {{{0, 15}}, {{0, 8, 32}}} },
};

ov::test::utils::InputGenerateData get_range_by_type(ov::element::Type temp_type, uint64_t kMaxRange);

std::string get_range_id(const std::shared_ptr<ov::Node>& node, size_t port, bool spectial = false);
class ModelRange {
std::vector<std::string> TYPE_ALIAS {"integral", "real"};

std::shared_ptr<ov::test::utils::InputGenerateData> general_real;
std::shared_ptr<ov::test::utils::InputGenerateData> general_integral;
// key for map calculated in get_range_id and contais [Op Type Name]_[integral/real]_[port]
std::map<std::string, std::shared_ptr<ov::test::utils::InputGenerateData>> node_ranges;
public:
void collect_ranges(const std::shared_ptr<ov::Model>& function, uint64_t kMaxRange);
void find_general_ranges();
std::string get_range_id(const std::shared_ptr<ov::Node>& node, size_t port);
ov::Tensor generate_input(std::shared_ptr<ov::Node> node, size_t port, const ov::Shape& targetShape);

const std::shared_ptr<ov::test::utils::InputGenerateData> get_general_real_range();
const std::shared_ptr<ov::test::utils::InputGenerateData> get_general_integral_range();
};

std::map<std::string, std::shared_ptr<ov::test::utils::InputGenerateData>> collect_ranges(const std::shared_ptr<ov::Model>& function, uint64_t kMaxRange);

} // namespace utils
} // namespace test
Expand Down
23 changes: 8 additions & 15 deletions src/tests/functional/shared_test_classes/src/base/ov_subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@
#include "functional_test_utils/crash_handler.hpp"

#include "shared_test_classes/base/ov_subgraph.hpp"
#include "shared_test_classes/base/utils/generate_inputs.hpp"
// #include "shared_test_classes/base/utils/generate_inputs.hpp"
#include "shared_test_classes/base/utils/compare_results.hpp"
#include "shared_test_classes/base/utils/calculate_thresholds.hpp"

#include "shared_test_classes/base/utils/ranges.hpp"


namespace ov {
namespace test {
Expand Down Expand Up @@ -320,28 +322,19 @@ void SubgraphBaseTest::compile_model() {

void SubgraphBaseTest::generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) {
inputs.clear();
auto inputDataMap = ov::test::utils::collect_ranges(function, testing::internal::Random::kMaxRange);
auto inputMap = utils::getInputMap();
ov::test::utils::ModelRange modelRange;
modelRange.collect_ranges(function, testing::internal::Random::kMaxRange);
modelRange.find_general_ranges();

auto itTargetShape = targetInputStaticShapes.begin();
for (const auto &param : function->get_parameters()) {
std::shared_ptr<ov::Node> inputNode = param;
for (size_t i = 0; i < param->get_output_size(); i++) {
for (const auto &node : param->get_output_target_inputs(i)) {
std::shared_ptr<ov::Node> nodePtr = node.get_node()->shared_from_this();
auto it = inputMap.find(nodePtr->get_type_info());
ASSERT_NE(it, inputMap.end());
for (size_t port = 0; port < nodePtr->get_input_size(); ++port) {
if (nodePtr->get_input_node_ptr(port)->shared_from_this() == inputNode->shared_from_this()) {
if (!inputDataMap.empty()) {
std::string spetial_range_id = ov::test::utils::get_range_id(nodePtr, port, true);
if (inputDataMap.find(spetial_range_id) == inputDataMap.end()) {
spetial_range_id = ov::test::utils::get_range_id(nodePtr, port, false);
}
inputs.insert({param, it->second(nodePtr, port, param->get_element_type(), *itTargetShape,
inputDataMap[spetial_range_id])});
} else {
inputs.insert({param, it->second(nodePtr, port, param->get_element_type(), *itTargetShape, nullptr)});
}
inputs.insert({param, modelRange.generate_input(nodePtr, port, *itTargetShape)});
break;
}
}
Expand Down
Loading

0 comments on commit d0310f0

Please sign in to comment.