Skip to content

Commit

Permalink
[CPU] Move weights transposition to CPU plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
antonvor committed Jul 19, 2023
1 parent bf2dca2 commit 6d59a62
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,13 @@ ov::intel_cpu::ConvertMatMulToFC::ConvertMatMulToFC() {
std::swap(*(transpose_order.end() - 1), *(transpose_order.end() - 2));

auto transpose_const = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ transpose_order.size() }, transpose_order);
auto transpose = ov::op::util::make_try_fold<ngraph::opset1::Transpose>(node, transpose_const);
auto transpose = std::make_shared<ngraph::opset1::Transpose>(node, transpose_const);
if (!ngraph::is_type<ngraph::opset1::Constant>(transpose)) {
new_ops.push_back(transpose_const);
MatcherPass::register_new_node(transpose);
}
transpose->set_friendly_name(transpose_name);
ov::disable_constant_folding(transpose);
new_ops.push_back(transpose);
return transpose;
};
Expand Down Expand Up @@ -143,7 +144,7 @@ ov::intel_cpu::ConvertMatMulToFC::ConvertMatMulToFC() {
NGRAPH_CHECK(K.is_static());
std::vector<int64_t> reshape_shape_values = { -1ll, static_cast<int64_t>(K.get_length()) };
auto reshape_shape = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 2 }, reshape_shape_values);
fc_input_b = ov::op::util::make_try_fold<ngraph::opset1::Reshape>(fc_input_b, reshape_shape, false);
fc_input_b = std::make_shared<ngraph::opset1::Reshape>(fc_input_b, reshape_shape, false);
if (!std::dynamic_pointer_cast<ngraph::opset1::Constant>(fc_input_b.get_node_shared_ptr())) {
new_ops.push_back(reshape_shape);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "transformations/common_optimizations/augru_cell_fusion.hpp"
#include "transformations/common_optimizations/common_optimizations.hpp"
#include "transformations/common_optimizations/wrap_interpolate_into_transposes.hpp"
#include "transformations/common_optimizations/matmul_const_transposes_extraction.hpp"
#include "transformations/control_flow/unroll_tensor_iterator.hpp"
#include "transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp"
#include "transformations/op_conversions/convert_batch_to_space.hpp"
Expand Down Expand Up @@ -434,6 +435,7 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
CPU_DISABLE_PASS_COMMON(manager, ov::pass::ConvertTopK3);
CPU_DISABLE_PASS_COMMON(manager, ov::pass::ConvertTopK11ToTopK3);
CPU_DISABLE_PASS_COMMON(manager, ov::pass::HSwishDecomposition);
CPU_DISABLE_PASS_COMMON(manager, ov::pass::MatMulConstTransposesExtraction);
CPU_DISABLE_PASS_X64(manager, ov::pass::HSigmoidDecomposition);

CPU_DISABLE_PASS_X64(manager, ov::pass::ReduceL1Decomposition);
Expand Down

0 comments on commit 6d59a62

Please sign in to comment.