diff --git a/src/plugins/intel_cpu/src/nodes/fullyconnected.cpp b/src/plugins/intel_cpu/src/nodes/fullyconnected.cpp index fcb112c368131c..117290ebcb0e52 100644 --- a/src/plugins/intel_cpu/src/nodes/fullyconnected.cpp +++ b/src/plugins/intel_cpu/src/nodes/fullyconnected.cpp @@ -52,6 +52,7 @@ struct FCKey { dnnl::primitive_attr attr; impl_desc_type implType; bool useConv1x1; + bool useSparseWeights; size_t hash() const; bool operator==(const FCKey& rhs) const; @@ -72,6 +73,7 @@ size_t FCKey::hash() const { seed = hash_combine(seed, get_attr_hash(*attr.get())); seed = hash_combine(seed, implType); seed = hash_combine(seed, useConv1x1); + seed = hash_combine(seed, useSparseWeights); return seed; } @@ -90,7 +92,7 @@ bool FCKey::operator==(const FCKey &rhs) const { retVal = retVal && out && rhs.out && out->getDnnlDesc() == rhs.out->getDnnlDesc(); } retVal = retVal && *attr.get() == *rhs.attr.get() && - implType == rhs.implType && useConv1x1 == rhs.useConv1x1; + implType == rhs.implType && useConv1x1 == rhs.useConv1x1 && useSparseWeights == rhs.useSparseWeights; return retVal; } @@ -416,15 +418,20 @@ static dnnl::primitive_desc createPrimitiveDesc(const FCKey& key, const dnnl::en auto normalizedOutDims = { outDims[0] * outDims[1], outDims[2] }; outDesc = outDesc.reshape(normalizedOutDims); } - auto wghDescAny = dnnl::memory::desc(DnnlExtensionUtils::convertToDnnlDims(key.inp1->getShape().getStaticDims()), - key.inp1->getDataType(), memory::format_tag::any); + dnnl::memory::desc weiDesc; + if (key.useSparseWeights) { + weiDesc = key.inp1->getDnnlDesc(); + } else { + weiDesc = dnnl::memory::desc(DnnlExtensionUtils::convertToDnnlDims(key.inp1->getShape().getStaticDims()), + key.inp1->getDataType(), memory::format_tag::any); + } dnnl::inner_product_forward::primitive_desc prim_desc; if (key.bias) { prim_desc = dnnl::inner_product_forward::primitive_desc( engine, dnnl::prop_kind::forward_inference, inDesc, - wghDescAny, + weiDesc, key.bias->getDnnlDesc(), outDesc, key.attr); @@ -433,7 +440,7 @@ static dnnl::primitive_desc createPrimitiveDesc(const FCKey& key, const dnnl::en engine, dnnl::prop_kind::forward_inference, inDesc, - wghDescAny, + weiDesc, outDesc, key.attr); } @@ -542,7 +549,8 @@ void FullyConnected::prepareParams() { outDesc, attr, implementationTypeIP, - useConv1x1}; + useConv1x1, + useSparseWeights}; auto& engine = getEngine(); @@ -597,7 +605,8 @@ void FullyConnected::prepareParams() { // changed shapes may also cause the kernel type changed selected_pd->setImplementationType(execPtr->getImplementationType()); // WA: We update implType to know whether weights decompression was used inside the kernel - if (selected_pd->getImplementationType() == ov::intel_cpu::brgemm_avx512_amx && useSparseWeights) { + if (selected_pd->getImplementationType() == ov::intel_cpu::brgemm_avx512_amx && + execPtr->getDnnlWeightDesc().get_format_kind() == memory::format_kind::sparsed) { selected_pd->setImplementationType(ov::intel_cpu::brgemm_sparse_avx512_amx); } // maybe expected 1x1 conv is not created, update the flag depends on the real type @@ -960,7 +969,7 @@ std::shared_ptr FullyConnected::getSrcMemDesc(const dnnl::primitive_ if (getInputShapeAtPort(idx).getRank() == 3 // report original plain layout for weight since it needs to be reordered dynamically at runtime - || idx == 1) { + || (idx == 1 && !useSparseWeights)) { return std::make_shared( DnnlExtensionUtils::DataTypeToIEPrecision(desc.get_data_type()), getInputShapeAtPort(idx)); }