diff --git a/src/common/low_precision_transformations/include/low_precision/mat_mul.hpp b/src/common/low_precision_transformations/include/low_precision/mat_mul.hpp index bc0077a716f701..1c1e32f1ee1a15 100644 --- a/src/common/low_precision_transformations/include/low_precision/mat_mul.hpp +++ b/src/common/low_precision_transformations/include/low_precision/mat_mul.hpp @@ -26,6 +26,9 @@ class LP_TRANSFORMATIONS_API MatMulTransformation : public LayerTransformation { bool transform(TransformationContext &context, ov::pass::pattern::Matcher &m) override; bool isPrecisionPreserved(std::shared_ptr layer) const noexcept override; bool canBeTransformed(const TransformationContext& context, std::shared_ptr layer) const override; + +protected: + virtual void handleDequantization(const std::shared_ptr& dequantization) const {} }; } // namespace low_precision diff --git a/src/common/low_precision_transformations/include/low_precision/mat_mul_with_dequantization.hpp b/src/common/low_precision_transformations/include/low_precision/mat_mul_with_dequantization.hpp new file mode 100644 index 00000000000000..4913d1cf1057f4 --- /dev/null +++ b/src/common/low_precision_transformations/include/low_precision/mat_mul_with_dequantization.hpp @@ -0,0 +1,33 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include "mat_mul.hpp" + +namespace ov { +namespace pass { +namespace low_precision { + +/** + * @ingroup ov_transformation_common_api + * @brief MatMulWithDequantizationTransformation propagates dequantization operations through MatMul operation and keep dequantisation as is. + * + * For more details about the transformation, refer to + * [MatMulWithDequantizationTransformation](@ref openvino_docs_OV_UG_lpt_MatMulWithDequantizationTransformation) page + * in the OpenVINO Developer Guide. + */ +class LP_TRANSFORMATIONS_API MatMulWithDequantizationTransformation : public MatMulTransformation { +public: + OPENVINO_RTTI("MatMulWithDequantizationTransformation", "0"); + MatMulWithDequantizationTransformation(const Params& params = Params()); + +protected: + void handleDequantization(const std::shared_ptr& dequantization) const override; +}; + +} // namespace low_precision +} // namespace pass +} // namespace ov diff --git a/src/common/low_precision_transformations/src/low_precision.cpp b/src/common/low_precision_transformations/src/low_precision.cpp index 52479a5c2dc1fa..e3046e166956ea 100644 --- a/src/common/low_precision_transformations/src/low_precision.cpp +++ b/src/common/low_precision_transformations/src/low_precision.cpp @@ -52,7 +52,11 @@ #include "low_precision/fake_quantize.hpp" #include "low_precision/group_convolution.hpp" #include "low_precision/interpolate.hpp" +#ifdef OPENVINO_ARCH_ARM64 +#include "low_precision/mat_mul_with_dequantization.hpp" +#else #include "low_precision/mat_mul.hpp" +#endif #include "low_precision/max_pool.hpp" #include "low_precision/multiply_partial.hpp" #include "low_precision/mvn.hpp" @@ -252,7 +256,11 @@ bool ov::pass::low_precision::LowPrecision::run_on_model(const std::shared_ptr +#include "low_precision/rt_info/bias_attribute.hpp" + +using namespace ov; +using namespace ov::pass; +using namespace ov::pass::low_precision; + +MatMulWithDequantizationTransformation::MatMulWithDequantizationTransformation(const Params& params) : MatMulTransformation(params) { +} + +void MatMulWithDequantizationTransformation::handleDequantization(const std::shared_ptr& dequantization) const { + const auto& dequantization_constant = is_type(dequantization->get_input_node_shared_ptr(1)) ? + as_type(dequantization->get_input_node_ptr(1)) : + as_type(dequantization->get_input_node_ptr(0)); + if ((dequantization_constant == nullptr) || (ov::shape_size(dequantization_constant->get_shape()) != 1ull)) { + return; + } + + ov::mark_as_bias(dequantization); +} diff --git a/src/common/low_precision_transformations/src/network_helper.cpp b/src/common/low_precision_transformations/src/network_helper.cpp index 1aebfb111d7892..c4afc48e4360f7 100644 --- a/src/common/low_precision_transformations/src/network_helper.cpp +++ b/src/common/low_precision_transformations/src/network_helper.cpp @@ -17,6 +17,7 @@ #include "low_precision/common/ie_lpt_exception.hpp" #include "low_precision/layer_transformation.hpp" #include "low_precision/network_helper.hpp" +#include "low_precision/rt_info/bias_attribute.hpp" #include "low_precision/rt_info/intervals_alignment_attribute.hpp" #include "low_precision/rt_info/precision_preserved_attribute.hpp" #include "low_precision/rt_info/quantization_alignment_attribute.hpp" @@ -1192,7 +1193,7 @@ FakeQuantizeDequantization NetworkHelper::getDequantization(const std::shared_pt const std::shared_ptr multiply = ov::as_type_ptr(dataNode.get_node_shared_ptr()); std::shared_ptr multiplyConstant; if (multiply != nullptr) { - if (!FakeQuantizeDequantization::checkShape(multiply)) { + if (!FakeQuantizeDequantization::checkShape(multiply) || ov::marked_as_bias(multiply)) { return FakeQuantizeDequantization(); } @@ -1207,6 +1208,9 @@ FakeQuantizeDequantization NetworkHelper::getDequantization(const std::shared_pt std::shared_ptr subtractConvert; std::shared_ptr subtractConstant; if (subtract != nullptr) { + if (ov::marked_as_bias(subtract)) { + return FakeQuantizeDequantization(); + } if (!FakeQuantizeDequantization::checkShape(subtract)) { return FakeQuantizeDequantization(dataNode, nullptr, nullptr, nullptr, nullptr, multiply, multiplyConstant); } @@ -1220,6 +1224,9 @@ FakeQuantizeDequantization NetworkHelper::getDequantization(const std::shared_pt const std::shared_ptr convert = ov::as_type_ptr(dataNode.get_node_shared_ptr()); if (convert != nullptr) { + if (ov::marked_as_bias(convert)) { + return FakeQuantizeDequantization(); + } auto el_type = convert->input(0).get_element_type(); auto foundIt = std::find(defaultPrecisions.begin(), defaultPrecisions.end(), el_type); if (foundIt == defaultPrecisions.end() && diff --git a/src/plugins/intel_cpu/src/cpu_memory.cpp b/src/plugins/intel_cpu/src/cpu_memory.cpp index 8e5fe8d72fd1f2..4fed8e2ef860b2 100644 --- a/src/plugins/intel_cpu/src/cpu_memory.cpp +++ b/src/plugins/intel_cpu/src/cpu_memory.cpp @@ -403,6 +403,7 @@ void DnnlMemoryBlock::notifyUpdate() { StaticMemory::StaticMemory(const dnnl::engine& eng, MemoryDescPtr desc, const void* data, bool pads_zeroing) : m_eng(eng), m_pMemDesc(desc) { + OPENVINO_ASSERT(!desc->empty() || (desc->empty() && (data == nullptr))); if (desc->getPrecision() == element::string) { OPENVINO_THROW("[CPU] StaticMemory object cannot be created for string data."); } @@ -412,7 +413,7 @@ StaticMemory::StaticMemory(const dnnl::engine& eng, MemoryDescPtr desc, const vo m_size = m_pMemDesc->getCurrentMemSize(); - if (data) { + if (data || desc->empty()) { m_pMemBlock = std::make_shared(const_cast(data), m_size); } else { m_pMemBlock = std::make_shared(m_size); diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_common_executor.cpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_common_executor.cpp index 646cf47c1bcf6c..5d12ee70fe23cc 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_common_executor.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_common_executor.cpp @@ -38,9 +38,9 @@ static void initACLTensorParams(const MemoryPtr& memoryPtr, } } -static std::shared_ptr initTensorInfo(const arm_compute::TensorShape& tensorShape, - const arm_compute::DataType& dataType, - const arm_compute::DataLayout& dataLayout) { +std::shared_ptr ACLCommonExecutor::initTensorInfo(const arm_compute::TensorShape& tensorShape, + const arm_compute::DataType& dataType, + const arm_compute::DataLayout& dataLayout) { std::shared_ptr aclMemoryInfo = nullptr; if (dataType != arm_compute::DataType::UNKNOWN) { aclMemoryInfo = std::make_shared( @@ -72,6 +72,9 @@ bool ACLCommonExecutor::update(const MemoryArgs &memory) { ACLTypes aclDataType{}; ACLLayouts aclDataLayout{}; for (auto& cpu_mem_ptr : memory) { + if (cpu_mem_ptr.second->getSize() == 0) { + continue; + } const ACLArgs index = argConvert.at(cpu_mem_ptr.first); initACLTensorParams(cpu_mem_ptr.second, aclTensorAttrs, aclMemoryShapes[index], diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_common_executor.hpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_common_executor.hpp index 1a5a00c7a85f7a..ff36eea94352ec 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_common_executor.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_common_executor.hpp @@ -50,6 +50,11 @@ class ACLCommonExecutor : public Executor { protected: ACLTensorAttrs aclTensorAttrs; + + virtual std::shared_ptr initTensorInfo(const arm_compute::TensorShape& tensorShape, + const arm_compute::DataType& dataType, + const arm_compute::DataLayout& dataLayout); + private: ACLTensors aclMemoryTensors; ACLInfos aclMemoryInfos; diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected.cpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected.cpp index cc42691950a3ff..124e972e9fbdca 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected.cpp @@ -7,6 +7,7 @@ #include "ov_optional.hpp" #include "acl_fullyconnected.hpp" +#include "acl_fullyconnected_utils.hpp" #include "acl_utils.hpp" #include "nodes/executors/executor.hpp" #include "nodes/executors/memory_arguments.hpp" @@ -22,232 +23,6 @@ namespace ov { namespace intel_cpu { -static VectorDims makeDummyInputDims(const Shape& inShape, const Shape& wShape) { - const auto& weightDims = wShape.getStaticDims(); - - auto inMinDims = inShape.getMinDims(); - auto inMaxDims = inShape.getMaxDims(); - inMinDims.back() = weightDims.back(); - inMaxDims.back() = weightDims.back(); - - return MemoryDescUtils::makeDummyShape(Shape(inMinDims, inMaxDims)).getStaticDims(); -} - -static VectorDims makeDummyOutputDims(const VectorDims& inShape, const VectorDims& wShape, const size_t out_rank) { - size_t activationRank = inShape.size(); - size_t channelRank = wShape.size() - 1; - // activation weight output_shape - // NCHW CoCHW NCo - // TNC CoC TNCo - // NC CoC NCo - VectorDims outputShape(out_rank, 1); - // set Co - outputShape.back() = wShape[0]; - // set batch dims - size_t batchRank = activationRank - channelRank; - size_t startIdx = out_rank - batchRank - 1; - for (size_t i = 0; i < batchRank; i++) { - outputShape[i + startIdx] = inShape[i]; - } - - return outputShape; -} - -static DnnlMemoryDescPtr makeTransposedWeightDescriptor(const DnnlMemoryDescPtr srcDesc, - const DnnlMemoryDescPtr dstDesc) { - const auto& weiDesc = srcDesc->getDnnlDesc(); - const auto reorderedWeiDesc = dnnl::memory::desc{weiDesc.get_dims(), weiDesc.get_data_type(), dnnl::memory::format_tag::ba}; - const auto transposedWeiDesc = reorderedWeiDesc.reshape(dstDesc->getDnnlDesc().get_dims()); - - return DnnlExtensionUtils::makeDescriptor(transposedWeiDesc); -} - -static ov::optional convertWeightPrecision(MemoryPtr input, MemoryPtr output, ov::element::Type weightPrecision) { - MemoryArgs memoryArgs; - memoryArgs[ARG_SRC] = input; - memoryArgs[ARG_DST] = output; - - auto aclWeightsConverter = std::make_shared(); - if (aclWeightsConverter->update(memoryArgs)) { - aclWeightsConverter->execute(memoryArgs); - return ov::optional(memoryArgs.at(ARG_DST)); - } - - if (!node::Convert::isSupportedDesc(input->getDesc()) || - !node::Convert::isSupportedDesc(output->getDesc())) { - return {}; - } - - auto data = static_cast(input->getData()); - std::vector tmpBuff; - tmpBuff.resize(output->getSize()); - cpu_convert(data, tmpBuff.data(), DnnlExtensionUtils::DataTypeToElementType(input->getDataType()), - weightPrecision, input->getSize() / input->getDesc().getPrecision().size()); - - return ov::optional(std::make_shared(output->getPrimitive().get_engine(), - output->getDesc().cloneWithNewPrecision(weightPrecision), - tmpBuff.data())); -} - -static ov::optional reorderDataFallback(MemoryPtr input, MemoryPtr output, ExecutorContext::CPtr context) { - if (output->getDataType() == input->getDataType()) { - return {}; - } - const auto inPrc = DnnlExtensionUtils::DataTypeToElementType(input->getDataType()); - auto convertedDstMemoryDesc = output->getDesc().cloneWithNewPrecision(inPrc); - dnnl::reorder reorderWithoutConvert = getReorderPrim(context->getRuntimeCache(), - output->getPrimitive().get_engine(), - input->getPrimitive().get_desc(), - MemoryDescUtils::convertToDnnlMemoryDesc(convertedDstMemoryDesc)->getDnnlDesc()); - - if (reorderWithoutConvert && parse_impl_name(reorderWithoutConvert.get_primitive_desc()->impl()->name()) != ref_any) { - auto convertOutput = convertWeightPrecision(input, output, inPrc); - if (!convertOutput) { - return {}; - } - input = *convertOutput; - - if (reorderWithoutConvert) { - dnnl::stream loc_stream(output->getPrimitive().get_engine(), dnnl::stream::flags::in_order); - reorderWithoutConvert.execute(loc_stream, {{DNNL_ARG_FROM, input->getPrimitive()}, {DNNL_ARG_TO, output->getPrimitive()}}); - return ov::optional(output); - } - } - return {}; -} - -static MemoryPtr reorderData(DnnlMemoryDescPtr srcWeightDesc, - DnnlMemoryDescPtr dstWeightDesc, - MemoryCPtr weightsMem, - ExecutorContext::CPtr context) { - MemoryPtr input = std::make_shared(context->getEngine(), srcWeightDesc, weightsMem->getData()); - MemoryPtr output = std::make_shared(context->getEngine(), dstWeightDesc); - if (!input->getDesc().isDefined() || !output->getDesc().isDefined()) - OPENVINO_THROW("Can't reorder data with dynamic shapes"); - - if (input->getShape().hasZeroDims() || output->getShape().hasZeroDims()) { - return output; - } - - if (input->getDesc().isCompatible(output->getDesc())) { - auto srcPtr = static_cast(input->getData()); - auto dstPtr = static_cast(output->getData()); - auto copySize = output->getSize(); - cpu_memcpy(dstPtr, srcPtr, copySize); - return output; - } - - // try directly reorder - auto engine = output->getPrimitive().get_engine(); - dnnl::reorder directReorder = getReorderPrim(context->getRuntimeCache(), - engine, - input->getPrimitive().get_desc(), - output->getPrimitive().get_desc()); - - if (!directReorder || parse_impl_name(directReorder.get_primitive_desc()->impl()->name()) == ref_any) { - // try precision conversion then do the reorder - auto fallbackOutput = reorderDataFallback(input, output, context); - if (fallbackOutput) { - return *fallbackOutput; - } - } - // if precision conversion does not work then do direct reference reorder - if (directReorder) { - dnnl::stream loc_stream(engine, dnnl::stream::flags::in_order); - directReorder.execute(loc_stream, {{DNNL_ARG_FROM, input->getPrimitive()}, {DNNL_ARG_TO, output->getPrimitive()}}); - } else { - OPENVINO_THROW("Could not make onednn reorder."); - } - return output; -} - -static MemoryPtr reorderWeights(const MemoryArgs &memory, - const ExecutorContext::CPtr context, - ACLFCAttrs& aclfcAttrs, - DnnlMemoryDescPtr dnnlSrcDesc, - DnnlMemoryDescPtr dnnlDstDesc) { - auto create = [&]() { - MemoryPtr weightsMemory = memory.at(ARG_WEI); - if (aclfcAttrs.isWeightsRepacked || aclfcAttrs.isConvertedWeights) { - weightsMemory = reorderData(dnnlSrcDesc, dnnlDstDesc, memory.at(ARG_WEI), context); - DEBUG_LOG("ACLFullyConnectedExecutor: cache miss, perform packing"); - } - return weightsMemory; - }; - - auto weightCache = context->getWeightsCache(); - if (weightCache != nullptr) { - const auto& wgtDims = memory.at(ARG_WEI)->getStaticDims(); - const auto N = wgtDims[0]; - const auto K = wgtDims[1]; - std::string format = "fc_acl_" + std::to_string(N) + "_" + std::to_string(K); - const std::string string_hash = format + "_" + std::to_string(memory.at(ARG_WEI)->getSize()) + "_" + - std::to_string(reinterpret_cast(memory.at(ARG_WEI)->getData())); - DEBUG_LOG("ACLFullyConnectedExecutor: findOrCreate, string_hash: ", string_hash); - return *weightCache->findOrCreate(string_hash, create); - } - - DEBUG_LOG("ACLFullyConnectedExecutor: Weights cache is not available"); - return create(); -} - -static MemoryPtr prepareWeightMemory(const MemoryArgs &memory, - const ExecutorContext::CPtr context, - const FCAttrs &attrs, - ACLFCAttrs& aclfcAttrs, - const PostOps &postOps, - arm_compute::WeightFormat& expectedWeightFormat, - arm_compute::TensorInfo& weiTensorInfo) { - MemoryArgs memoryArgs; - memoryArgs[ARG_BIAS] = memory.at(ARG_BIAS); - memoryArgs[ARG_WEI] = memory.at(ARG_WEI); - if (memory.at(ARG_SRC_0)->getShape().isDynamic()) { - const auto& inShape = memory.at(ARG_SRC_0)->getShape(); - const auto& wShape = memory.at(ARG_WEI)->getShape(); - const auto& inDymmyDims = makeDummyInputDims(inShape, wShape); - const auto& outDymmyDims = makeDummyOutputDims(inDymmyDims, wShape.getStaticDims(), memory.at(ARG_DST)->getShape().getRank()); - memoryArgs[ARG_SRC_0] = std::make_shared(context->getEngine(), - memory.at(ARG_SRC_0)->getDescPtr()->cloneWithNewDims(inDymmyDims)); - memoryArgs[ARG_DST] = std::make_shared(context->getEngine(), - memory.at(ARG_DST)->getDescPtr()->cloneWithNewDims(outDymmyDims)); - } else { - memoryArgs[ARG_SRC_0] = memory.at(ARG_SRC_0); - memoryArgs[ARG_DST] = memory.at(ARG_DST); - } - // TODO: ACLWeightFormatGenerator should be replaced with Reorder executor - // that calls ACL NEReorder + NETranspose or dnnl::reorder depending on backend availability - auto aclWeightsRepack = std::make_shared(attrs, postOps, memoryArgs); - bool isNeededReorder = aclWeightsRepack->update(memoryArgs); - expectedWeightFormat = isNeededReorder ? aclWeightsRepack->getOptImplWeightFormat() : arm_compute::WeightFormat::UNSPECIFIED; - weiTensorInfo = aclWeightsRepack->getTensorInfo(ACLArgs::ACL_WEI); - - MemoryPtr dstMemPtr = std::make_shared(context->getEngine(), - memory.at(ARG_WEI)->getDescPtr()->cloneWithNewPrecision(aclfcAttrs.inputPrecision)); - auto dstDesc = dstMemPtr->getDescPtr(); - auto dnnlDstDesc = MemoryDescUtils::convertToDnnlMemoryDesc(dstDesc); - auto weiDesc = memory.at(ARG_WEI)->getDescPtr(); - auto dnnlSrcDesc = MemoryDescUtils::convertToDnnlMemoryDesc(weiDesc); - - if (isNeededReorder) { - dnnl::impl::dim_t o_dim = 0; - dnnl::impl::dim_t inner_dim = 1; - std::vector remaining_dims = {}; - auto weights_md_ = dnnlDstDesc->getDnnlDesc().get(); - dnnl::impl::cpu::acl::acl_utils::reorder_to_weight_format(weiTensorInfo, *weights_md_, expectedWeightFormat, - inner_dim, o_dim, remaining_dims, {}); - if (aclfcAttrs.weightsNonTransposed) { - dnnlSrcDesc = makeTransposedWeightDescriptor(dnnlSrcDesc, dnnlDstDesc); - } - aclfcAttrs.isWeightsRepacked = true; - return reorderWeights(memory, context, aclfcAttrs, dnnlSrcDesc, dnnlDstDesc); - } - if (!aclfcAttrs.weightsNonTransposed) { - dnnlDstDesc = makeTransposedWeightDescriptor(dnnlDstDesc, dnnlSrcDesc); - aclfcAttrs.isWeightsRepacked = true; - } - return reorderWeights(memory, context, aclfcAttrs, dnnlSrcDesc, dnnlDstDesc); -} - static bool checkPostOps(const PostOps &postOps) { if (postOps.empty()) { return true; @@ -292,7 +67,7 @@ ACLFullyConnectedExecutor::ACLFullyConnectedExecutor(const FCAttrs &attrs, const MemoryArgs &memory, const ExecutorContext::CPtr context) { initFCAttrs(attrs, aclTensorAttrs, aclfcAttrs, memory, fullyConnectedLayerInfo, postOps); - packedWeights = prepareWeightMemory(memory, context, attrs, aclfcAttrs, postOps, expectedWeightFormat, weiTensorInfo); + packedWeights = acl_fc_executor::prepareWeightMemory(memory, context, attrs, aclfcAttrs, postOps, expectedWeightFormat, weiTensorInfo); } bool ACLFullyConnectedExecutor::supports(const FCConfig &config) { @@ -305,20 +80,8 @@ bool ACLFullyConnectedExecutor::supports(const FCConfig &config) { return true; } -static arm_compute::TensorShape normalizeDimsTo2D(const arm_compute::TensorShape shape) { - size_t norm_dim = std::accumulate(shape.begin() + 1, shape.end(), 1, std::multiplies()); - return arm_compute::TensorShape(shape[0], norm_dim); -} - -static void updateFCTensorsShapes(ACLShapes& aclMemoryShapes) { - aclMemoryShapes[ACLArgs::ACL_WEI] = normalizeDimsTo2D(aclMemoryShapes[ACLArgs::ACL_WEI]); - aclMemoryShapes[ACLArgs::ACL_SRC_0] = normalizeDimsTo2D(aclMemoryShapes[ACLArgs::ACL_SRC_0]); - aclMemoryShapes[ACLArgs::ACL_DST] = normalizeDimsTo2D(aclMemoryShapes[ACLArgs::ACL_DST]); - std::swap(aclMemoryShapes[ACLArgs::ACL_WEI][0], aclMemoryShapes[ACLArgs::ACL_WEI][1]); -} - void ACLFullyConnectedExecutor::updateTensorsShapes(ACLShapes& aclMemoryShapes) { - updateFCTensorsShapes(aclMemoryShapes); + acl_fc_executor::updateFCTensorsShapes(aclMemoryShapes); } arm_compute::Status ACLFullyConnectedExecutor::validateTensorsInfo(const ACLInfos & aclMemoryInfos) { @@ -358,48 +121,5 @@ ACLFunction ACLFullyConnectedExecutor::configureFunction(const ACLTensors & aclM return neFC; } -arm_compute::Status acl_fc_executor::ACLWeightsConverter::validateTensorsInfo(const ACLInfos &aclMemoryInfos) { - return arm_compute::NECast::validate(aclMemoryInfos[ACLArgs::ACL_SRC_0].get(), - aclMemoryInfos[ACLArgs::ACL_DST].get(), - arm_compute::ConvertPolicy::SATURATE); -} - -ACLFunction acl_fc_executor::ACLWeightsConverter::configureFunction(const ACLTensors &aclMemoryTensors) { - auto neCast = std::make_unique(); - neCast->configure(aclMemoryTensors[ACLArgs::ACL_SRC_0].get(), - aclMemoryTensors[ACLArgs::ACL_DST].get(), - arm_compute::ConvertPolicy::SATURATE); - return neCast; -} - -acl_fc_executor::ACLWeightFormatGenerator::ACLWeightFormatGenerator(const FCAttrs &attrs, - const PostOps &postOps, - const MemoryArgs &memory) { - initFCAttrs(attrs, aclTensorAttrs, aclfcAttrs, memory, fullyConnectedLayerInfo, postOps); -} - -void acl_fc_executor::ACLWeightFormatGenerator::updateTensorsShapes(ACLShapes &aclMemoryShapes) { - updateFCTensorsShapes(aclMemoryShapes); -} - -arm_compute::Status acl_fc_executor::ACLWeightFormatGenerator::validateTensorsInfo(const ACLInfos &aclMemoryInfos) { - if (aclfcAttrs.isConvertedWeights) { - aclMemoryInfos[ACLArgs::ACL_WEI]->set_data_type(aclMemoryInfos[ACLArgs::ACL_SRC_0]->data_type()); - } - int icTotal = aclMemoryInfos[ACLArgs::ACL_SRC_0]->dimension(0); - return arm_compute::NEFullyConnectedLayer::has_opt_impl( - expectedWeightFormat, - aclMemoryInfos[ACLArgs::ACL_SRC_0].get(), - aclMemoryInfos[ACLArgs::ACL_WEI].get(), - aclMemoryInfos[ACLArgs::ACL_BIAS].get(), - aclMemoryInfos[ACLArgs::ACL_DST].get(), - fullyConnectedLayerInfo, - arm_compute::WeightsInfo(false, 1, 1, icTotal, false, arm_compute::WeightFormat::ANY)); -} - -ACLFunction acl_fc_executor::ACLWeightFormatGenerator::configureFunction(const ACLTensors &aclMemoryTensors) { - return std::make_unique(); -} - } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected.hpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected.hpp index fcbcb1475efa15..4db9b95031c803 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected.hpp @@ -1,4 +1,4 @@ -// Copyright (C) 2018-2024 Intel Corporation +// Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // @@ -6,46 +6,11 @@ #include "acl_common_executor.hpp" #include "nodes/executors/fullyconnected_config.hpp" +#include "acl_fullyconnected_utils.hpp" namespace ov { namespace intel_cpu { -struct ACLFCAttrs { - ov::element::Type inputPrecision; - bool isConvertedWeights = false; - bool isWeightsRepacked = false; - bool weightsNonTransposed; -}; - -namespace acl_fc_executor { - -class ACLWeightsConverter : public ACLCommonExecutor { -public: - ACLWeightsConverter() = default; - void updateTensorsShapes(ACLShapes& aclMemoryShapes) override {} - arm_compute::Status validateTensorsInfo(const ACLInfos & aclMemoryInfos) override; - ACLFunction configureFunction(const ACLTensors & aclMemoryTensors) override; -}; - -class ACLWeightFormatGenerator : public ACLCommonExecutor { -public: - ACLWeightFormatGenerator(const FCAttrs& attrs, - const PostOps& postOps, - const MemoryArgs& memory); - void updateTensorsShapes(ACLShapes& aclMemoryShapes) override; - arm_compute::Status validateTensorsInfo(const ACLInfos & aclMemoryInfos) override; - ACLFunction configureFunction(const ACLTensors & aclMemoryTensors) override; - arm_compute::WeightFormat getOptImplWeightFormat() { - return expectedWeightFormat; - } -private: - arm_compute::FullyConnectedLayerInfo fullyConnectedLayerInfo; - ACLFCAttrs aclfcAttrs; - arm_compute::WeightFormat expectedWeightFormat; -}; - -} // namespace acl_fc_executor - class ACLFullyConnectedExecutor : public ACLCommonExecutor { public: ACLFullyConnectedExecutor(const FCAttrs& attrs, diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected_utils.cpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected_utils.cpp new file mode 100644 index 00000000000000..2df8c449b00c8f --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected_utils.cpp @@ -0,0 +1,379 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#include "acl_fullyconnected.hpp" +#include "acl_utils.hpp" +#include "nodes/convert.h" +#include "nodes/common/cpu_convert.h" +#include "nodes/common/cpu_memcpy.h" +#include "nodes/common/reorder_prim.h" +#include "nodes/executors/executor.hpp" +#include "nodes/executors/memory_arguments.hpp" +#include "memory_desc/cpu_memory_desc_utils.h" +#include "utils/debug_capabilities.h" + +#include +#include + +namespace ov { +namespace intel_cpu { + +VectorDims acl_fc_executor::makeDummyInputDims(const Shape& inShape, const Shape& wShape) { + const auto& weightDims = wShape.getStaticDims(); + + auto inMinDims = inShape.getMinDims(); + auto inMaxDims = inShape.getMaxDims(); + inMinDims.back() = weightDims.back(); + inMaxDims.back() = weightDims.back(); + + return MemoryDescUtils::makeDummyShape(Shape(inMinDims, inMaxDims)).getStaticDims(); +} + +VectorDims acl_fc_executor::makeDummyOutputDims(const VectorDims& inShape, const VectorDims& wShape, const size_t out_rank) { + size_t activationRank = inShape.size(); + size_t channelRank = wShape.size() - 1; + // activation weight output_shape + // NCHW CoCHW NCo + // TNC CoC TNCo + // NC CoC NCo + VectorDims outputShape(out_rank, 1); + // set Co + outputShape.back() = wShape[0]; + // set batch dims + size_t batchRank = activationRank - channelRank; + size_t startIdx = out_rank - batchRank - 1; + for (size_t i = 0; i < batchRank; i++) { + outputShape[i + startIdx] = inShape[i]; + } + + return outputShape; +} + +DnnlMemoryDescPtr acl_fc_executor::makeTransposedWeightDescriptor(const DnnlMemoryDescPtr srcDesc, + const DnnlMemoryDescPtr dstDesc) { + const auto& weiDesc = srcDesc->getDnnlDesc(); + const auto reorderedWeiDesc = dnnl::memory::desc{weiDesc.get_dims(), weiDesc.get_data_type(), dnnl::memory::format_tag::ba}; + const auto transposedWeiDesc = reorderedWeiDesc.reshape(dstDesc->getDnnlDesc().get_dims()); + + return DnnlExtensionUtils::makeDescriptor(transposedWeiDesc); +} + +ov::optional acl_fc_executor::convertWeightPrecision(MemoryPtr input, MemoryPtr output, ov::element::Type weightPrecision) { + MemoryArgs memoryArgs; + memoryArgs[ARG_SRC] = input; + memoryArgs[ARG_DST] = output; + + auto aclWeightsConverter = std::make_shared(); + if (aclWeightsConverter->update(memoryArgs)) { + aclWeightsConverter->execute(memoryArgs); + return ov::optional(memoryArgs.at(ARG_DST)); + } + + if (!node::Convert::isSupportedDesc(input->getDesc()) || + !node::Convert::isSupportedDesc(output->getDesc())) { + return {}; + } + + auto data = static_cast(input->getData()); + std::vector tmpBuff; + tmpBuff.resize(output->getSize()); + cpu_convert(data, tmpBuff.data(), DnnlExtensionUtils::DataTypeToElementType(input->getDataType()), + weightPrecision, input->getSize() / input->getDesc().getPrecision().size()); + + return ov::optional(std::make_shared(output->getPrimitive().get_engine(), + output->getDesc().cloneWithNewPrecision(weightPrecision), + tmpBuff.data())); +} + +ov::optional acl_fc_executor::reorderDataFallback(MemoryPtr input, MemoryPtr output, ExecutorContext::CPtr context) { + if (output->getDataType() == input->getDataType()) { + return {}; + } + const auto inPrc = DnnlExtensionUtils::DataTypeToElementType(input->getDataType()); + auto convertedDstMemoryDesc = output->getDesc().cloneWithNewPrecision(inPrc); + dnnl::reorder reorderWithoutConvert = getReorderPrim(context->getRuntimeCache(), + output->getPrimitive().get_engine(), + input->getPrimitive().get_desc(), + MemoryDescUtils::convertToDnnlMemoryDesc(convertedDstMemoryDesc)->getDnnlDesc()); + + if (reorderWithoutConvert && parse_impl_name(reorderWithoutConvert.get_primitive_desc()->impl()->name()) != ref_any) { + auto convertOutput = convertWeightPrecision(input, output, inPrc); + if (!convertOutput) { + return {}; + } + input = *convertOutput; + + if (reorderWithoutConvert) { + dnnl::stream loc_stream(output->getPrimitive().get_engine(), dnnl::stream::flags::in_order); + reorderWithoutConvert.execute(loc_stream, {{DNNL_ARG_FROM, input->getPrimitive()}, {DNNL_ARG_TO, output->getPrimitive()}}); + return ov::optional(output); + } + } + return {}; +} + +MemoryPtr acl_fc_executor::reorderData(DnnlMemoryDescPtr srcWeightDesc, + DnnlMemoryDescPtr dstWeightDesc, + MemoryCPtr weightsMem, + ExecutorContext::CPtr context) { + MemoryPtr input = std::make_shared(context->getEngine(), srcWeightDesc, weightsMem->getData()); + MemoryPtr output = std::make_shared(context->getEngine(), dstWeightDesc); + if (!input->getDesc().isDefined() || !output->getDesc().isDefined()) + OPENVINO_THROW("Can't reorder data with dynamic shapes"); + + if (input->getShape().hasZeroDims() || output->getShape().hasZeroDims()) { + return output; + } + + if (input->getDesc().isCompatible(output->getDesc())) { + auto srcPtr = static_cast(input->getData()); + auto dstPtr = static_cast(output->getData()); + auto copySize = output->getSize(); + cpu_memcpy(dstPtr, srcPtr, copySize); + return output; + } + + // try directly reorder + auto engine = output->getPrimitive().get_engine(); + dnnl::reorder directReorder = getReorderPrim(context->getRuntimeCache(), + engine, + input->getPrimitive().get_desc(), + output->getPrimitive().get_desc()); + + if (!directReorder || parse_impl_name(directReorder.get_primitive_desc()->impl()->name()) == ref_any) { + // try precision conversion then do the reorder + auto fallbackOutput = reorderDataFallback(input, output, context); + if (fallbackOutput) { + return *fallbackOutput; + } + } + // if precision conversion does not work then do direct reference reorder + if (directReorder) { + dnnl::stream loc_stream(engine, dnnl::stream::flags::in_order); + directReorder.execute(loc_stream, {{DNNL_ARG_FROM, input->getPrimitive()}, {DNNL_ARG_TO, output->getPrimitive()}}); + } else { + OPENVINO_THROW("Could not make onednn reorder."); + } + return output; +} + +MemoryPtr acl_fc_executor::reorderWeights(const MemoryArgs &memory, + const ExecutorContext::CPtr context, + ACLFCAttrs& aclfcAttrs, + DnnlMemoryDescPtr dnnlSrcDesc, + DnnlMemoryDescPtr dnnlDstDesc) { + auto create = [&]() { + MemoryPtr weightsMemory = memory.at(ARG_WEI); + if (aclfcAttrs.isWeightsRepacked || aclfcAttrs.isConvertedWeights) { + weightsMemory = reorderData(dnnlSrcDesc, dnnlDstDesc, memory.at(ARG_WEI), context); + DEBUG_LOG("ACLFullyConnectedExecutor: cache miss, perform packing"); + } + return weightsMemory; + }; + + auto weightCache = context->getWeightsCache(); + if (weightCache != nullptr) { + const auto& wgtDims = memory.at(ARG_WEI)->getStaticDims(); + const auto N = wgtDims[0]; + const auto K = wgtDims[1]; + std::string format = "fc_acl_" + std::to_string(N) + "_" + std::to_string(K); + const std::string string_hash = format + "_" + std::to_string(memory.at(ARG_WEI)->getSize()) + "_" + + std::to_string(reinterpret_cast(memory.at(ARG_WEI)->getData())); + DEBUG_LOG("ACLFullyConnectedExecutor: findOrCreate, string_hash: ", string_hash); + return *weightCache->findOrCreate(string_hash, create); + } + + DEBUG_LOG("ACLFullyConnectedExecutor: Weights cache is not available"); + return create(); +} + +MemoryPtr acl_fc_executor::prepareWeightMemory(const MemoryArgs &memory, + const ExecutorContext::CPtr context, + const FCAttrs &attrs, + ACLFCAttrs& aclfcAttrs, + const PostOps &postOps, + arm_compute::WeightFormat& expectedWeightFormat, + arm_compute::TensorInfo& weiTensorInfo) { + MemoryArgs memoryArgs; + memoryArgs[ARG_BIAS] = memory.at(ARG_BIAS); + memoryArgs[ARG_WEI] = memory.at(ARG_WEI); + if (memory.at(ARG_SRC_0)->getShape().isDynamic()) { + const auto& inShape = memory.at(ARG_SRC_0)->getShape(); + const auto& wShape = memory.at(ARG_WEI)->getShape(); + const auto& inDymmyDims = makeDummyInputDims(inShape, wShape); + const auto& outDymmyDims = makeDummyOutputDims(inDymmyDims, wShape.getStaticDims(), memory.at(ARG_DST)->getShape().getRank()); + memoryArgs[ARG_SRC_0] = std::make_shared(context->getEngine(), + memory.at(ARG_SRC_0)->getDescPtr()->cloneWithNewDims(inDymmyDims)); + memoryArgs[ARG_DST] = std::make_shared(context->getEngine(), + memory.at(ARG_DST)->getDescPtr()->cloneWithNewDims(outDymmyDims)); + } else { + memoryArgs[ARG_SRC_0] = memory.at(ARG_SRC_0); + memoryArgs[ARG_DST] = memory.at(ARG_DST); + } + // TODO: ACLWeightFormatGenerator should be replaced with Reorder executor + // that calls ACL NEReorder + NETranspose or dnnl::reorder depending on backend availability + auto aclWeightsRepack = std::make_shared(attrs, postOps, memoryArgs); + bool isNeededReorder = aclWeightsRepack->update(memoryArgs); + expectedWeightFormat = isNeededReorder ? aclWeightsRepack->getOptImplWeightFormat() : arm_compute::WeightFormat::UNSPECIFIED; + weiTensorInfo = aclWeightsRepack->getTensorInfo(ACLArgs::ACL_WEI); + + MemoryPtr dstMemPtr = std::make_shared(context->getEngine(), + memory.at(ARG_WEI)->getDescPtr()->cloneWithNewPrecision(aclfcAttrs.inputPrecision)); + auto dstDesc = dstMemPtr->getDescPtr(); + auto dnnlDstDesc = MemoryDescUtils::convertToDnnlMemoryDesc(dstDesc); + auto weiDesc = memory.at(ARG_WEI)->getDescPtr(); + auto dnnlSrcDesc = MemoryDescUtils::convertToDnnlMemoryDesc(weiDesc); + + if (isNeededReorder) { + dnnl::impl::dim_t o_dim = 0; + dnnl::impl::dim_t inner_dim = 1; + std::vector remaining_dims = {}; + auto weights_md_ = dnnlDstDesc->getDnnlDesc().get(); + dnnl::impl::cpu::acl::acl_utils::reorder_to_weight_format(weiTensorInfo, *weights_md_, expectedWeightFormat, + inner_dim, o_dim, remaining_dims, {}); + if (aclfcAttrs.weightsNonTransposed) { + dnnlSrcDesc = makeTransposedWeightDescriptor(dnnlSrcDesc, dnnlDstDesc); + } + aclfcAttrs.isWeightsRepacked = true; + return reorderWeights(memory, context, aclfcAttrs, dnnlSrcDesc, dnnlDstDesc); + } + if (!aclfcAttrs.weightsNonTransposed) { + dnnlDstDesc = makeTransposedWeightDescriptor(dnnlDstDesc, dnnlSrcDesc); + aclfcAttrs.isWeightsRepacked = true; + } + return reorderWeights(memory, context, aclfcAttrs, dnnlSrcDesc, dnnlDstDesc); +} + +static bool checkPostOps(const PostOps &postOps) { + // Add postops + if (!postOps.empty() && postOps.size() == 1) { + if (const auto activation = std::dynamic_pointer_cast(postOps[0])) { + if (checkActivationLayerInfo(convertToEltwiseAlgorithm(activation->type()))) { + return true; + } + } + } + return false; +} + +static void initFCAttrs(const FCAttrs &attrs, + ACLTensorAttrs& aclTensorAttrs, + ACLFCAttrs& aclfcAttrs, + const MemoryArgs &memory, + arm_compute::FullyConnectedLayerInfo& fullyConnectedLayerInfo, + const PostOps &postOps) { + aclTensorAttrs.hasLayoutTypeNHWC = memory.at(ARG_SRC)->getDescPtr()->hasLayoutType(LayoutType::nspc); + fullyConnectedLayerInfo.weights_trained_layout = getAclDataLayoutByMemoryDesc(memory.at(ARG_WEI)->getDescPtr()); + aclfcAttrs.inputPrecision = memory.at(ARG_SRC)->getDescPtr()->getPrecision(); + fullyConnectedLayerInfo.transpose_weights = false; + aclfcAttrs.weightsNonTransposed = attrs.weightsNonTransposed; + + if (checkPostOps(postOps)) { + auto activation = std::dynamic_pointer_cast(postOps[0]); + fullyConnectedLayerInfo.activation_info = getActivationLayerInfo( + convertToEltwiseAlgorithm(activation->type()), + activation->alpha(), activation->beta(), activation->gamma()); + } + + if (memory.at(ARG_SRC)->getPrecision() != memory.at(ARG_WEI)->getPrecision()) { + aclfcAttrs.isConvertedWeights = true; + } +} + +arm_compute::TensorShape acl_fc_executor::normalizeDimsTo2D(const arm_compute::TensorShape shape) { + size_t norm_dim = std::accumulate(shape.begin() + 1, shape.end(), 1, std::multiplies()); + return arm_compute::TensorShape(shape[0], norm_dim); +} + +void acl_fc_executor::updateFCTensorsShapes(ACLShapes& aclMemoryShapes) { + aclMemoryShapes[ACLArgs::ACL_WEI] = normalizeDimsTo2D(aclMemoryShapes[ACLArgs::ACL_WEI]); + aclMemoryShapes[ACLArgs::ACL_SRC_0] = normalizeDimsTo2D(aclMemoryShapes[ACLArgs::ACL_SRC_0]); + aclMemoryShapes[ACLArgs::ACL_DST] = normalizeDimsTo2D(aclMemoryShapes[ACLArgs::ACL_DST]); + std::swap(aclMemoryShapes[ACLArgs::ACL_WEI][0], aclMemoryShapes[ACLArgs::ACL_WEI][1]); +} + +arm_compute::Status acl_fc_executor::ACLWeightsConverter::validateTensorsInfo(const ACLInfos &aclMemoryInfos) { + return arm_compute::NECast::validate(aclMemoryInfos[ACLArgs::ACL_SRC_0].get(), + aclMemoryInfos[ACLArgs::ACL_DST].get(), + arm_compute::ConvertPolicy::SATURATE); +} + +ACLFunction acl_fc_executor::ACLWeightsConverter::configureFunction(const ACLTensors &aclMemoryTensors) { + auto neCast = std::make_unique(); + neCast->configure(aclMemoryTensors[ACLArgs::ACL_SRC_0].get(), + aclMemoryTensors[ACLArgs::ACL_DST].get(), + arm_compute::ConvertPolicy::SATURATE); + return neCast; +} + + +arm_compute::Status acl_fc_executor::ACLWeightsTranspose::validateTensorsInfo(const ACLInfos &aclMemoryInfos) { + return arm_compute::NETranspose::validate(aclMemoryInfos[ACLArgs::ACL_SRC_0].get(), + aclMemoryInfos[ACLArgs::ACL_DST].get()); +} + +ACLFunction acl_fc_executor::ACLWeightsTranspose::configureFunction(const ACLTensors &aclMemoryTensors) { + auto neTranspose = std::make_unique(); + neTranspose->configure(aclMemoryTensors[ACLArgs::ACL_SRC_0].get(), + aclMemoryTensors[ACLArgs::ACL_DST].get()); + return neTranspose; +} + +acl_fc_executor::ACLWeightFormatGenerator::ACLWeightFormatGenerator(const FCAttrs &attrs, + const PostOps &postOps, + const MemoryArgs &memory) { + initFCAttrs(attrs, aclTensorAttrs, aclfcAttrs, memory, fullyConnectedLayerInfo, postOps); +} + +void acl_fc_executor::ACLWeightFormatGenerator::updateTensorsShapes(ACLShapes &aclMemoryShapes) { + updateFCTensorsShapes(aclMemoryShapes); +} + +arm_compute::Status acl_fc_executor::ACLWeightFormatGenerator::validateTensorsInfo(const ACLInfos &aclMemoryInfos) { + if (aclfcAttrs.isConvertedWeights) { + aclMemoryInfos[ACLArgs::ACL_WEI]->set_data_type(aclMemoryInfos[ACLArgs::ACL_SRC_0]->data_type()); + } + return arm_compute::NEFullyConnectedLayer::has_opt_impl( + expectedWeightFormat, + aclMemoryInfos[ACLArgs::ACL_SRC_0].get(), + aclMemoryInfos[ACLArgs::ACL_WEI].get(), + aclMemoryInfos[ACLArgs::ACL_BIAS].get(), + aclMemoryInfos[ACLArgs::ACL_DST].get(), + fullyConnectedLayerInfo, + weightsInfo); +} + +ACLFunction acl_fc_executor::ACLWeightFormatGenerator::configureFunction(const ACLTensors &aclMemoryTensors) { + return std::make_unique(); +} + +arm_compute::Status acl_fc_executor::ACLWeightsReorder::validateTensorsInfo(const ACLInfos &aclMemoryInfos) { +#if defined(OPENVINO_ARCH_ARM64) + return arm_compute::NEReorderLayer::validate(aclMemoryInfos[ACLArgs::ACL_SRC_0].get(), + aclMemoryInfos[ACLArgs::ACL_DST].get(), + inWeightFormat, + outWeightFormat); +#else + return arm_compute::NECopy::validate(aclMemoryInfos[ACLArgs::ACL_SRC_0].get(), + aclMemoryInfos[ACLArgs::ACL_DST].get()); +#endif +} + +ACLFunction acl_fc_executor::ACLWeightsReorder::configureFunction(const ACLTensors &aclMemoryTensors) { +#if defined(OPENVINO_ARCH_ARM64) + auto neReorderLayer = std::make_unique(); + neReorderLayer->configure(aclMemoryTensors[ACLArgs::ACL_SRC_0].get(), + aclMemoryTensors[ACLArgs::ACL_DST].get(), + inWeightFormat, + outWeightFormat); + return neReorderLayer; +#else + auto neCopy = std::make_unique(); + neCopy->configure(aclMemoryTensors[ACLArgs::ACL_SRC_0].get(), + aclMemoryTensors[ACLArgs::ACL_DST].get()); + return neCopy; +#endif +} + +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected_utils.hpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected_utils.hpp new file mode 100644 index 00000000000000..63905ffa29a740 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected_utils.hpp @@ -0,0 +1,108 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once +#include "acl_common_executor.hpp" +#include "ov_optional.hpp" +#include "nodes/executors/fullyconnected_config.hpp" + +namespace ov { +namespace intel_cpu { + +struct ACLFCAttrs { + ov::element::Type inputPrecision; + bool isConvertedWeights = false; + bool isWeightsRepacked = false; + bool weightsNonTransposed; +}; + +namespace acl_fc_executor { + +VectorDims makeDummyInputDims(const Shape& inShape, const Shape& wShape); + +VectorDims makeDummyOutputDims(const VectorDims& inShape, const VectorDims& wShape, const size_t out_rank); + +DnnlMemoryDescPtr makeTransposedWeightDescriptor(const DnnlMemoryDescPtr srcDesc, + const DnnlMemoryDescPtr dstDesc); + +ov::optional convertWeightPrecision(MemoryPtr input, + MemoryPtr output, + ov::element::Type weightPrecision); + +ov::optional reorderDataFallback(MemoryPtr input, + MemoryPtr output, + ExecutorContext::CPtr context); + +MemoryPtr reorderData(DnnlMemoryDescPtr srcWeightDesc, + DnnlMemoryDescPtr dstWeightDesc, + MemoryCPtr weightsMem, + ExecutorContext::CPtr context); + +MemoryPtr reorderWeights(const MemoryArgs &memory, + const ExecutorContext::CPtr context, + ACLFCAttrs& aclfcAttrs, + DnnlMemoryDescPtr dnnlSrcDesc, + DnnlMemoryDescPtr dnnlDstDesc); + +MemoryPtr prepareWeightMemory(const MemoryArgs &memory, + const ExecutorContext::CPtr context, + const FCAttrs &attrs, + ACLFCAttrs& aclfcAttrs, + const PostOps &postOps, + arm_compute::WeightFormat& expectedWeightFormat, + arm_compute::TensorInfo& weiTensorInfo); + +arm_compute::TensorShape normalizeDimsTo2D(const arm_compute::TensorShape shape); + +void updateFCTensorsShapes(ACLShapes& aclMemoryShapes); + +class ACLWeightsConverter : public ACLCommonExecutor { +public: + ACLWeightsConverter() = default; + void updateTensorsShapes(ACLShapes& aclMemoryShapes) override {} + arm_compute::Status validateTensorsInfo(const ACLInfos & aclMemoryInfos) override; + ACLFunction configureFunction(const ACLTensors & aclMemoryTensors) override; +}; + +class ACLWeightsTranspose : public ACLCommonExecutor { +public: + ACLWeightsTranspose() = default; + void updateTensorsShapes(ACLShapes& aclMemoryShapes) override {} + arm_compute::Status validateTensorsInfo(const ACLInfos & aclMemoryInfos) override; + ACLFunction configureFunction(const ACLTensors & aclMemoryTensors) override; +}; + +class ACLWeightFormatGenerator : public ACLCommonExecutor { +public: + ACLWeightFormatGenerator(const FCAttrs& attrs, + const PostOps& postOps, + const MemoryArgs& memory); + void updateTensorsShapes(ACLShapes& aclMemoryShapes) override; + arm_compute::Status validateTensorsInfo(const ACLInfos & aclMemoryInfos) override; + ACLFunction configureFunction(const ACLTensors & aclMemoryTensors) override; + arm_compute::WeightFormat getOptImplWeightFormat() { + return expectedWeightFormat; + } +private: + arm_compute::FullyConnectedLayerInfo fullyConnectedLayerInfo; + arm_compute::WeightsInfo weightsInfo; + ACLFCAttrs aclfcAttrs; + arm_compute::WeightFormat expectedWeightFormat; +}; + +class ACLWeightsReorder : public ACLCommonExecutor { +public: + ACLWeightsReorder(arm_compute::WeightFormat inWeightFormat, + arm_compute::WeightFormat outWeightFormat) + : inWeightFormat(inWeightFormat), outWeightFormat(outWeightFormat) {} + void updateTensorsShapes(ACLShapes& aclMemoryShapes) override {} + arm_compute::Status validateTensorsInfo(const ACLInfos & aclMemoryInfos) override; + ACLFunction configureFunction(const ACLTensors & aclMemoryTensors) override; +private: + arm_compute::WeightFormat inWeightFormat; + arm_compute::WeightFormat outWeightFormat; +}; + +} // namespace acl_fc_executor +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_lowp_fullyconnected.cpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_lowp_fullyconnected.cpp new file mode 100644 index 00000000000000..099fee648532cf --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_lowp_fullyconnected.cpp @@ -0,0 +1,143 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "acl_lowp_fullyconnected.hpp" + +#include "acl_fullyconnected_utils.hpp" +#include "arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h" +#include "memory_desc/cpu_memory_desc_utils.h" +#include "nodes/common/cpu_convert.h" +#include "nodes/executors/acl/acl_utils.hpp" +#include "nodes/executors/executor.hpp" +#include "nodes/executors/memory_arguments.hpp" +#include "nodes/executors/debug_messages.hpp" +#include "nodes/executors/implementation_utils.hpp" +#include "utils/debug_capabilities.h" + +namespace ov { +namespace intel_cpu { + +static bool checkPostOps(const PostOps &postOps) { + if (postOps.empty()) { + return true; + } + + if (postOps.size() != 1) { + return false; + } + + const auto activation = std::dynamic_pointer_cast(postOps[0]); + return checkActivationLayerInfo(convertToEltwiseAlgorithm(activation->type())); +} + +static void initFCAttrs(const FCAttrs &attrs, + ACLTensorAttrs& aclTensorAttrs, + ACLFCAttrs& aclfcAttrs, + const MemoryArgs &memory, + arm_compute::GEMMInfo& fullyConnectedLayerInfo, + const PostOps &postOps) { + aclTensorAttrs.hasLayoutTypeNHWC = memory.at(ARG_SRC)->getDescPtr()->hasLayoutType(LayoutType::nspc); + aclfcAttrs.inputPrecision = memory.at(ARG_SRC)->getDescPtr()->getPrecision(); + aclfcAttrs.weightsNonTransposed = attrs.weightsNonTransposed; + + if (!postOps.empty()) { + auto activation = std::dynamic_pointer_cast(postOps[0]); + fullyConnectedLayerInfo.set_activation_info(getActivationLayerInfo( + convertToEltwiseAlgorithm(activation->type()), + activation->alpha(), activation->beta(), activation->gamma())); + } + + if (memory.at(ARG_SRC)->getPrecision() != memory.at(ARG_WEI)->getPrecision()) { + aclfcAttrs.isConvertedWeights = true; + } +} + +ACLLowpFullyConnectedExecutor::ACLLowpFullyConnectedExecutor(const FCAttrs &attrs, + const PostOps &postOps, + const MemoryArgs &memory, + const ExecutorContext::CPtr& context) : dequantizationScales(attrs.dequantizationScales) { + initFCAttrs(attrs, aclTensorAttrs, aclfcAttrs, memory, gemmInfo, postOps); + packedWeights = acl_fc_executor::prepareWeightMemory(memory, context, attrs, aclfcAttrs, postOps, expectedWeightFormat, weiTensorInfo); +} + +bool ACLLowpFullyConnectedExecutor::supports(const FCConfig &config) { + const auto src0 = srcType(config); + const auto src1 = weiType(config); + const auto dst = dstType(config); + if ((src0 != ov::element::i8) || (src1 != ov::element::i8) || (dst != ov::element::f32)) { + return false; + } + + VERIFY(checkPostOps(config.postOps), UNSUPPORTED_TYPE_OF_POSTOPS); + VERIFY(one_of(srcRank(config), 2U, 3U, 4U), UNSUPPORTED_SRC_RANK); + VERIFY(one_of(weiRank(config), 2U, 3U, 4U), UNSUPPORTED_WEI_RANK); + VERIFY(static_cast(config.attrs).dequantizationScales.size() <= 1, UNSUPPORTED_PER_CHANNEL_QUANTIZATION); + return true; +} + +void ACLLowpFullyConnectedExecutor::updateTensorsShapes(ACLShapes& aclMemoryShapes) { + acl_fc_executor::updateFCTensorsShapes(aclMemoryShapes); +} + +arm_compute::Status ACLLowpFullyConnectedExecutor::validateTensorsInfo(const ACLInfos & aclMemoryInfos) { + auto &tensor_info = aclMemoryInfos[ACLArgs::ACL_SRC_0]; + if (dequantizationScales.empty()) { + tensor_info->set_quantization_info(arm_compute::QuantizationInfo(1.f)); + } else { + tensor_info->set_quantization_info(arm_compute::QuantizationInfo(dequantizationScales[0])); + } + + auto& tensor_info_weights = aclMemoryInfos[ACLArgs::ACL_WEI]; + tensor_info_weights->set_quantization_info(arm_compute::QuantizationInfo(1.f)); + + const auto matMulValid = arm_compute::NEGEMMLowpMatrixMultiplyCore::validate( + aclMemoryInfos[ACLArgs::ACL_SRC_0].get(), + aclMemoryInfos[ACLArgs::ACL_WEI].get(), + aclMemoryInfos[ACLArgs::ACL_BIAS].get(), + aclMemoryInfos[ACLArgs::ACL_DST].get(), + gemmInfo); + return matMulValid; +} + +ACLFunction ACLLowpFullyConnectedExecutor::configureFunction(const ACLTensors & aclMemoryTensors) { + auto gemm = std::make_unique(); + gemm->configure( + aclMemoryTensors[ACLArgs::ACL_SRC_0].get(), + aclMemoryTensors[ACLArgs::ACL_WEI].get(), + aclMemoryTensors[ACLArgs::ACL_BIAS].get(), + aclMemoryTensors.at(ACLArgs::ACL_DST).get(), + gemmInfo); + + if (aclfcAttrs.isConvertedWeights || !aclfcAttrs.weightsNonTransposed) { + aclTensorAttrs.memoryUsageIndicator[ACLArgs::ACL_WEI] = false; + aclMemoryTensors[ACLArgs::ACL_WEI]->allocator()->import_memory(packedWeights->getData()); + } + return gemm; +} + +std::shared_ptr ACLLowpFullyConnectedExecutor::initTensorInfo( + const arm_compute::TensorShape& tensorShape, + const arm_compute::DataType& dataType, + const arm_compute::DataLayout& dataLayout) { + arm_compute::DataType result; + switch (dataType) { + case arm_compute::DataType::S8: { + result = arm_compute::DataType::QASYMM8_SIGNED; + break; + } + case arm_compute::DataType::U8: { + result = arm_compute::DataType::QASYMM8; + break; + } + default: { + result = dataType; + break; + } + } + + return ACLCommonExecutor::initTensorInfo(tensorShape, result, dataLayout); +} + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_lowp_fullyconnected.hpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_lowp_fullyconnected.hpp new file mode 100644 index 00000000000000..96dac857c47441 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_lowp_fullyconnected.hpp @@ -0,0 +1,51 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "acl_common_executor.hpp" +#include "nodes/executors/fullyconnected_config.hpp" +#include "acl_fullyconnected_utils.hpp" + +namespace ov { +namespace intel_cpu { + +class ACLLowpFullyConnectedExecutor : public ACLCommonExecutor { +public: + ACLLowpFullyConnectedExecutor(const FCAttrs& attrs, + const PostOps& postOps, + const MemoryArgs& memory, + const ExecutorContext::CPtr& context); + + static bool supports(const FCConfig& config); + + void updateTensorsShapes(ACLShapes& aclMemoryShapes) override; + + arm_compute::Status validateTensorsInfo(const ACLInfos & aclMemoryInfos) override; + + ACLFunction configureFunction(const ACLTensors & aclMemoryTensors) override; + + impl_desc_type implType() const override { + return impl_desc_type::gemm_acl; + } + +protected: + std::shared_ptr initTensorInfo(const arm_compute::TensorShape& tensorShape, + const arm_compute::DataType& dataType, + const arm_compute::DataLayout& dataLayout) override; + +private: + arm_compute::GEMMInfo gemmInfo; + arm_compute::WeightFormat expectedWeightFormat; + arm_compute::TensorInfo weiTensorInfo; + + MemoryCPtr packedWeights; + ACLFCAttrs aclfcAttrs; + std::vector dequantizationScales; +}; + +using ACLLowpFullyConnectedExecutorPtr = std::shared_ptr; + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_utils.hpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_utils.hpp index a3d151192e601b..e20ba4f9283077 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_utils.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_utils.hpp @@ -110,7 +110,7 @@ inline int axisCast(const std::size_t axis, const std::size_t shapeSize, ACLAxis * @param precision precision to be converted * @return ComputeLibrary DataType or UNKNOWN if precision is not mapped to DataType */ -inline arm_compute::DataType precisionToAclDataType(ov::element::Type precision) { +inline arm_compute::DataType precisionToAclDataType(const ov::element::Type& precision) { switch (precision) { case ov::element::i8: return arm_compute::DataType::S8; case ov::element::u8: return arm_compute::DataType::U8; diff --git a/src/plugins/intel_cpu/src/nodes/executors/debug_messages.hpp b/src/plugins/intel_cpu/src/nodes/executors/debug_messages.hpp index 26ae6ace59631b..206842014365a0 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/debug_messages.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/debug_messages.hpp @@ -18,6 +18,7 @@ #define UNSUPPORTED_DST_RANK " unsupported dst rank" #define UNSUPPORTED_DST_STRIDES " unsupported dst strides" #define HEURISTICS_MISMATCH " heuristics mismatch" +#define UNSUPPORTED_PER_CHANNEL_QUANTIZATION " unsupported per-channel quantization" #define VERIFY(condition, ...) \ do { \ diff --git a/src/plugins/intel_cpu/src/nodes/executors/fullyconnected_implementations.cpp b/src/plugins/intel_cpu/src/nodes/executors/fullyconnected_implementations.cpp index 5a8b1ef78b6dbb..ff965bc3fdc858 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/fullyconnected_implementations.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/fullyconnected_implementations.cpp @@ -30,6 +30,7 @@ #if defined(OV_CPU_WITH_ACL) #include "nodes/executors/acl/acl_fullyconnected.hpp" +#include "nodes/executors/acl/acl_lowp_fullyconnected.hpp" #endif #if defined(OV_CPU_WITH_SHL) @@ -87,6 +88,11 @@ static const TypeMapping aclFCTypeMapping { {{_any, _any, _any, _any}, pt(just(), just(), just(), just())} }; +static const TypeMapping aclLowpFCTypeMapping { + // {src, wei, bia, dst} pt + {{_i8, _i8, _any, _f32}, pt(bypass(), bypass(), just(), bypass())} +}; + static const MappingNotation dnnlConvolutionMappingNotation { ARG_SRC, ARG_WEI, ARG_BIAS, ARG_DST }; @@ -370,6 +376,35 @@ const std::vector>& getImplementations() { const ExecutorContext::CPtr context) { return std::make_shared(attrs, postOps, memory, context); }) + OV_CPU_INSTANCE_ACL( + "fullyconnected_acl_lowp", + ExecutorType::Acl, + OperationType::FullyConnected, + ShapeTolerance::Agnostic, + // supports + [](const FCConfig& config) -> bool { + VERIFY(noSparseDecompression(config), UNSUPPORTED_SPARSE_WEIGHTS); + VERIFY(noWeightsDecompression(config), UNSUPPORTED_WEIGHTS_DECOMPRESSION); + return ACLLowpFullyConnectedExecutor::supports(config); + }, + // requiresFallback + [](const FCConfig& config) -> ov::optional> { + return requiresFallbackCommon(config, + aclLowpFCTypeMapping, + aclFCLayoutConfig, + aclFullyConnectedMappingNotation); + }, + // acceptsShapes + [](const MemoryArgs& memory) -> bool { + return true; + }, + // create + [](const FCAttrs& attrs, + const PostOps& postOps, + const MemoryArgs& memory, + const ExecutorContext::CPtr context) { + return std::make_shared(attrs, postOps, memory, context); + }) OV_CPU_INSTANCE_SHL( "fullyconnected_shl", ExecutorType::Shl, diff --git a/src/plugins/intel_cpu/src/transformations/snippets/aarch64/pass/snippets_mark_skipped.cpp b/src/plugins/intel_cpu/src/transformations/snippets/aarch64/pass/snippets_mark_skipped.cpp index c38d088ef95e7b..7343497e855f5b 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/aarch64/pass/snippets_mark_skipped.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/aarch64/pass/snippets_mark_skipped.cpp @@ -7,6 +7,7 @@ #include "snippets/op/subgraph.hpp" #include "snippets/utils/utils.hpp" +#include "low_precision/rt_info/bias_attribute.hpp" #include "transformations/utils/utils.hpp" #include "transformations/utils.hpp" #include "utils/general_utils.h" @@ -227,6 +228,11 @@ bool SnippetsMarkSkipped::run_on_model(const std::shared_ptr &m) { for (auto &node : m->get_ordered_ops()) { if (is_skipped_op(node)) continue; + + if (ov::marked_as_bias(node)) { + SetNodeFusingType(node, NodeFusingType::FusedWithMisc); + } + if (isSuitableConvolutionParent(node)) { // Initiate fusing chain SetNodeFusingType(node, NodeFusingType::FusedWithConvolution); diff --git a/src/plugins/intel_cpu/tests/functional/CMakeLists.txt b/src/plugins/intel_cpu/tests/functional/CMakeLists.txt index 3092356e1189b6..40a4fc4a1739c4 100644 --- a/src/plugins/intel_cpu/tests/functional/CMakeLists.txt +++ b/src/plugins/intel_cpu/tests/functional/CMakeLists.txt @@ -59,6 +59,7 @@ if(NOT (ARM OR AARCH64)) ${CMAKE_CURRENT_SOURCE_DIR}/custom/single_layer_tests/instances/arm ${CMAKE_CURRENT_SOURCE_DIR}/custom/subgraph_tests/src/arm ${CMAKE_CURRENT_SOURCE_DIR}/shared_tests_instances/snippets/arm + ${CMAKE_CURRENT_SOURCE_DIR}/shared_tests_instances/low_precision_transformations/aarch64 ${CMAKE_CURRENT_SOURCE_DIR}/utils/arm) else() # temporary disable all custom tests for ARM @@ -81,7 +82,8 @@ endif() if(NOT X86_64) list(APPEND EXCLUDED_SOURCE_PATHS ${CMAKE_CURRENT_SOURCE_DIR}/custom/single_layer_tests/instances/x64 - ${CMAKE_CURRENT_SOURCE_DIR}/custom/subgraph_tests/src/x64) + ${CMAKE_CURRENT_SOURCE_DIR}/custom/subgraph_tests/src/x64 + ${CMAKE_CURRENT_SOURCE_DIR}/shared_tests_instances/low_precision_transformations/x64) endif() ov_add_test_target( diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/aarch64/fully_connected_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/aarch64/fully_connected_transformation.cpp new file mode 100644 index 00000000000000..f764ea3ca1156e --- /dev/null +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/aarch64/fully_connected_transformation.cpp @@ -0,0 +1,103 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "low_precision_transformations/fully_connected_transformation.hpp" +#include "common_test_utils/test_constants.hpp" + +using namespace LayerTestsDefinitions; + +namespace { +const std::vector netPrecisions = { + ov::element::f32 +}; + +const std::vector shapes = { + { + ov::PartialShape{ 1, 16 }, + ov::PartialShape{ 16, 8 }, + false, + false + }, + { + ov::PartialShape{ 1, 1, 16 }, + ov::PartialShape{ 1, 16, 8 }, + false, + false + }, + { + ov::PartialShape{ 1, 16 }, + ov::PartialShape{ 8, 16 }, + false, + true + }, + { + ov::PartialShape{ 1, 1, 16 }, + ov::PartialShape{ 1, 8, 16 }, + false, + true + }, + { + ov::PartialShape{ 16, 1 }, + ov::PartialShape{ 16, 8 }, + true, + false + }, + { + ov::PartialShape{ 1, 16, 1 }, + ov::PartialShape{ 1, 16, 8 }, + true, + false + }, + { + ov::PartialShape{ 16, 1 }, + ov::PartialShape{ 8, 16 }, + true, + true + }, + { + ov::PartialShape{ 1, 16, 1 }, + ov::PartialShape{ 1, 8, 16 }, + true, + true + } +}; + +const std::vector trasformationParamValues = { + LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParams() +}; + +const std::vector activations = { + { + true, // activation + false, // per-channel + true, // FQ + "fullyConnected,fullyConnected/DequantizationMultiply,relu" + }, + { + false, // activation + false, // per-channel + true, // FQ + "fullyConnected_original,fullyConnected" + }, + { + true, // activation + true, // per-channel + false, // FQ + "fullyConnected,relu_original" // dequantization is not supported for per-channel quantization + }, +}; + +INSTANTIATE_TEST_SUITE_P(smoke_LPT, FullyConnectedTransformation, + ::testing::Combine( + ::testing::ValuesIn(netPrecisions), + ::testing::ValuesIn(shapes), + ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::ValuesIn(trasformationParamValues), + ::testing::ValuesIn({ov::element::i8 /*, ov::element::u8*/}), + ::testing::ValuesIn(activations), + ::testing::Values("gemm_acl_i8")), + FullyConnectedTransformation::getTestCaseName); +} // namespace diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/add_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/add_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/add_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/add_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/assign_and_read_value_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/assign_and_read_value_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/assign_and_read_value_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/assign_and_read_value_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/batch_to_space_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/batch_to_space_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/batch_to_space_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/batch_to_space_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/clamp_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/clamp_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/clamp_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/clamp_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/concat_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/concat_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/concat_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/concat_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/concat_with_child_and_output.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/concat_with_child_and_output.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/concat_with_child_and_output.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/concat_with_child_and_output.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/concat_with_different_precision_on_children.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/concat_with_different_precision_on_children.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/concat_with_different_precision_on_children.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/concat_with_different_precision_on_children.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/concat_with_intermediate_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/concat_with_intermediate_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/concat_with_intermediate_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/concat_with_intermediate_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/concat_with_neighbors_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/concat_with_neighbors_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/concat_with_neighbors_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/concat_with_neighbors_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/concat_with_split_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/concat_with_split_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/concat_with_split_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/concat_with_split_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/convolution_backprop_data_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/convolution_backprop_data_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/convolution_backprop_data_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/convolution_backprop_data_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/convolution_qdq_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/convolution_qdq_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/convolution_qdq_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/convolution_qdq_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/convolution_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/convolution_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/convolution_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/convolution_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/depth_to_space_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/depth_to_space_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/depth_to_space_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/depth_to_space_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/elementwise_branch_selection_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/elementwise_branch_selection_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/elementwise_branch_selection_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/elementwise_branch_selection_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/eliminate_fake_quantize_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/eliminate_fake_quantize_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/eliminate_fake_quantize_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/eliminate_fake_quantize_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/fq_and_avg_pool_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/fq_and_avg_pool_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/fq_and_avg_pool_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/fq_and_avg_pool_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/fq_and_max_pool_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/fq_and_max_pool_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/fq_and_max_pool_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/fq_and_max_pool_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/fq_and_two_output_branches_with_convolution.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/fq_and_two_output_branches_with_convolution.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/fq_and_two_output_branches_with_convolution.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/fq_and_two_output_branches_with_convolution.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/fq_precision_selection_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/fq_precision_selection_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/fq_precision_selection_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/fq_precision_selection_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/fq_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/fq_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/fq_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/fq_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/fq_with_dq_not_optimal_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/fq_with_dq_not_optimal_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/fq_with_dq_not_optimal_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/fq_with_dq_not_optimal_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/fully_connected_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/fully_connected_transformation.cpp similarity index 61% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/fully_connected_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/fully_connected_transformation.cpp index 0368215a5cf5a4..e351fb607c6e8b 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/fully_connected_transformation.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/fully_connected_transformation.cpp @@ -14,7 +14,7 @@ const std::vector netPrecisions = { ov::element::f32 }; -const std::vector shapes = { +const std::vector shapes = { { ov::PartialShape{ 1, 16 }, ov::PartialShape{ 16, 8 }, @@ -39,11 +39,35 @@ const std::vector trasform LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParams() }; +const std::vector activations = { + { + true, // activation + false, // per-channel + true, // FQ + "fullyconnected,relu_original,relu" + }, + { + false, // activation + false, // per-channel + true, // FQ + "fullyConnected_original,fullyConnected" + }, + { + true, // activation + true, // per-channel + false, // FQ + "fullyconnected,relu_original,relu" + }, +}; + INSTANTIATE_TEST_SUITE_P(smoke_LPT, FullyConnectedTransformation, ::testing::Combine( ::testing::ValuesIn(netPrecisions), ::testing::ValuesIn(shapes), ::testing::Values(ov::test::utils::DEVICE_CPU), - ::testing::ValuesIn(trasformationParamValues)), + ::testing::ValuesIn(trasformationParamValues), + ::testing::ValuesIn({ov::element::i8/*, ov::element::u8*/}), + ::testing::ValuesIn(activations), + ::testing::Values("")), FullyConnectedTransformation::getTestCaseName); } // namespace diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/fuse_convert_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/fuse_convert_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/fuse_convert_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/fuse_convert_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/fuse_dequantize_to_fq_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/fuse_dequantize_to_fq_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/fuse_dequantize_to_fq_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/fuse_dequantize_to_fq_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/fuse_fq_and_scale_shift_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/fuse_fq_and_scale_shift_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/fuse_fq_and_scale_shift_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/fuse_fq_and_scale_shift_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/fuse_multiply_to_fq_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/fuse_multiply_to_fq_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/fuse_multiply_to_fq_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/fuse_multiply_to_fq_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/fuse_subtract_to_fq_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/fuse_subtract_to_fq_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/fuse_subtract_to_fq_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/fuse_subtract_to_fq_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/gather_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/gather_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/gather_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/gather_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/gemm_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/gemm_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/gemm_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/gemm_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/group_convolution_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/group_convolution_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/group_convolution_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/group_convolution_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/groupconvolution_qdq_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/groupconvolution_qdq_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/groupconvolution_qdq_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/groupconvolution_qdq_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/interpolate_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/interpolate_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/interpolate_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/interpolate_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/mat_mul_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/mat_mul_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/mat_mul_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/mat_mul_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/mat_mul_with_constant_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/mat_mul_with_constant_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/mat_mul_with_constant_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/mat_mul_with_constant_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/mat_mul_with_optimized_constant_fq.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/mat_mul_with_optimized_constant_fq.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/mat_mul_with_optimized_constant_fq.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/mat_mul_with_optimized_constant_fq.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/move_fake_quantize_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/move_fake_quantize_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/move_fake_quantize_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/move_fake_quantize_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/multiply_to_group_convolution.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/multiply_to_group_convolution.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/multiply_to_group_convolution.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/multiply_to_group_convolution.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/multiply_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/multiply_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/multiply_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/multiply_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/multiply_with_one_parent.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/multiply_with_one_parent.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/multiply_with_one_parent.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/multiply_with_one_parent.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/mvn_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/mvn_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/mvn_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/mvn_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/normalize_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/normalize_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/normalize_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/normalize_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/output_layers.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/output_layers.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/output_layers.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/output_layers.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/output_layers_concat.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/output_layers_concat.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/output_layers_concat.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/output_layers_concat.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/output_layers_concat_multi_channel.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/output_layers_concat_multi_channel.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/output_layers_concat_multi_channel.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/output_layers_concat_multi_channel.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/pad_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/pad_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/pad_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/pad_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/prelu_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/prelu_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/prelu_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/prelu_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/pull_reshape_through_dequantization.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/pull_reshape_through_dequantization.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/pull_reshape_through_dequantization.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/pull_reshape_through_dequantization.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/recurrent_cell_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/recurrent_cell_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/recurrent_cell_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/recurrent_cell_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/reduce_max_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/reduce_max_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/reduce_max_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/reduce_max_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/reduce_mean_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/reduce_mean_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/reduce_mean_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/reduce_mean_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/reduce_min_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/reduce_min_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/reduce_min_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/reduce_min_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/reduce_sum_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/reduce_sum_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/reduce_sum_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/reduce_sum_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/relu_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/relu_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/relu_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/relu_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/reshape_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/reshape_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/reshape_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/reshape_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/shuffle_channels_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/shuffle_channels_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/shuffle_channels_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/shuffle_channels_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/space_to_batch_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/space_to_batch_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/space_to_batch_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/space_to_batch_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/split_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/split_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/split_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/split_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/squeeze_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/squeeze_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/squeeze_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/squeeze_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/strided_slice_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/strided_slice_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/strided_slice_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/strided_slice_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/subtract_multiply_to_multiply_add.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/subtract_multiply_to_multiply_add.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/subtract_multiply_to_multiply_add.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/subtract_multiply_to_multiply_add.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/subtract_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/subtract_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/subtract_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/subtract_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/transpose_after_matmul_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/transpose_after_matmul_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/transpose_after_matmul_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/transpose_after_matmul_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/transpose_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/transpose_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/transpose_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/transpose_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/unsqueeze_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/unsqueeze_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/unsqueeze_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/unsqueeze_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/variadic_split_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/variadic_split_transformation.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/variadic_split_transformation.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/x64/variadic_split_transformation.cpp diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp index 90820d550df179..d808016a4946bd 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp @@ -479,8 +479,10 @@ std::vector disabledTestPatterns() { retVector.emplace_back(R"(smoke_TestsDFT_(1|2|3|4)d/DFTLayerTest.Inference.*)"); // Issue 88764, 91647, 108802: accuracy issue retVector.emplace_back(R"(MultipleLSTMCellTest/MultipleLSTMCellTest.CompareWithRefs.*)"); +#if !defined(OPENVINO_ARCH_ARM64) // int8 / code-generation specific retVector.emplace_back(R"(smoke_LPT.*)"); +#endif // Compressed weights are not supported retVector.emplace_back(R"(smoke_MatMulCompressedWeights.*)"); retVector.emplace_back(R"(smoke_MatMulSharedCompressedWeights.*)"); diff --git a/src/plugins/intel_gpu/tests/functional/shared_tests_instances/low_precision_transformations/fully_connected_transformation.cpp b/src/plugins/intel_gpu/tests/functional/shared_tests_instances/low_precision_transformations/fully_connected_transformation.cpp index 71978473696a0b..4146e370cd1165 100644 --- a/src/plugins/intel_gpu/tests/functional/shared_tests_instances/low_precision_transformations/fully_connected_transformation.cpp +++ b/src/plugins/intel_gpu/tests/functional/shared_tests_instances/low_precision_transformations/fully_connected_transformation.cpp @@ -15,7 +15,7 @@ const std::vector netPrecisions = { ov::element::f16 }; -const std::vector shapes = { +const std::vector shapes = { { { 1, 16 }, { 16, 8 }, @@ -40,11 +40,35 @@ const std::vector trasform LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParams() }; +const std::vector activations = { + { + true, // activation + false, // per-channel + true, // FQ + "" + }, + { + false, // activation + false, // per-channel + true, // FQ + "" + }, + { + true, // activation + true, // per-channel + false, // FQ + "" + }, +}; + INSTANTIATE_TEST_SUITE_P(smoke_LPT, FullyConnectedTransformation, ::testing::Combine( ::testing::ValuesIn(netPrecisions), ::testing::ValuesIn(shapes), ::testing::Values(ov::test::utils::DEVICE_GPU), - ::testing::ValuesIn(trasformationParamValues)), + ::testing::ValuesIn(trasformationParamValues), + ::testing::ValuesIn({ov::element::i8/*, ov::element::u8*/}), + ::testing::ValuesIn(activations), + ::testing::Values("")), FullyConnectedTransformation::getTestCaseName); } // namespace diff --git a/src/tests/functional/plugin/shared/include/low_precision_transformations/fully_connected_transformation.hpp b/src/tests/functional/plugin/shared/include/low_precision_transformations/fully_connected_transformation.hpp index 731ce44224e33b..232df8f7dd49c4 100644 --- a/src/tests/functional/plugin/shared/include/low_precision_transformations/fully_connected_transformation.hpp +++ b/src/tests/functional/plugin/shared/include/low_precision_transformations/fully_connected_transformation.hpp @@ -8,7 +8,7 @@ #include #include "shared_test_classes/base/low_precision_transformations/layer_transformation.hpp" -class MatMulShapes { +class FullyConnectedShapes { public: ov::PartialShape inputA; ov::PartialShape inputB; @@ -16,11 +16,22 @@ class MatMulShapes { bool transposeB; }; +class FullyConnectedParams { +public: + bool activation; + bool perChannelWeights; + bool fq; + std::string originalLayersNames; +}; + typedef std::tuple< ov::element::Type, - MatMulShapes, + FullyConnectedShapes, std::string, - ov::pass::low_precision::LayerTransformation::Params> FullyConnectedTransformationParams; + ov::pass::low_precision::LayerTransformation::Params, + ov::element::Type, + FullyConnectedParams, + std::string> FullyConnectedTransformationParams; namespace LayerTestsDefinitions { diff --git a/src/tests/functional/plugin/shared/src/low_precision_transformations/fully_connected_transformation.cpp b/src/tests/functional/plugin/shared/src/low_precision_transformations/fully_connected_transformation.cpp index f72f6d90333613..5de4424019ae4e 100644 --- a/src/tests/functional/plugin/shared/src/low_precision_transformations/fully_connected_transformation.cpp +++ b/src/tests/functional/plugin/shared/src/low_precision_transformations/fully_connected_transformation.cpp @@ -5,38 +5,51 @@ #include "low_precision_transformations/fully_connected_transformation.hpp" #include +#include #include #include -#include #include "common_test_utils/common_utils.hpp" +#include "openvino/util/common_util.hpp" #include "ov_lpt_models/mat_mul.hpp" namespace LayerTestsDefinitions { std::string FullyConnectedTransformation::getTestCaseName(const testing::TestParamInfo& obj) { ov::element::Type precision; - MatMulShapes shapes; + FullyConnectedShapes shapes; std::string targetDevice; ov::pass::low_precision::LayerTransformation::Params params; - std::tie(precision, shapes, targetDevice, params) = obj.param; + ov::element::Type weightsType; + FullyConnectedParams activation; + std::string expectedPrimitiveType; + std::tie(precision, shapes, targetDevice, params, weightsType, activation, expectedPrimitiveType) = obj.param; std::ostringstream result; result << - get_test_case_name_by_params(precision, shapes.inputA, targetDevice, params) << - shapes.inputB << "_" << - shapes.transposeA << "_" << - shapes.transposeB; + get_test_case_name_by_params(precision, shapes.inputA, targetDevice, params) << + shapes.inputB << "_" << + "transposeA=" << shapes.transposeA << "_" << + "transposeB=" << shapes.transposeB << "_" << + weightsType << "_" << + "Activation=" << activation.activation << "_" << + "perChannelWeights=" << activation.perChannelWeights << "_" << + "FQ=" << activation.fq << "_" << + activation.originalLayersNames << "_" << + expectedPrimitiveType; return result.str(); } void FullyConnectedTransformation::SetUp() { ov::element::Type precision; - MatMulShapes shapes; + FullyConnectedShapes shapes; ov::pass::low_precision::LayerTransformation::Params params; - std::tie(precision, shapes, targetDevice, params) = this->GetParam(); + ov::element::Type weightsType; + FullyConnectedParams activation; + std::string expectedPrimitiveType; + std::tie(precision, shapes, targetDevice, params, weightsType, activation, expectedPrimitiveType) = this->GetParam(); init_input_shapes({ shapes.inputA, shapes.inputB }); @@ -45,12 +58,32 @@ void FullyConnectedTransformation::SetUp() { shapes.inputA, shapes.inputB, shapes.transposeA, - shapes.transposeB); + shapes.transposeB, + weightsType == ov::element::i8, + activation.perChannelWeights, + activation.activation, + activation.fq); } TEST_P(FullyConnectedTransformation, CompareWithRefImpl) { SKIP_IF_CURRENT_TEST_IS_DISABLED(); run(); + + const auto& activation = std::get<5>(GetParam()); + if (!activation.originalLayersNames.empty()) { + const auto originalLayersNames = get_property_by_type("FullyConnected", "originalLayersNames"); + EXPECT_EQ(ov::util::to_lower(activation.originalLayersNames), originalLayersNames); + } + + const auto& actualPrecision = get_runtime_precision_by_type("FullyConnected"); + const auto expectedPrecision = std::get<4>(GetParam()); + EXPECT_EQ(actualPrecision, expectedPrecision.to_string()); + + const auto& expectedPrimitiveType = std::get<6>(GetParam()); + if (!expectedPrimitiveType.empty()) { + const std::string actualPrimitiveType = get_property_by_type("FullyConnected", "primitiveType"); + EXPECT_EQ(expectedPrimitiveType, actualPrimitiveType); + } }; } // namespace LayerTestsDefinitions diff --git a/src/tests/functional/shared_test_classes/include/shared_test_classes/base/low_precision_transformations/layer_transformation.hpp b/src/tests/functional/shared_test_classes/include/shared_test_classes/base/low_precision_transformations/layer_transformation.hpp index 10a70f3bc04ee0..b9da9ff8af4833 100644 --- a/src/tests/functional/shared_test_classes/include/shared_test_classes/base/low_precision_transformations/layer_transformation.hpp +++ b/src/tests/functional/shared_test_classes/include/shared_test_classes/base/low_precision_transformations/layer_transformation.hpp @@ -49,6 +49,8 @@ class LayerTransformation : virtual public ov::test::SubgraphBaseTest { // get runtime precision by operation type std::string get_runtime_precision_by_type(const std::string& layerType); + std::string get_property_by_type(const std::string& layerTypeName, const std::string& propertyName); + // get runtime precision by operation friendly name which can be fused std::string get_runtime_precision_by_fused_name(const std::string& layerName); diff --git a/src/tests/functional/shared_test_classes/src/base/low_precision_transformations/layer_transformation.cpp b/src/tests/functional/shared_test_classes/src/base/low_precision_transformations/layer_transformation.cpp index 49e7b0581cae76..a1781c5826e590 100644 --- a/src/tests/functional/shared_test_classes/src/base/low_precision_transformations/layer_transformation.cpp +++ b/src/tests/functional/shared_test_classes/src/base/low_precision_transformations/layer_transformation.cpp @@ -7,6 +7,7 @@ #include #include +#include "openvino/util/common_util.hpp" namespace LayerTestsUtils { ov::pass::low_precision::LayerTransformation::Params LayerTransformationParamsNGraphFactory::createParamsU8I8AndI8() { @@ -60,15 +61,15 @@ std::string LayerTransformation::get_test_case_name_by_params( namespace { template -std::string find_node_by_runtime_precision(const ov::CompiledModel& execNet, IsNodeF is_node_f) { +std::string find_node_by_runtime_property(const ov::CompiledModel& execNet, IsNodeF is_node_f, const std::string& propertyName = "runtimePrecision") { const std::shared_ptr& execFunction = execNet.get_runtime_model(); for (const auto& op : execFunction->get_ops()) { if (!is_node_f(op)) continue; const ov::RTMap& rtInfo = op->get_rt_info(); - const auto& it = rtInfo.find("runtimePrecision"); - OPENVINO_ASSERT(it != rtInfo.end(), "Runtime precision is not found for node: ", op->get_friendly_name()); + const auto& it = rtInfo.find(propertyName); + OPENVINO_ASSERT(it != rtInfo.end(), "Runtime property \"", propertyName, "\" is not found for node: ", op->get_friendly_name()); return it->second.as(); } @@ -80,7 +81,7 @@ std::string LayerTransformation::get_runtime_precision(const std::string& layerN auto is_node_f = [layerName](const std::shared_ptr& op) { return op->get_friendly_name() == layerName; }; - return find_node_by_runtime_precision(compiledModel, is_node_f); + return find_node_by_runtime_property(compiledModel, is_node_f); } std::string LayerTransformation::get_runtime_precision_by_type(const std::string& layerType) { @@ -91,7 +92,18 @@ std::string LayerTransformation::get_runtime_precision_by_type(const std::string OPENVINO_ASSERT(typeIt != rtInfo.end(), "Layer is not found for type: ", layerType); return typeIt->second.as() == layerType; }; - return find_node_by_runtime_precision(compiledModel, is_node_f); + return find_node_by_runtime_property(compiledModel, is_node_f); +} + +std::string LayerTransformation::get_property_by_type(const std::string& layerTypeName, const std::string& propertyName) { + auto is_node_f = [&layerTypeName](const std::shared_ptr& op) { + const auto& rtInfo = op->get_rt_info(); + const auto& typeIt = rtInfo.find("layerType"); + + OPENVINO_ASSERT(typeIt != rtInfo.end(), "Layer is not found for type: ", layerTypeName); + return typeIt->second.as() == layerTypeName; + }; + return ov::util::to_lower(find_node_by_runtime_property(compiledModel, is_node_f, propertyName)); } namespace { @@ -116,7 +128,7 @@ std::string LayerTransformation::get_runtime_precision_by_fused_name(const std:: OPENVINO_ASSERT(nameIt != rtInfo.end(), "originalLayersNames is not found for node: ", layerName); return has_layer(nameIt->second.as(), layerName); }; - return find_node_by_runtime_precision(compiledModel, is_node_f); + return find_node_by_runtime_property(compiledModel, is_node_f); } bool LayerTransformation::check_execution_order(const std::vector& orderedOpsTypes) { diff --git a/src/tests/ov_helpers/ov_lpt_models/include/ov_lpt_models/mat_mul.hpp b/src/tests/ov_helpers/ov_lpt_models/include/ov_lpt_models/mat_mul.hpp index 787e1f6ebe8bd4..9a4006b917bec6 100644 --- a/src/tests/ov_helpers/ov_lpt_models/include/ov_lpt_models/mat_mul.hpp +++ b/src/tests/ov_helpers/ov_lpt_models/include/ov_lpt_models/mat_mul.hpp @@ -24,17 +24,22 @@ class MatMulFunction { static std::shared_ptr getOriginal( const ov::element::Type precision, - const ov::PartialShape inputShape1, - const ov::PartialShape inputShape2, + const ov::PartialShape& inputShape1, + const ov::PartialShape& inputShape2, const bool transpose1, - const bool transpose2); + const bool transpose2, + const bool signedWeights, + const bool perChannelWeightsDequantization, + const bool relu, + const bool fq); static std::shared_ptr getOriginal( const ov::element::Type precision, const ov::Shape& inputShape1, const FakeQuantizeOnData& fqOnData1, const ov::Shape& inputShape2, - const FakeQuantizeOnData& fqOnData2); + const FakeQuantizeOnData& fqOnData2, + const bool requantization = false); static std::shared_ptr getOriginal(const ov::element::Type netPrecision, const ov::PartialShape& inputShape1, diff --git a/src/tests/ov_helpers/ov_lpt_models/src/mat_mul.cpp b/src/tests/ov_helpers/ov_lpt_models/src/mat_mul.cpp index 1b1351ef1b3399..5f8d6086e98849 100644 --- a/src/tests/ov_helpers/ov_lpt_models/src/mat_mul.cpp +++ b/src/tests/ov_helpers/ov_lpt_models/src/mat_mul.cpp @@ -49,36 +49,108 @@ std::shared_ptr MatMulFunction::getOriginal( return function; } +namespace { +template +std::vector generate_values(const ov::Shape& shape, float delimiter = 1.f) { + std::vector values(ov::shape_size(shape)); + for (size_t i = 0; i < values.size(); ++i) { + values[i] = static_cast(static_cast(i) / delimiter); + } + return values; +} + +std::vector generate_dequantization_values( + const ov::Shape& shape, + const size_t levels, + const bool low) { + const auto shape_size = ov::shape_size(shape); + std::vector values(shape_size); + for (size_t i = 0; i < shape_size; ++i) { + values[i] = low ? -128.f / (static_cast(i) + 1.f) : 127.f / (static_cast(i) + 1.f); + } + return values; +} +} // namespace + std::shared_ptr MatMulFunction::getOriginal( - const ov::element::Type precision, - const ov::PartialShape inputShape1, - const ov::PartialShape inputShape2, - const bool transpose1, - const bool transpose2) { + const ov::element::Type precision, + const ov::PartialShape& inputShape1, + const ov::PartialShape& inputShape2, + const bool transpose1, + const bool transpose2, + const bool signedOnWeights, + const bool perChannelWeightsDequantization, + const bool relu, + const bool fq) { const auto paramNode = std::make_shared(precision, inputShape1); const std::vector constShapes(inputShape1.rank().get_length(), 1ul); - const auto fakeQuantizeOnAcitvations = ov::test::utils::make_fake_quantize( - paramNode, precision, 256ul, constShapes, - { 0.f }, { 255.f / 4.f }, { 0.f }, { 255.f / 4.f }); + const auto fakeQuantizeOnAcitvations = signedOnWeights ? + ov::test::utils::make_fake_quantize( + paramNode, precision, 256ul, constShapes, + { -128.f / 4.f }, { 127.f / 4.f }, { -128.f / 4.f }, { 127.f / 4.f }) : + ov::test::utils::make_fake_quantize( + paramNode, precision, 256ul, constShapes, + { 0.f }, { 255.f / 4.f }, { 0.f }, { 255.f / 4.f }); fakeQuantizeOnAcitvations->set_friendly_name("fakeQuantizeOnAcitvations"); - auto weightsConst = std::make_shared( - precision, - inputShape2.to_shape(), - std::vector({ 1.f })); - const auto fakeQuantizeOnWeights = ov::test::utils::make_fake_quantize( - weightsConst, precision, 256ul, { 1ul, 1ul }, - { -128.f / 8.f }, { 127.f / 8.f }, { -128.f / 8.f }, { 127.f / 8.f }); - fakeQuantizeOnWeights->set_friendly_name("fakeQuantizeOnWeights"); + const size_t channel = inputShape2[inputShape2.size() - 2].get_length(); + + // fq + std::shared_ptr parentOnWeights; + if (fq) { + auto weightsConst = std::make_shared( + precision, + inputShape2.to_shape(), + generate_values(inputShape2.to_shape(), 10.f)); + + parentOnWeights = perChannelWeightsDequantization ? + ov::test::utils::make_fake_quantize( + weightsConst, precision, 256ul, + Shape{channel, 1}, + generate_dequantization_values(Shape{channel, 1}, 256ul, true), + generate_dequantization_values(Shape{channel, 1}, 256ul, false), + generate_dequantization_values(Shape{channel, 1}, 256ul, true), + generate_dequantization_values(Shape{channel, 1}, 256ul, false)) : + ov::test::utils::make_fake_quantize( + weightsConst, precision, 256ul, {1ul, 1ul}, + {-128.f / 8.f}, {127.f / 8.f}, {-128.f / 8.f}, {127.f / 8.f}); + } else { + Shape shape = inputShape2.to_shape(); + if (transpose2) { + shape[shape.size() - 1ull] = 1; + } else { + shape[shape.size() - 2ull] = 1; + } + + auto weightsConst = std::make_shared( + signedOnWeights ? element::i8 : element::u8, + inputShape2.to_shape(), + generate_values(inputShape2.to_shape())); + + const auto convert = std::make_shared(weightsConst, precision); + + const auto multiplyConst = std::make_shared( + precision, + shape, + generate_values(shape)); + parentOnWeights = std::make_shared(convert, multiplyConst); + } - const std::shared_ptr fullyConnected = std::make_shared( + parentOnWeights->set_friendly_name("fakeQuantizeOnWeights"); + + std::shared_ptr parent = std::make_shared( fakeQuantizeOnAcitvations->output(0), - fakeQuantizeOnWeights->output(0), + parentOnWeights->output(0), transpose1, transpose2); - fullyConnected->set_friendly_name("fullyConnected"); + parent->set_friendly_name("fullyConnected"); - ov::ResultVector results{ std::make_shared(fullyConnected) }; + if (relu) { + parent = std::make_shared(parent); + parent->set_friendly_name("relu"); + } + + ov::ResultVector results{ std::make_shared(parent) }; std::shared_ptr function = std::make_shared( results, ov::ParameterVector{ paramNode }, @@ -93,21 +165,40 @@ std::shared_ptr MatMulFunction::getOriginal( const ov::Shape& inputShape1, const FakeQuantizeOnData& fqOnData1, const ov::Shape& inputShape2, - const FakeQuantizeOnData& fqOnData2) { + const FakeQuantizeOnData& fqOnData2, + const bool requantization) { const std::shared_ptr input1 = std::make_shared(precision, inputShape1); input1->set_friendly_name("input1"); const std::shared_ptr input2 = std::make_shared(precision, inputShape2); input2->set_friendly_name("input2"); - const std::shared_ptr matMul = std::make_shared( - makeFakeQuantize(input1, precision, fqOnData1), - makeFakeQuantize(input2, precision, fqOnData2), + std::shared_ptr parent1 = input1; + if (!fqOnData1.empty()) { + parent1 = makeFakeQuantize(parent1, precision, fqOnData1); + } + + std::shared_ptr parent2 = input2; + if (!fqOnData2.empty()) { + parent2 = makeFakeQuantize(parent2, precision, fqOnData2); + } + + std::shared_ptr parent = std::make_shared( + parent1, + parent2, false, false); - matMul->set_friendly_name("matMul"); + parent->set_friendly_name("matMul"); + + if (requantization) { + parent = makeFakeQuantize(parent, precision, fqOnData1); + parent = std::make_shared( + parent, + std::make_shared(ov::element::f32, Shape{1}, std::vector{0.f})); + parent->set_friendly_name("prelu"); + } - std::shared_ptr result = std::make_shared(matMul); + std::shared_ptr result = std::make_shared(parent); std::shared_ptr function = std::make_shared( ov::ResultVector{ result },