diff --git a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp index abcac8b84011fa..9f8598de5be10a 100644 --- a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp +++ b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp @@ -24,80 +24,8 @@ template class GroupNormalizationFusionTransformationTestsF : public ov::test::GroupNormalizationFusionTestBase, public testing::TestWithParam { -protected: - bool positiveTest; - ov::pass::Manager manager; - ov::pass::Manager manager_ref; - std::shared_ptr model; - std::shared_ptr model_ref; - - virtual void read_test_parameters() { - const auto& params = GetParam(); - - dataShape = std::get<0>(params); - if (!dataShape.rank().is_static()) - throw std::runtime_error("Rank of input tensor has to be static!"); - if (dataShape.rank().get_max_length() < 2) - throw std::runtime_error("Expected at least two dimensions in input tensor!"); - if (!dataShape[1].is_static()) - throw std::runtime_error("Channel dimension in input tensor has to be static!"); - - numChannels = static_cast(dataShape[1].get_max_length()); - instanceNormGammaShape = std::get<1>(params); - instanceNormBetaShape = std::get<2>(params); - groupNormGammaShape = std::get<3>(params); - groupNormBetaShape = std::get<4>(params); - numGroups = std::get<5>(params); - epsilon = std::get<6>(params); - positiveTest = std::get<7>(params); - - instanceNormGammaPresent = (instanceNormGammaShape != Shape{}); - instanceNormBetaPresent = (instanceNormBetaShape != Shape{}); - - if (positiveTest) { - if ((instanceNormGammaShape != Shape{}) && (shape_size(instanceNormGammaShape) != numGroups)) - throw std::runtime_error("Shape of instance norm gamma has to either be empty or contain " - "exactly elements"); - if ((instanceNormBetaShape != Shape{}) && (shape_size(instanceNormBetaShape) != numGroups)) - throw std::runtime_error("Shape of instance norm beta has to either be empty shape or contain " - "exactly elements"); - if (shape_size(groupNormGammaShape) != numChannels) - throw std::runtime_error("Shape of group norm gamma has to contain exactly elements"); - if (shape_size(groupNormBetaShape) != numChannels) - throw std::runtime_error("Shape of group norm beta has to contain exactly elements"); - - instanceNormGammaPresent = instanceNormGammaPresent && (shape_size(instanceNormGammaShape) == numGroups); - instanceNormBetaPresent = instanceNormBetaPresent && (shape_size(instanceNormBetaShape) == numGroups); - } - } - - std::shared_ptr create_ref_model() { - auto input = std::make_shared(T_elem_t, dataShape); - - auto group_norm_beta_corr_vals = groupNormBetaVals; - if (instanceNormBetaPresent) - for (auto i = 0; i < group_norm_beta_corr_vals.size(); i++) - group_norm_beta_corr_vals[i] = - groupNormGammaVals[i] * instanceNormBetaVals[i / (numChannels / numGroups)] + groupNormBetaVals[i]; - auto group_norm_beta_1d = op::v0::Constant::create(T_elem_t, Shape{numChannels}, group_norm_beta_corr_vals); - - auto group_norm_gamma_corr_vals = groupNormGammaVals; - if (instanceNormGammaPresent) - for (auto i = 0; i < group_norm_gamma_corr_vals.size(); i++) - group_norm_gamma_corr_vals[i] = - groupNormGammaVals[i] * instanceNormGammaVals[i / (numChannels / numGroups)]; - auto group_norm_gamma_1d = op::v0::Constant::create(T_elem_t, Shape{numChannels}, group_norm_gamma_corr_vals); - - auto group_norm = std::make_shared(input, - group_norm_gamma_1d, - group_norm_beta_1d, - numGroups, - epsilon); - - return std::make_shared(NodeVector{group_norm}, ParameterVector{input}); - } - public: + static constexpr element::Type T_elem_t = T_elem; static std::string getTestCaseName(const testing::TestParamInfo& obj) { const auto& params = obj.param; @@ -127,8 +55,8 @@ class GroupNormalizationFusionTransformationTestsF void run() { read_test_parameters(); - generate_weights_init_values(); - model = create_model(); + this->generate_weights_init_values(); + model = this->create_model(); manager = ov::pass::Manager(); manager.register_pass(); @@ -164,7 +92,7 @@ class GroupNormalizationFusionTransformationTestsF ASSERT_EQ(f_results[0]->get_output_partial_shape(0), f_ref_results[0]->get_output_partial_shape(0)); ASSERT_EQ(f_results[0]->get_output_partial_shape(0), f_parameters[0]->get_output_partial_shape(0)); ASSERT_EQ(f_ref_results[0]->get_output_partial_shape(0), f_ref_parameters[0]->get_output_partial_shape(0)); - ASSERT_EQ(f_ref_results[0]->get_output_partial_shape(0), dataShape); + ASSERT_EQ(f_ref_results[0]->get_output_partial_shape(0), this->dataShape); const auto& gn_node = f_results[0]->get_input_node_shared_ptr(0); const auto& gn_ref_node = f_ref_results[0]->get_input_node_shared_ptr(0); @@ -173,22 +101,103 @@ class GroupNormalizationFusionTransformationTestsF ASSERT_EQ(gn_node->inputs().size(), gn_ref_node->inputs().size()); ASSERT_EQ(gn_node->inputs().size(), 3); ASSERT_EQ(gn_node->get_input_partial_shape(0), gn_ref_node->get_input_partial_shape(0)); - ASSERT_EQ(gn_node->get_input_partial_shape(0), dataShape); + ASSERT_EQ(gn_node->get_input_partial_shape(0), this->dataShape); ASSERT_EQ(shape_size(gn_node->get_input_shape(1)), shape_size(gn_ref_node->get_input_shape(1))); - ASSERT_EQ(shape_size(gn_node->get_input_shape(1)), numChannels); + ASSERT_EQ(shape_size(gn_node->get_input_shape(1)), this->numChannels); ASSERT_EQ(shape_size(gn_node->get_input_shape(2)), shape_size(gn_ref_node->get_input_shape(2))); - ASSERT_EQ(shape_size(gn_node->get_input_shape(2)), numChannels); + ASSERT_EQ(shape_size(gn_node->get_input_shape(2)), this->numChannels); const auto& gn_node_casted = ov::as_type_ptr(gn_node); const auto& gn_ref_node_casted = ov::as_type_ptr(gn_ref_node); ASSERT_EQ(gn_node_casted->get_epsilon(), gn_ref_node_casted->get_epsilon()); - ASSERT_EQ(gn_node_casted->get_epsilon(), epsilon); + ASSERT_EQ(gn_node_casted->get_epsilon(), this->epsilon); ASSERT_EQ(gn_node_casted->get_num_groups(), gn_ref_node_casted->get_num_groups()); - ASSERT_EQ(gn_node_casted->get_num_groups(), numGroups); + ASSERT_EQ(gn_node_casted->get_num_groups(), this->numGroups); } else { ASSERT_EQ(count_ops_of_type(model), 0); } } + +protected: + bool positiveTest; + ov::pass::Manager manager; + ov::pass::Manager manager_ref; + std::shared_ptr model; + std::shared_ptr model_ref; + + void read_test_parameters() override { + const auto& params = GetParam(); + + this->dataShape = std::get<0>(params); + if (!this->dataShape.rank().is_static()) + throw std::runtime_error("Rank of input tensor has to be static!"); + if (this->dataShape.rank().get_max_length() < 2) + throw std::runtime_error("Expected at least two dimensions in input tensor!"); + if (!this->dataShape[1].is_static()) + throw std::runtime_error("Channel dimension in input tensor has to be static!"); + + this->numChannels = static_cast(this->dataShape[1].get_max_length()); + this->instanceNormGammaShape = std::get<1>(params); + this->instanceNormBetaShape = std::get<2>(params); + this->groupNormGammaShape = std::get<3>(params); + this->groupNormBetaShape = std::get<4>(params); + this->numGroups = std::get<5>(params); + this->epsilon = std::get<6>(params); + positiveTest = std::get<7>(params); + + this->instanceNormGammaPresent = (this->instanceNormGammaShape != Shape{}); + this->instanceNormBetaPresent = (this->instanceNormBetaShape != Shape{}); + + if (positiveTest) { + if ((this->instanceNormGammaShape != Shape{}) && + (shape_size(this->instanceNormGammaShape) != this->numGroups)) + throw std::runtime_error("Shape of instance norm gamma has to either be empty or contain " + "exactly elements"); + if ((this->instanceNormBetaShape != Shape{}) && + (shape_size(this->instanceNormBetaShape) != this->numGroups)) + throw std::runtime_error("Shape of instance norm beta has to either be empty shape or contain " + "exactly elements"); + if (shape_size(this->groupNormGammaShape) != this->numChannels) + throw std::runtime_error("Shape of group norm gamma has to contain exactly elements"); + if (shape_size(this->groupNormBetaShape) != this->numChannels) + throw std::runtime_error("Shape of group norm beta has to contain exactly elements"); + + this->instanceNormGammaPresent = + this->instanceNormGammaPresent && (shape_size(this->instanceNormGammaShape) == this->numGroups); + this->instanceNormBetaPresent = + this->instanceNormBetaPresent && (shape_size(this->instanceNormBetaShape) == this->numGroups); + } + } + + std::shared_ptr create_ref_model() { + auto input = std::make_shared(T_elem_t, this->dataShape); + + auto group_norm_beta_corr_vals = this->groupNormBetaVals; + if (this->instanceNormBetaPresent) + for (auto i = 0; i < group_norm_beta_corr_vals.size(); i++) + group_norm_beta_corr_vals[i] = + this->groupNormGammaVals[i] * + this->instanceNormBetaVals[i / (this->numChannels / this->numGroups)] + + this->groupNormBetaVals[i]; + auto group_norm_beta_1d = + op::v0::Constant::create(T_elem_t, Shape{this->numChannels}, group_norm_beta_corr_vals); + + auto group_norm_gamma_corr_vals = this->groupNormGammaVals; + if (this->instanceNormGammaPresent) + for (auto i = 0; i < group_norm_gamma_corr_vals.size(); i++) + group_norm_gamma_corr_vals[i] = this->groupNormGammaVals[i] * + this->instanceNormGammaVals[i / (this->numChannels / this->numGroups)]; + auto group_norm_gamma_1d = + op::v0::Constant::create(T_elem_t, Shape{this->numChannels}, group_norm_gamma_corr_vals); + + auto group_norm = std::make_shared(input, + group_norm_gamma_1d, + group_norm_beta_1d, + this->numGroups, + this->epsilon); + + return std::make_shared(NodeVector{group_norm}, ParameterVector{input}); + } }; class GroupNormalizationFusionTransformationTestsF_f32 diff --git a/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp b/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp index 3ddc264000646b..761e99daad3599 100644 --- a/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp +++ b/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp @@ -132,6 +132,7 @@ class GroupNormalizationFusionSubgraphTestsF public ov::test::SubgraphBaseTest, public testing::WithParamInterface { public: + static constexpr element::Type T_elem_t = T_elem; static std::string getTestCaseName( const testing::TestParamInfo& obj) { const auto& params = obj.param; @@ -194,32 +195,32 @@ class GroupNormalizationFusionSubgraphTestsF SubgraphBaseTest::TearDown(); } - virtual void read_test_parameters() { + void read_test_parameters() override { const auto& params = GetParam(); - dataShape = std::get<0>(params); - if (!dataShape.rank().is_static()) + this->dataShape = std::get<0>(params); + if (!this->dataShape.rank().is_static()) throw std::runtime_error("Rank of input tensor has to be static!"); - if (dataShape.rank().get_max_length() < 2) + if (this->dataShape.rank().get_max_length() < 2) throw std::runtime_error("Expected at least two dimensions in input tensor!"); - if (!dataShape[1].is_static()) + if (!this->dataShape[1].is_static()) throw std::runtime_error("Channel dimension in input tensor has to be static!"); - numChannels = static_cast(dataShape[1].get_max_length()); - instanceNormGammaShape = std::get<1>(params); - instanceNormBetaShape = std::get<2>(params); - groupNormGammaShape = std::get<3>(params); - groupNormBetaShape = std::get<4>(params); - numGroups = std::get<5>(params); - epsilon = std::get<6>(params); + this->numChannels = static_cast(this->dataShape[1].get_max_length()); + this->instanceNormGammaShape = std::get<1>(params); + this->instanceNormBetaShape = std::get<2>(params); + this->groupNormGammaShape = std::get<3>(params); + this->groupNormBetaShape = std::get<4>(params); + this->numGroups = std::get<5>(params); + this->epsilon = std::get<6>(params); positiveTest = std::get<7>(params); targetDeviceName = std::get<8>(params); targetConfiguration = std::get<9>(params); refDevice = std::get<10>(params); refConfiguration = std::get<11>(params); - instanceNormGammaPresent = (instanceNormGammaShape != Shape{}); - instanceNormBetaPresent = (instanceNormBetaShape != Shape{}); + this->instanceNormGammaPresent = (this->instanceNormGammaShape != Shape{}); + this->instanceNormBetaPresent = (this->instanceNormBetaShape != Shape{}); inType = T_elem_t; outType = T_elem_t; @@ -227,19 +228,23 @@ class GroupNormalizationFusionSubgraphTestsF configuration = targetConfiguration; if (positiveTest) { - if ((instanceNormGammaShape != Shape{}) && (shape_size(instanceNormGammaShape) != numGroups)) + if ((this->instanceNormGammaShape != Shape{}) && + (shape_size(this->instanceNormGammaShape) != this->numGroups)) throw std::runtime_error("Shape of instance norm gamma has to either be empty or contain " "exactly elements"); - if ((instanceNormBetaShape != Shape{}) && (shape_size(instanceNormBetaShape) != numGroups)) + if ((this->instanceNormBetaShape != Shape{}) && + (shape_size(this->instanceNormBetaShape) != this->numGroups)) throw std::runtime_error("Shape of instance norm beta has to either be empty shape or contain " "exactly elements"); - if (shape_size(groupNormGammaShape) != numChannels) + if (shape_size(this->groupNormGammaShape) != this->numChannels) throw std::runtime_error("Shape of group norm gamma has to contain exactly elements"); - if (shape_size(groupNormBetaShape) != numChannels) + if (shape_size(this->groupNormBetaShape) != this->numChannels) throw std::runtime_error("Shape of group norm beta has to contain exactly elements"); - instanceNormGammaPresent = instanceNormGammaPresent && (shape_size(instanceNormGammaShape) == numGroups); - instanceNormBetaPresent = instanceNormBetaPresent && (shape_size(instanceNormBetaShape) == numGroups); + this->instanceNormGammaPresent = + this->instanceNormGammaPresent && (shape_size(this->instanceNormGammaShape) == this->numGroups); + this->instanceNormBetaPresent = + this->instanceNormBetaPresent && (shape_size(this->instanceNormBetaShape) == this->numGroups); } } @@ -305,7 +310,7 @@ class GroupNormalizationFusionSubgraphTestsF void init_thresholds() override { if (!targetStaticShapes.empty()) { - size_t problem_size = shape_size(dataShape.get_shape()); + size_t problem_size = shape_size(this->dataShape.get_shape()); abs_threshold = pow(problem_size, 0.5) * test::utils::get_eps_by_ov_type(outType); rel_threshold = abs_threshold; @@ -399,8 +404,8 @@ class GroupNormalizationFusionSubgraphTestsF std::string errorMessage; try { read_test_parameters(); - generate_weights_init_values(); - functionRefs = create_model(); + this->generate_weights_init_values(); + functionRefs = this->create_model(); function = functionRefs->clone(); pass::Manager m; m.register_pass(); @@ -415,7 +420,7 @@ class GroupNormalizationFusionSubgraphTestsF if (!function->is_dynamic()) { configure_device(); configure_ref_device(); - auto input_shapes = static_partial_shapes_to_test_representation({dataShape}); + auto input_shapes = static_partial_shapes_to_test_representation({this->dataShape}); init_input_shapes(input_shapes); ASSERT_FALSE(targetStaticShapes.empty() && !function->get_parameters().empty()) << "Target Static Shape is empty!!!";