Skip to content

Commit

Permalink
Fix type and members from GroupNormalizationFusionTestBase in derived…
Browse files Browse the repository at this point in the history
… classes' templates
  • Loading branch information
jhajducz committed Feb 5, 2025
1 parent 9ea92f2 commit ef4c918
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,80 +24,8 @@ template <element::Type_t T_elem>
class GroupNormalizationFusionTransformationTestsF
: public ov::test::GroupNormalizationFusionTestBase<T_elem>,
public testing::TestWithParam<GroupNormalizationFusionSubgraphTestValues> {
protected:
bool positiveTest;
ov::pass::Manager manager;
ov::pass::Manager manager_ref;
std::shared_ptr<ov::Model> model;
std::shared_ptr<ov::Model> model_ref;

virtual void read_test_parameters() {
const auto& params = GetParam();

dataShape = std::get<0>(params);
if (!dataShape.rank().is_static())
throw std::runtime_error("Rank of input tensor has to be static!");
if (dataShape.rank().get_max_length() < 2)
throw std::runtime_error("Expected at least two dimensions in input tensor!");
if (!dataShape[1].is_static())
throw std::runtime_error("Channel dimension in input tensor has to be static!");

numChannels = static_cast<size_t>(dataShape[1].get_max_length());
instanceNormGammaShape = std::get<1>(params);
instanceNormBetaShape = std::get<2>(params);
groupNormGammaShape = std::get<3>(params);
groupNormBetaShape = std::get<4>(params);
numGroups = std::get<5>(params);
epsilon = std::get<6>(params);
positiveTest = std::get<7>(params);

instanceNormGammaPresent = (instanceNormGammaShape != Shape{});
instanceNormBetaPresent = (instanceNormBetaShape != Shape{});

if (positiveTest) {
if ((instanceNormGammaShape != Shape{}) && (shape_size(instanceNormGammaShape) != numGroups))
throw std::runtime_error("Shape of instance norm gamma has to either be empty or contain "
"exactly <numGroups> elements");
if ((instanceNormBetaShape != Shape{}) && (shape_size(instanceNormBetaShape) != numGroups))
throw std::runtime_error("Shape of instance norm beta has to either be empty shape or contain "
"exactly <numGroups> elements");
if (shape_size(groupNormGammaShape) != numChannels)
throw std::runtime_error("Shape of group norm gamma has to contain exactly <numChannels> elements");
if (shape_size(groupNormBetaShape) != numChannels)
throw std::runtime_error("Shape of group norm beta has to contain exactly <numChannels> elements");

instanceNormGammaPresent = instanceNormGammaPresent && (shape_size(instanceNormGammaShape) == numGroups);
instanceNormBetaPresent = instanceNormBetaPresent && (shape_size(instanceNormBetaShape) == numGroups);
}
}

std::shared_ptr<ov::Model> create_ref_model() {
auto input = std::make_shared<ov::op::v0::Parameter>(T_elem_t, dataShape);

auto group_norm_beta_corr_vals = groupNormBetaVals;
if (instanceNormBetaPresent)
for (auto i = 0; i < group_norm_beta_corr_vals.size(); i++)
group_norm_beta_corr_vals[i] =
groupNormGammaVals[i] * instanceNormBetaVals[i / (numChannels / numGroups)] + groupNormBetaVals[i];
auto group_norm_beta_1d = op::v0::Constant::create(T_elem_t, Shape{numChannels}, group_norm_beta_corr_vals);

auto group_norm_gamma_corr_vals = groupNormGammaVals;
if (instanceNormGammaPresent)
for (auto i = 0; i < group_norm_gamma_corr_vals.size(); i++)
group_norm_gamma_corr_vals[i] =
groupNormGammaVals[i] * instanceNormGammaVals[i / (numChannels / numGroups)];
auto group_norm_gamma_1d = op::v0::Constant::create(T_elem_t, Shape{numChannels}, group_norm_gamma_corr_vals);

auto group_norm = std::make_shared<ov::op::v12::GroupNormalization>(input,
group_norm_gamma_1d,
group_norm_beta_1d,
numGroups,
epsilon);

return std::make_shared<Model>(NodeVector{group_norm}, ParameterVector{input});
}

public:
static constexpr element::Type T_elem_t = T_elem;
static std::string getTestCaseName(const testing::TestParamInfo<GroupNormalizationFusionSubgraphTestValues>& obj) {
const auto& params = obj.param;

Expand Down Expand Up @@ -127,8 +55,8 @@ class GroupNormalizationFusionTransformationTestsF

void run() {
read_test_parameters();
generate_weights_init_values();
model = create_model();
this->generate_weights_init_values();
model = this->create_model();

manager = ov::pass::Manager();
manager.register_pass<ov::pass::InitNodeInfo>();
Expand Down Expand Up @@ -164,7 +92,7 @@ class GroupNormalizationFusionTransformationTestsF
ASSERT_EQ(f_results[0]->get_output_partial_shape(0), f_ref_results[0]->get_output_partial_shape(0));
ASSERT_EQ(f_results[0]->get_output_partial_shape(0), f_parameters[0]->get_output_partial_shape(0));
ASSERT_EQ(f_ref_results[0]->get_output_partial_shape(0), f_ref_parameters[0]->get_output_partial_shape(0));
ASSERT_EQ(f_ref_results[0]->get_output_partial_shape(0), dataShape);
ASSERT_EQ(f_ref_results[0]->get_output_partial_shape(0), this->dataShape);

const auto& gn_node = f_results[0]->get_input_node_shared_ptr(0);
const auto& gn_ref_node = f_ref_results[0]->get_input_node_shared_ptr(0);
Expand All @@ -173,22 +101,103 @@ class GroupNormalizationFusionTransformationTestsF
ASSERT_EQ(gn_node->inputs().size(), gn_ref_node->inputs().size());
ASSERT_EQ(gn_node->inputs().size(), 3);
ASSERT_EQ(gn_node->get_input_partial_shape(0), gn_ref_node->get_input_partial_shape(0));
ASSERT_EQ(gn_node->get_input_partial_shape(0), dataShape);
ASSERT_EQ(gn_node->get_input_partial_shape(0), this->dataShape);
ASSERT_EQ(shape_size(gn_node->get_input_shape(1)), shape_size(gn_ref_node->get_input_shape(1)));
ASSERT_EQ(shape_size(gn_node->get_input_shape(1)), numChannels);
ASSERT_EQ(shape_size(gn_node->get_input_shape(1)), this->numChannels);
ASSERT_EQ(shape_size(gn_node->get_input_shape(2)), shape_size(gn_ref_node->get_input_shape(2)));
ASSERT_EQ(shape_size(gn_node->get_input_shape(2)), numChannels);
ASSERT_EQ(shape_size(gn_node->get_input_shape(2)), this->numChannels);

const auto& gn_node_casted = ov::as_type_ptr<ov::op::v12::GroupNormalization>(gn_node);
const auto& gn_ref_node_casted = ov::as_type_ptr<ov::op::v12::GroupNormalization>(gn_ref_node);
ASSERT_EQ(gn_node_casted->get_epsilon(), gn_ref_node_casted->get_epsilon());
ASSERT_EQ(gn_node_casted->get_epsilon(), epsilon);
ASSERT_EQ(gn_node_casted->get_epsilon(), this->epsilon);
ASSERT_EQ(gn_node_casted->get_num_groups(), gn_ref_node_casted->get_num_groups());
ASSERT_EQ(gn_node_casted->get_num_groups(), numGroups);
ASSERT_EQ(gn_node_casted->get_num_groups(), this->numGroups);
} else {
ASSERT_EQ(count_ops_of_type<ov::op::v12::GroupNormalization>(model), 0);
}
}

protected:
bool positiveTest;
ov::pass::Manager manager;
ov::pass::Manager manager_ref;
std::shared_ptr<ov::Model> model;
std::shared_ptr<ov::Model> model_ref;

void read_test_parameters() override {
const auto& params = GetParam();

this->dataShape = std::get<0>(params);
if (!this->dataShape.rank().is_static())
throw std::runtime_error("Rank of input tensor has to be static!");
if (this->dataShape.rank().get_max_length() < 2)
throw std::runtime_error("Expected at least two dimensions in input tensor!");
if (!this->dataShape[1].is_static())
throw std::runtime_error("Channel dimension in input tensor has to be static!");

this->numChannels = static_cast<size_t>(this->dataShape[1].get_max_length());
this->instanceNormGammaShape = std::get<1>(params);
this->instanceNormBetaShape = std::get<2>(params);
this->groupNormGammaShape = std::get<3>(params);
this->groupNormBetaShape = std::get<4>(params);
this->numGroups = std::get<5>(params);
this->epsilon = std::get<6>(params);
positiveTest = std::get<7>(params);

this->instanceNormGammaPresent = (this->instanceNormGammaShape != Shape{});
this->instanceNormBetaPresent = (this->instanceNormBetaShape != Shape{});

if (positiveTest) {
if ((this->instanceNormGammaShape != Shape{}) &&
(shape_size(this->instanceNormGammaShape) != this->numGroups))
throw std::runtime_error("Shape of instance norm gamma has to either be empty or contain "
"exactly <numGroups> elements");
if ((this->instanceNormBetaShape != Shape{}) &&
(shape_size(this->instanceNormBetaShape) != this->numGroups))
throw std::runtime_error("Shape of instance norm beta has to either be empty shape or contain "
"exactly <numGroups> elements");
if (shape_size(this->groupNormGammaShape) != this->numChannels)
throw std::runtime_error("Shape of group norm gamma has to contain exactly <numChannels> elements");
if (shape_size(this->groupNormBetaShape) != this->numChannels)
throw std::runtime_error("Shape of group norm beta has to contain exactly <numChannels> elements");

this->instanceNormGammaPresent =
this->instanceNormGammaPresent && (shape_size(this->instanceNormGammaShape) == this->numGroups);
this->instanceNormBetaPresent =
this->instanceNormBetaPresent && (shape_size(this->instanceNormBetaShape) == this->numGroups);
}
}

std::shared_ptr<ov::Model> create_ref_model() {
auto input = std::make_shared<ov::op::v0::Parameter>(T_elem_t, this->dataShape);

auto group_norm_beta_corr_vals = this->groupNormBetaVals;
if (this->instanceNormBetaPresent)
for (auto i = 0; i < group_norm_beta_corr_vals.size(); i++)
group_norm_beta_corr_vals[i] =
this->groupNormGammaVals[i] *
this->instanceNormBetaVals[i / (this->numChannels / this->numGroups)] +
this->groupNormBetaVals[i];
auto group_norm_beta_1d =
op::v0::Constant::create(T_elem_t, Shape{this->numChannels}, group_norm_beta_corr_vals);

auto group_norm_gamma_corr_vals = this->groupNormGammaVals;
if (this->instanceNormGammaPresent)
for (auto i = 0; i < group_norm_gamma_corr_vals.size(); i++)
group_norm_gamma_corr_vals[i] = this->groupNormGammaVals[i] *
this->instanceNormGammaVals[i / (this->numChannels / this->numGroups)];
auto group_norm_gamma_1d =
op::v0::Constant::create(T_elem_t, Shape{this->numChannels}, group_norm_gamma_corr_vals);

auto group_norm = std::make_shared<ov::op::v12::GroupNormalization>(input,
group_norm_gamma_1d,
group_norm_beta_1d,
this->numGroups,
this->epsilon);

return std::make_shared<Model>(NodeVector{group_norm}, ParameterVector{input});
}
};

class GroupNormalizationFusionTransformationTestsF_f32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ class GroupNormalizationFusionSubgraphTestsF
public ov::test::SubgraphBaseTest,
public testing::WithParamInterface<GroupNormalizationFusionTransformationsTestValues> {
public:
static constexpr element::Type T_elem_t = T_elem;
static std::string getTestCaseName(
const testing::TestParamInfo<GroupNormalizationFusionTransformationsTestValues>& obj) {
const auto& params = obj.param;
Expand Down Expand Up @@ -194,52 +195,56 @@ class GroupNormalizationFusionSubgraphTestsF
SubgraphBaseTest::TearDown();
}

virtual void read_test_parameters() {
void read_test_parameters() override {
const auto& params = GetParam();

dataShape = std::get<0>(params);
if (!dataShape.rank().is_static())
this->dataShape = std::get<0>(params);
if (!this->dataShape.rank().is_static())
throw std::runtime_error("Rank of input tensor has to be static!");
if (dataShape.rank().get_max_length() < 2)
if (this->dataShape.rank().get_max_length() < 2)
throw std::runtime_error("Expected at least two dimensions in input tensor!");
if (!dataShape[1].is_static())
if (!this->dataShape[1].is_static())
throw std::runtime_error("Channel dimension in input tensor has to be static!");

numChannels = static_cast<size_t>(dataShape[1].get_max_length());
instanceNormGammaShape = std::get<1>(params);
instanceNormBetaShape = std::get<2>(params);
groupNormGammaShape = std::get<3>(params);
groupNormBetaShape = std::get<4>(params);
numGroups = std::get<5>(params);
epsilon = std::get<6>(params);
this->numChannels = static_cast<size_t>(this->dataShape[1].get_max_length());
this->instanceNormGammaShape = std::get<1>(params);
this->instanceNormBetaShape = std::get<2>(params);
this->groupNormGammaShape = std::get<3>(params);
this->groupNormBetaShape = std::get<4>(params);
this->numGroups = std::get<5>(params);
this->epsilon = std::get<6>(params);
positiveTest = std::get<7>(params);
targetDeviceName = std::get<8>(params);
targetConfiguration = std::get<9>(params);
refDevice = std::get<10>(params);
refConfiguration = std::get<11>(params);

instanceNormGammaPresent = (instanceNormGammaShape != Shape{});
instanceNormBetaPresent = (instanceNormBetaShape != Shape{});
this->instanceNormGammaPresent = (this->instanceNormGammaShape != Shape{});
this->instanceNormBetaPresent = (this->instanceNormBetaShape != Shape{});

inType = T_elem_t;
outType = T_elem_t;
targetDevice = targetDeviceName;
configuration = targetConfiguration;

if (positiveTest) {
if ((instanceNormGammaShape != Shape{}) && (shape_size(instanceNormGammaShape) != numGroups))
if ((this->instanceNormGammaShape != Shape{}) &&
(shape_size(this->instanceNormGammaShape) != this->numGroups))
throw std::runtime_error("Shape of instance norm gamma has to either be empty or contain "
"exactly <numGroups> elements");
if ((instanceNormBetaShape != Shape{}) && (shape_size(instanceNormBetaShape) != numGroups))
if ((this->instanceNormBetaShape != Shape{}) &&
(shape_size(this->instanceNormBetaShape) != this->numGroups))
throw std::runtime_error("Shape of instance norm beta has to either be empty shape or contain "
"exactly <numGroups> elements");
if (shape_size(groupNormGammaShape) != numChannels)
if (shape_size(this->groupNormGammaShape) != this->numChannels)
throw std::runtime_error("Shape of group norm gamma has to contain exactly <numChannels> elements");
if (shape_size(groupNormBetaShape) != numChannels)
if (shape_size(this->groupNormBetaShape) != this->numChannels)
throw std::runtime_error("Shape of group norm beta has to contain exactly <numChannels> elements");

instanceNormGammaPresent = instanceNormGammaPresent && (shape_size(instanceNormGammaShape) == numGroups);
instanceNormBetaPresent = instanceNormBetaPresent && (shape_size(instanceNormBetaShape) == numGroups);
this->instanceNormGammaPresent =
this->instanceNormGammaPresent && (shape_size(this->instanceNormGammaShape) == this->numGroups);
this->instanceNormBetaPresent =
this->instanceNormBetaPresent && (shape_size(this->instanceNormBetaShape) == this->numGroups);
}
}

Expand Down Expand Up @@ -305,7 +310,7 @@ class GroupNormalizationFusionSubgraphTestsF

void init_thresholds() override {
if (!targetStaticShapes.empty()) {
size_t problem_size = shape_size(dataShape.get_shape());
size_t problem_size = shape_size(this->dataShape.get_shape());

abs_threshold = pow(problem_size, 0.5) * test::utils::get_eps_by_ov_type(outType);
rel_threshold = abs_threshold;
Expand Down Expand Up @@ -399,8 +404,8 @@ class GroupNormalizationFusionSubgraphTestsF
std::string errorMessage;
try {
read_test_parameters();
generate_weights_init_values();
functionRefs = create_model();
this->generate_weights_init_values();
functionRefs = this->create_model();
function = functionRefs->clone();
pass::Manager m;
m.register_pass<ov::pass::GroupNormalizationFusion>();
Expand All @@ -415,7 +420,7 @@ class GroupNormalizationFusionSubgraphTestsF
if (!function->is_dynamic()) {
configure_device();
configure_ref_device();
auto input_shapes = static_partial_shapes_to_test_representation({dataShape});
auto input_shapes = static_partial_shapes_to_test_representation({this->dataShape});
init_input_shapes(input_shapes);
ASSERT_FALSE(targetStaticShapes.empty() && !function->get_parameters().empty())
<< "Target Static Shape is empty!!!";
Expand Down

0 comments on commit ef4c918

Please sign in to comment.