diff --git a/src/common/transformations/include/transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp b/src/common/transformations/include/transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp index b85eff6d575bde..29d335b0db1c06 100644 --- a/src/common/transformations/include/transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp +++ b/src/common/transformations/include/transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp @@ -14,7 +14,6 @@ namespace pass { class TRANSFORMATIONS_API EnableDecompressionConvertConstantFolding; class TRANSFORMATIONS_API DisableDecompressionConvertConstantFolding; class TRANSFORMATIONS_API KeepConstAndDecompression; -class TRANSFORMATIONS_API KeepConstAndDecompressionForMatMul; } // namespace pass } // namespace ov @@ -48,14 +47,3 @@ class ov::pass::KeepConstAndDecompression : public MatcherPass { OPENVINO_RTTI("KeepConstAndDecompression", "0"); KeepConstAndDecompression(); }; - -/** - * @ingroup ie_transformation_common_api - * @brief Disables ConstantFolding for Convert operation (just before MatMul operation only) and prevents conversion - * of f16 Consts to f32. - */ -class ov::pass::KeepConstAndDecompressionForMatMul : public MatcherPass { -public: - OPENVINO_RTTI("KeepConstAndDecompressionForMatMul", "0"); - KeepConstAndDecompressionForMatMul(); -}; diff --git a/src/common/transformations/src/transformations/fp16_compression/mark_decompression_convert_constant_folding.cpp b/src/common/transformations/src/transformations/fp16_compression/mark_decompression_convert_constant_folding.cpp index f47718f80a5d76..de03d931ac47cc 100644 --- a/src/common/transformations/src/transformations/fp16_compression/mark_decompression_convert_constant_folding.cpp +++ b/src/common/transformations/src/transformations/fp16_compression/mark_decompression_convert_constant_folding.cpp @@ -59,6 +59,10 @@ pass::KeepConstAndDecompression::KeepConstAndDecompression() { ov::is_shape_subgraph(node->shared_from_this())) return false; + if (transformation_callback(node)) { + return false; + } + disable_constant_folding(node); if (!is_type(node->input_value(0).get_node_shared_ptr())) @@ -70,29 +74,3 @@ pass::KeepConstAndDecompression::KeepConstAndDecompression() { auto m = std::make_shared(node_pattern, matcher_name); register_matcher(m, callback); } - -pass::KeepConstAndDecompressionForMatMul::KeepConstAndDecompressionForMatMul() { - MATCHER_SCOPE(KeepConstAndDecompressionForMatMul); - auto matmul = pass::pattern::wrap_type(); - - matcher_pass_callback callback = [=](pass::pattern::Matcher& m) { - auto node = m.get_match_root(); - - // input to matmul is decompression Convert - const auto& inp_convert = node->input_value(1).get_node_shared_ptr(); - if (!is_type(inp_convert) || inp_convert->get_output_target_inputs(0).size() != 1 || - !is_decompression(inp_convert)) - return false; - - disable_constant_folding(inp_convert); - - if (!is_type(inp_convert->input_value(0).get_node_shared_ptr())) - return false; - enable_keep_fp16_const(inp_convert->input_value(0).get_node_shared_ptr()); - - return false; - }; - - auto m = std::make_shared(matmul, matcher_name); - this->register_matcher(m, callback); -} diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index f8356a8f793e76..560a141c52745e 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -203,7 +203,14 @@ void Transformations::PreLpt(const std::vector& defaultPrecis manager.set_per_pass_validation(false); CPU_REGISTER_PASS_COMMON(manager, ov::pass::InitNodeInfo); CPU_REGISTER_PASS_COMMON(manager, ov::pass::MarkShapeOfSubgraphs); - CPU_REGISTER_PASS_COMMON(manager, ov::pass::KeepConstAndDecompressionForMatMul); + + CPU_REGISTER_PASS_COMMON(manager, ov::pass::KeepConstAndDecompression); + CPU_SET_CALLBACK_COMMON(manager, + [](const_node_ptr &node) -> bool { + const auto outputs = node->get_output_target_inputs(0); + return outputs.size() != 1 || !is_type(outputs.begin()->get_node()); + }, + ov::pass::KeepConstAndDecompression); const bool useLpt = !defaultPrecisions.empty(); if (useLpt) { @@ -434,7 +441,7 @@ void Transformations::PreLpt(const std::vector& defaultPrecis AUGRUCell node (see AUGRUCellFusion pass). In such cases, some constant paths will be unfolded, which can lead to crashes in the plugin. To avoid this, we re-mark decompression converts again and finally do CF for those constant paths that are not inputs to MatMul node */ CPU_REGISTER_PASS_COMMON(manager, ov::pass::EnableDecompressionConvertConstantFolding); - CPU_REGISTER_PASS_COMMON(manager, ov::pass::KeepConstAndDecompressionForMatMul); + CPU_REGISTER_PASS_COMMON(manager, ov::pass::KeepConstAndDecompression); CPU_REGISTER_PASS_COMMON(manager, ov::pass::ConstantFolding); manager.run_passes(model);