Skip to content

Commit

Permalink
[LPT] Quantized LSTMSequence & GRUSequence extended support (#25654)
Browse files Browse the repository at this point in the history
### Details:
- *Low Precision Transformations: Quantized LSTMSequence & GRUSequence
extended support*

### Tickets:
 - Current implementation for: *CVS-146067*
 - Will be changed in feature request: *CVS-147588*
  • Loading branch information
eshoguli authored Jul 27, 2024
1 parent 4a5bd43 commit 3056b53
Show file tree
Hide file tree
Showing 16 changed files with 608 additions and 76 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "transparent_base_transformation.hpp"

namespace ov {
namespace pass {
namespace low_precision {

/**
* @ingroup ov_transformation_common_api
* @brief BroadcastTransformation propagates dequantization operations through Broadcast operation.
*
* For more details about the transformation, refer to
* [BroadcastTransformation](@ref openvino_docs_OV_UG_lpt_BroadcastTransformation) page
* in the OpenVINO Developer Guide.
*/
class LP_TRANSFORMATIONS_API BroadcastTransformation : public TransparentBaseTransformation {
public:
OPENVINO_RTTI("BroadcastTransformation", "0");
BroadcastTransformation(const Params& params = Params());
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<ov::Node> layer) const override;
};

} // namespace low_precision
} // namespace pass
} // namespace ov
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2022 Intel Corporation
// Copyright (C) 2022-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand All @@ -23,6 +23,9 @@ class LP_TRANSFORMATIONS_API RecurrentCellTransformation : public LayerTransform
static std::shared_ptr<ov::Node> wrap_fake_quantize(const std::shared_ptr<ov::Node> parameter);
static std::shared_ptr<ov::Node> wrap_quantization(const std::shared_ptr<ov::Node> parameter);
static std::shared_ptr<ov::Node> wrap_dequantization(const std::shared_ptr<ov::Node> parameter, const bool with_subtract);

private:
void propagate(TransformationContext& context, const std::shared_ptr<ov::Node> node);
};

} // namespace low_precision
Expand Down
77 changes: 77 additions & 0 deletions src/common/low_precision_transformations/src/broadcast.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "low_precision/broadcast.hpp"

#include <memory>

#include "openvino/opsets/opset1.hpp"
#include "openvino/opsets/opset3.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "low_precision/network_helper.hpp"

#include "itt.hpp"

using namespace ov::pass::low_precision;

BroadcastTransformation::BroadcastTransformation(const Params& params) : TransparentBaseTransformation(params) {
MATCHER_SCOPE(BroadcastTransformation);
auto broadcast1 = pattern::wrap_type<ov::opset1::Broadcast>({
pattern::wrap_type<ov::opset1::Multiply>(),
ov::pass::pattern::any_input(),
ov::pass::pattern::any_input() });

auto broadcast3 = pattern::wrap_type<ov::opset3::Broadcast>({
pattern::wrap_type<ov::opset1::Multiply>(),
ov::pass::pattern::any_input(),
ov::pass::pattern::any_input() });

const auto matcher = std::make_shared<ov::pass::pattern::op::Or>(ov::OutputVector{ broadcast1, broadcast3 });

ov::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
auto op = m.get_match_root();
if (transformation_callback(op)) {
return false;
}
return transform(*context, m);
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(matcher, matcher_name);
this->register_matcher(m, callback);
}

bool BroadcastTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<ov::Node> layer) const {
if (!LayerTransformation::canBeTransformed(context, layer)) {
return false;
}

const auto& dequantization = NetworkHelper::getDequantization(layer, defaultPrecisions);
if (dequantization.empty()) {
return false;
}

if (dequantization.isPerTensor()) {
return true;
}

const auto& inputShape = layer->get_input_partial_shape(0);
if (inputShape.rank().is_dynamic() || inputShape[dequantization.channelDimIndex].is_dynamic()) {
return false;
}

const auto targetShapeConstant = ov::as_type_ptr<ov::opset1::Constant>(layer->get_input_node_shared_ptr(1));
const auto& targetShape = targetShapeConstant->cast_vector<int64_t>();
if (targetShape[dequantization.channelDimIndex] != inputShape[dequantization.channelDimIndex].get_length()) {
return false;
}

const auto axesMappingConstant = ov::as_type_ptr<ov::opset1::Constant>(layer->get_input_node_shared_ptr(2));
const auto& axesMapping = axesMappingConstant->cast_vector<int64_t>();
if (static_cast<size_t>(axesMapping[dequantization.channelDimIndex]) != dequantization.channelDimIndex) {
return false;
}

return true;
}
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ std::shared_ptr<ov::Node> LayerTransformation::moveDequantizationAfter(
const FakeQuantizeDequantization& dequantization,
const bool updateOutputPrecision,
const bool moveSubtract) const {
OPENVINO_ASSERT(!dequantization.empty());
const auto result = ov::pass::low_precision::NetworkHelper::moveDequantizationAfter(operation,
dequantization,
updateOutputPrecision,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "low_precision/assign_and_read_value.hpp"
#include "low_precision/avg_pool.hpp"
#include "low_precision/batch_to_space.hpp"
#include "low_precision/broadcast.hpp"
#include "low_precision/clamp.hpp"
#include "low_precision/convolution.hpp"
#include "low_precision/convolution_backprop_data.hpp"
Expand Down Expand Up @@ -240,6 +241,7 @@ bool ov::pass::low_precision::LowPrecision::run_on_model(const std::shared_ptr<o
ADD_MATCHER(common, AssignAndReadValueTransformation, f, params)
ADD_MATCHER(common, AvgPoolTransformation, params)
ADD_MATCHER(common, BatchToSpaceTransformation, params)
ADD_MATCHER(common, BroadcastTransformation, params)
ADD_MATCHER(common, ClampTransformation, params)
ADD_MATCHER(common, ConcatTransformation, params)
ADD_MATCHER(common, ConvolutionTransformation, params)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "openvino/opsets/opset1.hpp"
#include "openvino/opsets/opset2.hpp"
#include "openvino/opsets/opset3.hpp"
#include "openvino/opsets/opset4.hpp"
#include "openvino/opsets/opset5.hpp"
#include "openvino/opsets/opset6.hpp"
Expand Down Expand Up @@ -152,6 +153,8 @@ bool ov::pass::low_precision::MarkupPrecisions::isPrecisionPreserved(const std::
{ name<opset1::Relu>() },
// TODO: there are conditions
{ name<opset2::BatchToSpace>() },
{ name<opset1::Broadcast>() },
{ name<opset3::Broadcast>() },
{ name<opset1::Pad>() },
{ name<ov::opset12::Pad>() },
{ name<opset1::Reshape>() },
Expand Down Expand Up @@ -192,6 +195,8 @@ bool ov::pass::low_precision::MarkupPrecisions::isSupported(const std::shared_pt
{ name<opset1::Add>() },
{ name<opset1::AvgPool>() },
{ name<opset2::BatchToSpace>() },
{ name<opset1::Broadcast>() },
{ name<opset3::Broadcast>() },
{ name<opset1::Clamp>() },
{ name<opset1::Concat>() },
// ?
Expand Down
Loading

0 comments on commit 3056b53

Please sign in to comment.