Skip to content

Commit

Permalink
[CPU] [ARM] FullyConnected: int8 support
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Jul 10, 2024
1 parent a58e4a5 commit b56d725
Show file tree
Hide file tree
Showing 89 changed files with 358 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ static void initACLTensorParams(const MemoryPtr& memoryPtr,
}
}

static ACLInfo initTensorInfo(const arm_compute::TensorShape& tensorShape,
const arm_compute::DataType& dataType,
const arm_compute::DataLayout& dataLayout) {
ACLInfo ACLCommonExecutor::initTensorInfo(const arm_compute::TensorShape& tensorShape,
const arm_compute::DataType& dataType,
const arm_compute::DataLayout& dataLayout) {
ACLInfo aclMemoryInfo = nullptr;
if (dataType != arm_compute::DataType::UNKNOWN) {
aclMemoryInfo = std::make_shared<arm_compute::TensorInfo>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ class ACLCommonExecutor : public Executor {
protected:
ACLTensorAttrs aclTensorAttrs;

virtual ACLInfo initTensorInfo(const arm_compute::TensorShape& tensorShape,
const arm_compute::DataType& dataType,
const arm_compute::DataLayout& dataLayout);

private:
ACLMemoryTensors aclMemoryTensors;
ACLFunction iFunction = nullptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ ACLFullyConnectedExecutor::ACLFullyConnectedExecutor(const FCAttrs &attrs, const
aclTensorAttrs.hasLayoutTypeNHWC = memory.at(ARG_SRC)->getDescPtr()->hasLayoutType(LayoutType::nspc);
fullyConnectedLayerInfo.weights_trained_layout = getAclDataLayoutByMemoryDesc(memory.at(ARG_WEI)->getDescPtr());
fullyConnectedLayerInfo.transpose_weights = !attrs.weightsNonTransposed;
if (!attrs.dequantizationScales.empty()) {
dequantizationScale = attrs.dequantizationScales[0];
}

// Add postops
if (!postOps.empty() && postOps.size() == 1) {
Expand All @@ -32,10 +35,20 @@ ACLFullyConnectedExecutor::ACLFullyConnectedExecutor(const FCAttrs &attrs, const
}

bool ACLFullyConnectedExecutor::supports(const FCConfig &config) {
VERIFY(one_of(srcType(config), ov::element::f16, ov::element::f32), UNSUPPORTED_SRC_PRECISIONS);
// issue #<create and put number here>
const auto attrs = static_cast<FCAttrs>(config.attrs);
if (std::any_of(
attrs.dequantizationScales.begin(),
attrs.dequantizationScales.end(),
[](float value) { return value != 1.f;})) {
return false;
}

VERIFY(one_of(srcType(config), ov::element::f16, ov::element::f32, ov::element::i8), UNSUPPORTED_SRC_PRECISIONS);
VERIFY(postOpsNumbers(config) < 2, UNSUPPORTED_NUMBER_OF_POSTOPS);
VERIFY(one_of(srcRank(config), 2U, 3U, 4U), UNSUPPORTED_SRC_RANK);
VERIFY(one_of(weiRank(config), 2U, 3U), UNSUPPORTED_WEI_RANK);
VERIFY(static_cast<FCAttrs>(config.attrs).dequantizationScales.size() <= 1, UNSUPPORTED_PER_CHANNEL_QUANTIZATION);
return true;
}

Expand Down Expand Up @@ -74,16 +87,43 @@ arm_compute::Status ACLFullyConnectedExecutor::validateTensorsInfo(const ACLMemo
}

ACLFunction ACLFullyConnectedExecutor::configureFunction(const ACLMemoryTensors & aclMemoryTensors) {
const auto dstTensor = aclMemoryTensors.at(ACLArgs::ACL_DST).get();
if (dequantizationScale != 1.0) {
dstTensor->info()->set_quantization_info(arm_compute::QuantizationInfo(dequantizationScale, 0));
}

auto neFC = std::make_unique<arm_compute::NEFullyConnectedLayer>();
neFC->configure(
aclMemoryTensors[ACLArgs::ACL_SRC_0].get(),
aclMemoryTensors[ACLArgs::ACL_WEI].get(),
aclMemoryTensors[ACLArgs::ACL_BIAS].get(),
aclMemoryTensors[ACLArgs::ACL_DST].get(),
dstTensor,
fullyConnectedLayerInfo,
weightsInfo);
return neFC;
}

ACLInfo ACLFullyConnectedExecutor::initTensorInfo(const arm_compute::TensorShape& tensorShape,
const arm_compute::DataType& dataType,
const arm_compute::DataLayout& dataLayout) {
arm_compute::DataType fcDataType;
switch (dataType) {
case arm_compute::DataType::S8: {
fcDataType = arm_compute::DataType::QASYMM8_SIGNED;
break;
}
case arm_compute::DataType::U8: {
fcDataType = arm_compute::DataType::QASYMM8;
break;
}
default: {
fcDataType = dataType;
break;
}
}

return ACLCommonExecutor::initTensorInfo(tensorShape, fcDataType, dataLayout);
}

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,16 @@ class ACLFullyConnectedExecutor : public ACLCommonExecutor {
impl_desc_type implType() const override {
return impl_desc_type::gemm_acl;
}

protected:
ACLInfo initTensorInfo(const arm_compute::TensorShape& tensorShape,
const arm_compute::DataType& dataType,
const arm_compute::DataLayout& dataLayout) override;

private:
arm_compute::FullyConnectedLayerInfo fullyConnectedLayerInfo;
arm_compute::WeightsInfo weightsInfo;
float dequantizationScale = 1.f;
};

using ACLFullyConnectedExecutorPtr = std::shared_ptr<ACLFullyConnectedExecutor>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ inline int axisCast(const std::size_t axis, const std::size_t shapeSize, ACLAxis
* @param precision precision to be converted
* @return ComputeLibrary DataType or UNKNOWN if precision is not mapped to DataType
*/
inline arm_compute::DataType precisionToAclDataType(ov::element::Type precision) {
inline arm_compute::DataType precisionToAclDataType(const ov::element::Type& precision) {
switch (precision) {
case ov::element::i8: return arm_compute::DataType::S8;
case ov::element::u8: return arm_compute::DataType::U8;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#define UNSUPPORTED_DST_RANK " unsupported dst rank"
#define UNSUPPORTED_DST_STRIDES " unsupported dst strides"
#define HEURISTICS_MISMATCH " heuristics mismatch"
#define UNSUPPORTED_PER_CHANNEL_QUANTIZATION " unsupported per-channel quantization"

#define VERIFY(condition, ...) \
do { \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ static const TypeMapping dnnlFCTypeMapping {

static const TypeMapping aclFCTypeMapping {
// {src, wei, bia, dst} pt<src, wei, bias, dst>
{{_f32 | _f16, _any, _any, _any}, pt(bypass(), use<0>(), use<0>(), use<0>())},
{{_i8, _i8, _any, _any}, pt(just<i8>(), just<i8>(), just<i32>(), just<i32>())},
{{_any, _any, _any, _any}, pt(just<f32>(), just<f32>(), just<f32>(), just<f32>())}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,58 @@ bool isFullyConnected(const std::shared_ptr<const ov::Node>& node) {
ov::op::util::is_on_constant_path(out_weights);
}

bool SupportsFusingWithConvolution_Simple(const std::shared_ptr<const Node> &node) {
// TODO: move to base type
bool canBePerformedAsScaleShift(const std::shared_ptr<const Node> &node, const int channelAxis) {
size_t fusingPort = 0;
size_t numNonConstInputs = 0;
ov::PartialShape dataShape;
for (size_t i = 0; i < node->get_input_size(); i++) {
const auto parent = node->get_input_node_shared_ptr(i);
if (!ov::is_type<ov::op::v0::Constant>(parent)) {
fusingPort = i;
dataShape = node->get_input_partial_shape(i);
// only one non-const parent is allowed
if (++numNonConstInputs != 1)
return false;
} else {
// every const parent must have exactly one child
const auto out = parent->outputs();
const bool has_only_child = (out.size() == 1) && (out[0].get_target_inputs().size() == 1);
if (!has_only_child)
return false;
}
}

const auto isBroadcastableToDataInput = [&]() {
for (size_t i = 0; i < node->get_input_size(); i++) {
if (i == fusingPort)
continue;
const ov::PartialShape weightShape = node->get_input_partial_shape(i);
if (!isPerTensorOrPerChannelBroadcastable(dataShape.get_max_shape(), weightShape.get_max_shape(), channelAxis, true))
return false;
}
return true;
};

// Prelu and MulAdd are still ignored
// isConvertablePowerStatic() is ignored
return (ov::is_type<ov::opset1::Add>(node) ||
ov::is_type<ov::opset1::Multiply>(node) ||
ov::is_type<ov::opset1::Subtract>(node) ||
ov::is_type<ov::opset1::Divide>(node)) &&
isBroadcastableToDataInput();
}

bool SupportsFusingWithConvolution_Simple(const std::shared_ptr<const Node> &node, const int channelAxis = DEFAULT_AXIS) {
// Note: some other operations support this fusing (SoftPlus, Sqrt).
// Skip them here, when they are supported by Snippets ARM. Ticket: 141170.
return ov::is_type<ov::op::v0::Abs>(node) ||
ov::is_type<ov::op::v0::Clamp>(node) ||
ov::is_type<ov::op::v0::Elu>(node) ||
ov::is_type<ov::op::v0::Relu>(node) ||
ov::is_type<ov::op::v0::Sigmoid>(node) ||
ov::is_type<ov::op::v0::Tanh>(node);
ov::is_type<ov::op::v0::Tanh>(node) ||
canBePerformedAsScaleShift(node, channelAxis);
}
// Convolution is a special case, since it supports peculiar fusings
bool isSuitableConvolutionParent(const std::shared_ptr<const Node> &node) {
Expand Down Expand Up @@ -231,7 +274,10 @@ bool SnippetsMarkSkipped::run_on_model(const std::shared_ptr<ov::Model> &m) {
PropagateIfHasOnlyChild(node, fusingChainType);
} else if (isSuitableChildForFusingSimple(node)) {
#if defined (OV_CPU_WITH_ACL)
if (one_of(fusingChainType, NodeFusingType::FusedWithConvolution, NodeFusingType::FusedWithBinaryConvolution)) {
if (one_of(fusingChainType,
NodeFusingType::FusedWithConvolution,
NodeFusingType::FusedWithBinaryConvolution,
NodeFusingType::FusedWithFC)) {
PropagateIfHasOnlyChild(node, NodeFusingType::FusedTerminator);
continue;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#pragma once

#include "openvino/pass/graph_rewrite.hpp"
#include "transformations/snippets/common/pass/snippets_mark_skipped_base.hpp"

namespace ov {
namespace intel_cpu {
Expand All @@ -14,10 +14,9 @@ namespace intel_cpu {
* @brief Mark operations that should be ignored by snippets on tokenization stage. A typical example is eltwise operations
* that will be fused into convolutions on plugin side.
*/
class SnippetsMarkSkipped : public ov::pass::ModelPass {
class SnippetsMarkSkipped : public SnippetsMarkSkippedBase {
public:
OPENVINO_RTTI("SnippetsMarkSkipped", "0");
SnippetsMarkSkipped() : ModelPass() {}
bool run_on_model(const std::shared_ptr<ov::Model> &) override;
};

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "snippets_mark_skipped_base.hpp"

#include "snippets/pass/tokenization.hpp"
#include "snippets/op/subgraph.hpp"
#include "snippets/utils.hpp"

#include "transformations/utils/utils.hpp"
#include "transformations/utils.hpp"
#include "utils/general_utils.h"
#include "utils/cpu_utils.hpp"
#include "cpu/x64/cpu_isa_traits.hpp"

#include "itt.hpp"


namespace ov {
namespace intel_cpu {

bool SnippetsMarkSkippedBase::canBePerformedAsScaleShift(const std::shared_ptr<const Node> &node, const int channelAxis) {
size_t fusingPort = 0;
size_t numNonConstInputs = 0;
ov::PartialShape dataShape;
for (size_t i = 0; i < node->get_input_size(); i++) {
const auto parent = node->get_input_node_shared_ptr(i);
if (!ov::is_type<ov::op::v0::Constant>(parent)) {
fusingPort = i;
dataShape = node->get_input_partial_shape(i);
// only one non-const parent is allowed
if (++numNonConstInputs != 1)
return false;
} else {
// every const parent must have exactly one child
const auto out = parent->outputs();
const bool has_only_child = (out.size() == 1) && (out[0].get_target_inputs().size() == 1);
if (!has_only_child)
return false;
}
}

const auto isBroadcastableToDataInput = [&]() {
for (size_t i = 0; i < node->get_input_size(); i++) {
if (i == fusingPort)
continue;
const ov::PartialShape weightShape = node->get_input_partial_shape(i);
if (!isPerTensorOrPerChannelBroadcastable(dataShape.get_max_shape(), weightShape.get_max_shape(), channelAxis, true))
return false;
}
return true;
};

// Prelu and MulAdd are still ignored
// isConvertablePowerStatic() is ignored
return (ov::is_type<ov::opset1::Add>(node) ||
ov::is_type<ov::opset1::Multiply>(node) ||
ov::is_type<ov::opset1::Subtract>(node) ||
ov::is_type<ov::opset1::Divide>(node)) &&
isBroadcastableToDataInput();
}

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/graph_rewrite.hpp"

namespace ov {
namespace intel_cpu {

/**
* @interface SnippetsMarkSkippedBase
* @brief Base class to mark operations that should be ignored by snippets on tokenization stage.
*/
class SnippetsMarkSkippedBase : public ov::pass::ModelPass {
protected:
bool canBePerformedAsScaleShift(const std::shared_ptr<const Node> &node, const int channelAxis);
};

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#pragma once

#include "openvino/pass/graph_rewrite.hpp"
#include "transformations/snippets/common/pass/snippets_mark_skipped_base.hpp"

namespace ov {
namespace intel_cpu {
Expand All @@ -14,10 +14,10 @@ namespace intel_cpu {
* @brief Mark operations that should be ignored by snippets on tokenization stage. A typical example is eltwise operations
* that will be fused into convolutions on plugin side.
*/
class SnippetsMarkSkipped : public ov::pass::ModelPass {
class SnippetsMarkSkipped : public SnippetsMarkSkippedBase {
public:
OPENVINO_RTTI("SnippetsMarkSkipped", "0");
SnippetsMarkSkipped(bool enableBF16 = false) : ModelPass(), enableBF16(enableBF16) {}
SnippetsMarkSkipped(bool enableBF16 = false) : SnippetsMarkSkippedBase(), enableBF16(enableBF16) {}
bool run_on_model(const std::shared_ptr<ov::Model> &) override;
private:
bool enableBF16 = false;
Expand Down
4 changes: 3 additions & 1 deletion src/plugins/intel_cpu/tests/functional/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ if(NOT (ARM OR AARCH64))
list(APPEND EXCLUDED_SOURCE_PATHS
${CMAKE_CURRENT_SOURCE_DIR}/custom/single_layer_tests/instances/arm
${CMAKE_CURRENT_SOURCE_DIR}/custom/subgraph_tests/src/arm
${CMAKE_CURRENT_SOURCE_DIR}/shared_tests_instances/low_precision_transformations/arm
${CMAKE_CURRENT_SOURCE_DIR}/utils/arm)
else()
list(APPEND EXCLUDED_SOURCE_PATHS
Expand All @@ -67,7 +68,8 @@ endif()
if(NOT X86_64)
list(APPEND EXCLUDED_SOURCE_PATHS
${CMAKE_CURRENT_SOURCE_DIR}/custom/single_layer_tests/instances/x64
${CMAKE_CURRENT_SOURCE_DIR}/custom/subgraph_tests/src/x64)
${CMAKE_CURRENT_SOURCE_DIR}/custom/subgraph_tests/src/x64
${CMAKE_CURRENT_SOURCE_DIR}/shared_tests_instances/low_precision_transformations/x64)
endif()

ov_add_test_target(
Expand Down
Loading

0 comments on commit b56d725

Please sign in to comment.