Skip to content

Commit

Permalink
[CPU] [ARM] FullyConnected: int8 support
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Jul 2, 2024
1 parent a58e4a5 commit 743281f
Show file tree
Hide file tree
Showing 80 changed files with 120 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ static void initACLTensorParams(const MemoryPtr& memoryPtr,
}
}

static ACLInfo initTensorInfo(const arm_compute::TensorShape& tensorShape,
const arm_compute::DataType& dataType,
const arm_compute::DataLayout& dataLayout) {
ACLInfo ACLCommonExecutor::initTensorInfo(const arm_compute::TensorShape& tensorShape,
const arm_compute::DataType& dataType,
const arm_compute::DataLayout& dataLayout) {
ACLInfo aclMemoryInfo = nullptr;
if (dataType != arm_compute::DataType::UNKNOWN) {
aclMemoryInfo = std::make_shared<arm_compute::TensorInfo>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ class ACLCommonExecutor : public Executor {
protected:
ACLTensorAttrs aclTensorAttrs;

virtual ACLInfo initTensorInfo(const arm_compute::TensorShape& tensorShape,
const arm_compute::DataType& dataType,
const arm_compute::DataLayout& dataLayout);

private:
ACLMemoryTensors aclMemoryTensors;
ACLFunction iFunction = nullptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,15 @@ ACLFullyConnectedExecutor::ACLFullyConnectedExecutor(const FCAttrs &attrs, const
}

bool ACLFullyConnectedExecutor::supports(const FCConfig &config) {
VERIFY(one_of(srcType(config), ov::element::f16, ov::element::f32), UNSUPPORTED_SRC_PRECISIONS);
const auto attrs = static_cast<FCAttrs>(config.attrs);
if (std::any_of(
attrs.dequantizationScales.begin(),
attrs.dequantizationScales.end(),
[](float value) { return value != 1.f;})) {
return false;
}

VERIFY(one_of(srcType(config), ov::element::f16, ov::element::f32, ov::element::i8), UNSUPPORTED_SRC_PRECISIONS);
VERIFY(postOpsNumbers(config) < 2, UNSUPPORTED_NUMBER_OF_POSTOPS);
VERIFY(one_of(srcRank(config), 2U, 3U, 4U), UNSUPPORTED_SRC_RANK);
VERIFY(one_of(weiRank(config), 2U, 3U), UNSUPPORTED_WEI_RANK);
Expand Down Expand Up @@ -85,5 +93,27 @@ ACLFunction ACLFullyConnectedExecutor::configureFunction(const ACLMemoryTensors
return neFC;
}

ACLInfo ACLFullyConnectedExecutor::initTensorInfo(const arm_compute::TensorShape& tensorShape,
const arm_compute::DataType& dataType,
const arm_compute::DataLayout& dataLayout) {
arm_compute::DataType fcDataType;
switch (dataType) {
case arm_compute::DataType::S8: {
fcDataType = arm_compute::DataType::QASYMM8_SIGNED;
break;
}
case arm_compute::DataType::U8: {
fcDataType = arm_compute::DataType::QASYMM8;
break;
}
default: {
fcDataType = dataType;
break;
}
}

return ACLCommonExecutor::initTensorInfo(tensorShape, fcDataType, dataLayout);
}

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ class ACLFullyConnectedExecutor : public ACLCommonExecutor {
impl_desc_type implType() const override {
return impl_desc_type::gemm_acl;
}

protected:
ACLInfo initTensorInfo(const arm_compute::TensorShape& tensorShape,
const arm_compute::DataType& dataType,
const arm_compute::DataLayout& dataLayout) override;

private:
arm_compute::FullyConnectedLayerInfo fullyConnectedLayerInfo;
arm_compute::WeightsInfo weightsInfo;
Expand Down
41 changes: 38 additions & 3 deletions src/plugins/intel_cpu/src/nodes/executors/acl/acl_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,50 @@ inline int axisCast(const std::size_t axis, const std::size_t shapeSize, ACLAxis
}
}

enum class QuantizedDataType {
NONE, // not quantized
QSYMM, // quantized, symmetric
QASYMM // quantized, asymmetric
};

/**
* @brief Return ComputeLibrary DataType that corresponds to the given precision
* @param precision precision to be converted
* @return ComputeLibrary DataType or UNKNOWN if precision is not mapped to DataType
*/
inline arm_compute::DataType precisionToAclDataType(ov::element::Type precision) {
inline arm_compute::DataType precisionToAclDataType(
const ov::element::Type& precision,
const QuantizedDataType quantized = QuantizedDataType::NONE) {
switch (precision) {
case ov::element::i8: return arm_compute::DataType::S8;
case ov::element::u8: return arm_compute::DataType::U8;
case ov::element::i8: {
switch (quantized) {
case QuantizedDataType::QASYMM: {
return arm_compute::DataType::QASYMM8_SIGNED;
}
case QuantizedDataType::NONE: {
return arm_compute::DataType::S8;
}
default: {
return arm_compute::DataType::UNKNOWN;
}
}
}
case ov::element::u8: {
switch (quantized) {
case QuantizedDataType::QSYMM: {
return arm_compute::DataType::QSYMM8;
}
case QuantizedDataType::QASYMM: {
return arm_compute::DataType::QASYMM8;
}
case QuantizedDataType::NONE: {
return arm_compute::DataType::U8;
}
default: {
return arm_compute::DataType::UNKNOWN;
}
}
}
case ov::element::i16: return arm_compute::DataType::S16;
case ov::element::u16: return arm_compute::DataType::U16;
case ov::element::i32: return arm_compute::DataType::S32;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ static const TypeMapping dnnlFCTypeMapping {
static const TypeMapping aclFCTypeMapping {
// {src, wei, bia, dst} pt<src, wei, bias, dst>
{{_f32 | _f16, _any, _any, _any}, pt(bypass(), use<0>(), use<0>(), use<0>())},
{{_i8, _i8, _any, _any}, pt(just<i8>(), just<i8>(), bypass(), just<i32>())},
{{_any, _any, _any, _any}, pt(just<f32>(), just<f32>(), just<f32>(), just<f32>())}
};

Expand Down
4 changes: 3 additions & 1 deletion src/plugins/intel_cpu/tests/functional/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ if(NOT (ARM OR AARCH64))
list(APPEND EXCLUDED_SOURCE_PATHS
${CMAKE_CURRENT_SOURCE_DIR}/custom/single_layer_tests/instances/arm
${CMAKE_CURRENT_SOURCE_DIR}/custom/subgraph_tests/src/arm
${CMAKE_CURRENT_SOURCE_DIR}/shared_tests_instances/low_precision_transformations/arm
${CMAKE_CURRENT_SOURCE_DIR}/utils/arm)
else()
list(APPEND EXCLUDED_SOURCE_PATHS
Expand All @@ -67,7 +68,8 @@ endif()
if(NOT X86_64)
list(APPEND EXCLUDED_SOURCE_PATHS
${CMAKE_CURRENT_SOURCE_DIR}/custom/single_layer_tests/instances/x64
${CMAKE_CURRENT_SOURCE_DIR}/custom/subgraph_tests/src/x64)
${CMAKE_CURRENT_SOURCE_DIR}/custom/subgraph_tests/src/x64
${CMAKE_CURRENT_SOURCE_DIR}/shared_tests_instances/low_precision_transformations/x64)
endif()

ov_add_test_target(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_LPT, FullyConnectedTransformation,
::testing::ValuesIn(netPrecisions),
::testing::ValuesIn(shapes),
::testing::Values(ov::test::utils::DEVICE_CPU),
::testing::ValuesIn(trasformationParamValues)),
::testing::ValuesIn(trasformationParamValues),
::testing::ValuesIn({ov::element::i8, ov::element::u8})),
FullyConnectedTransformation::getTestCaseName);
} // namespace
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,10 @@ std::vector<std::string> disabledTestPatterns() {
retVector.emplace_back(R"(smoke_TestsDFT_(1|2|3|4)d/DFTLayerTest.Inference.*)");
// Issue 88764, 91647, 108802: accuracy issue
retVector.emplace_back(R"(MultipleLSTMCellTest/MultipleLSTMCellTest.CompareWithRefs.*)");
#if !defined(OPENVINO_ARCH_ARM64)
// int8 / code-generation specific
retVector.emplace_back(R"(smoke_LPT.*)");
#endif
// Compressed weights are not supported
retVector.emplace_back(R"(smoke_MatMulCompressedWeights.*)");
retVector.emplace_back(R"(smoke_MatMulSharedCompressedWeights.*)");
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/thirdparty/ComputeLibrary
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_LPT, FullyConnectedTransformation,
::testing::ValuesIn(netPrecisions),
::testing::ValuesIn(shapes),
::testing::Values(ov::test::utils::DEVICE_GPU),
::testing::ValuesIn(trasformationParamValues)),
::testing::ValuesIn(trasformationParamValues),
::testing::ValuesIn({ov::element::i8, ov::element::u8})),
FullyConnectedTransformation::getTestCaseName);
} // namespace
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ typedef std::tuple<
ov::element::Type,
MatMulShapes,
std::string,
ov::pass::low_precision::LayerTransformation::Params> FullyConnectedTransformationParams;
ov::pass::low_precision::LayerTransformation::Params,
ov::element::Type> FullyConnectedTransformationParams;

namespace LayerTestsDefinitions {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@ std::string FullyConnectedTransformation::getTestCaseName(const testing::TestPar
MatMulShapes shapes;
std::string targetDevice;
ov::pass::low_precision::LayerTransformation::Params params;
std::tie(precision, shapes, targetDevice, params) = obj.param;
ov::element::Type weightsType;
std::tie(precision, shapes, targetDevice, params, weightsType) = obj.param;

std::ostringstream result;
result <<
get_test_case_name_by_params(precision, shapes.inputA, targetDevice, params) <<
shapes.inputB << "_" <<
get_test_case_name_by_params(precision, shapes.inputA, targetDevice, params) <<
shapes.inputB << "_" <<
shapes.transposeA << "_" <<
shapes.transposeB;
shapes.transposeB << "_" <<
weightsType;

return result.str();
}
Expand All @@ -36,7 +38,8 @@ void FullyConnectedTransformation::SetUp() {
ov::element::Type precision;
MatMulShapes shapes;
ov::pass::low_precision::LayerTransformation::Params params;
std::tie(precision, shapes, targetDevice, params) = this->GetParam();
ov::element::Type weightsType;
std::tie(precision, shapes, targetDevice, params, weightsType) = this->GetParam();

init_input_shapes({ shapes.inputA, shapes.inputB });

Expand All @@ -45,12 +48,17 @@ void FullyConnectedTransformation::SetUp() {
shapes.inputA,
shapes.inputB,
shapes.transposeA,
shapes.transposeB);
shapes.transposeB,
weightsType == ov::element::i8);
}

TEST_P(FullyConnectedTransformation, CompareWithRefImpl) {
SKIP_IF_CURRENT_TEST_IS_DISABLED();
run();

const auto actualPrecision = get_runtime_precision_by_type("FullyConnected");
const auto weightsType = std::get<4>(GetParam());
EXPECT_EQ(actualPrecision, weightsType.to_string());
};

} // namespace LayerTestsDefinitions
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ class MatMulFunction {
const ov::PartialShape inputShape1,
const ov::PartialShape inputShape2,
const bool transpose1,
const bool transpose2);
const bool transpose2,
const bool signedOnWeights = false);

static std::shared_ptr<ov::Model> getOriginal(
const ov::element::Type precision,
Expand Down
13 changes: 9 additions & 4 deletions src/tests/ov_helpers/ov_lpt_models/src/mat_mul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,17 @@ std::shared_ptr<ov::Model> MatMulFunction::getOriginal(
const ov::PartialShape inputShape1,
const ov::PartialShape inputShape2,
const bool transpose1,
const bool transpose2) {
const bool transpose2,
const bool signedOnWeights) {
const auto paramNode = std::make_shared<ov::opset1::Parameter>(precision, inputShape1);
const std::vector<size_t> constShapes(inputShape1.rank().get_length(), 1ul);
const auto fakeQuantizeOnAcitvations = ov::test::utils::make_fake_quantize(
paramNode, precision, 256ul, constShapes,
{ 0.f }, { 255.f / 4.f }, { 0.f }, { 255.f / 4.f });
const auto fakeQuantizeOnAcitvations = signedOnWeights ?
ov::test::utils::make_fake_quantize(
paramNode, precision, 256ul, constShapes,
{ -128.f / 4.f }, { 127.f / 4.f }, { -128.f / 4.f }, { 127.f / 4.f }) :
ov::test::utils::make_fake_quantize(
paramNode, precision, 256ul, constShapes,
{ 0.f }, { 255.f / 4.f }, { 0.f }, { 255.f / 4.f });
fakeQuantizeOnAcitvations->set_friendly_name("fakeQuantizeOnAcitvations");

auto weightsConst = std::make_shared<ov::op::v0::Constant>(
Expand Down

0 comments on commit 743281f

Please sign in to comment.