Skip to content

Commit

Permalink
removed KeepConstAndDecompressionForMatMul pass and added cpu callbac…
Browse files Browse the repository at this point in the history
…k for KeepConstAndDecompression
  • Loading branch information
antonvor committed Aug 10, 2023
1 parent 979cc72 commit f296887
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ namespace pass {
class TRANSFORMATIONS_API EnableDecompressionConvertConstantFolding;
class TRANSFORMATIONS_API DisableDecompressionConvertConstantFolding;
class TRANSFORMATIONS_API KeepConstAndDecompression;
class TRANSFORMATIONS_API KeepConstAndDecompressionForMatMul;

} // namespace pass
} // namespace ov
Expand Down Expand Up @@ -48,14 +47,3 @@ class ov::pass::KeepConstAndDecompression : public MatcherPass {
OPENVINO_RTTI("KeepConstAndDecompression", "0");
KeepConstAndDecompression();
};

/**
* @ingroup ie_transformation_common_api
* @brief Disables ConstantFolding for Convert operation (just before MatMul operation only) and prevents conversion
* of f16 Consts to f32.
*/
class ov::pass::KeepConstAndDecompressionForMatMul : public MatcherPass {
public:
OPENVINO_RTTI("KeepConstAndDecompressionForMatMul", "0");
KeepConstAndDecompressionForMatMul();
};
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ pass::KeepConstAndDecompression::KeepConstAndDecompression() {
ov::is_shape_subgraph(node->shared_from_this()))
return false;

if (transformation_callback(node)) {
return false;
}

disable_constant_folding(node);

if (!is_type<ov::op::v0::Constant>(node->input_value(0).get_node_shared_ptr()))
Expand All @@ -70,29 +74,3 @@ pass::KeepConstAndDecompression::KeepConstAndDecompression() {
auto m = std::make_shared<pattern::Matcher>(node_pattern, matcher_name);
register_matcher(m, callback);
}

pass::KeepConstAndDecompressionForMatMul::KeepConstAndDecompressionForMatMul() {
MATCHER_SCOPE(KeepConstAndDecompressionForMatMul);
auto matmul = pass::pattern::wrap_type<ov::op::v0::MatMul>();

matcher_pass_callback callback = [=](pass::pattern::Matcher& m) {
auto node = m.get_match_root();

// input to matmul is decompression Convert
const auto& inp_convert = node->input_value(1).get_node_shared_ptr();
if (!is_type<ov::op::v0::Convert>(inp_convert) || inp_convert->get_output_target_inputs(0).size() != 1 ||
!is_decompression(inp_convert))
return false;

disable_constant_folding(inp_convert);

if (!is_type<ov::op::v0::Constant>(inp_convert->input_value(0).get_node_shared_ptr()))
return false;
enable_keep_fp16_const(inp_convert->input_value(0).get_node_shared_ptr());

return false;
};

auto m = std::make_shared<pass::pattern::Matcher>(matmul, matcher_name);
this->register_matcher(m, callback);
}
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,14 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
manager.set_per_pass_validation(false);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::InitNodeInfo);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::MarkShapeOfSubgraphs);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::KeepConstAndDecompressionForMatMul);

CPU_REGISTER_PASS_COMMON(manager, ov::pass::KeepConstAndDecompression);
CPU_SET_CALLBACK_COMMON(manager,
[](const_node_ptr &node) -> bool {
const auto outputs = node->get_output_target_inputs(0);
return outputs.size() != 1 || !is_type<ov::op::v0::MatMul>(outputs.begin()->get_node());
},
ov::pass::KeepConstAndDecompression);

const bool useLpt = !defaultPrecisions.empty();
if (useLpt) {
Expand Down Expand Up @@ -434,7 +441,7 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
AUGRUCell node (see AUGRUCellFusion pass). In such cases, some constant paths will be unfolded, which can lead to crashes in the plugin. To avoid this,
we re-mark decompression converts again and finally do CF for those constant paths that are not inputs to MatMul node */
CPU_REGISTER_PASS_COMMON(manager, ov::pass::EnableDecompressionConvertConstantFolding);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::KeepConstAndDecompressionForMatMul);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::KeepConstAndDecompression);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::ConstantFolding);

manager.run_passes(model);
Expand Down

0 comments on commit f296887

Please sign in to comment.