Skip to content

Commit

Permalink
[CPU][EXPERIMENTAL] FullyConnected: enabled sparsity weights decompre…
Browse files Browse the repository at this point in the history
…ssion
  • Loading branch information
antonvor committed Nov 30, 2022
1 parent 9478062 commit 9bb8a83
Show file tree
Hide file tree
Showing 13 changed files with 130 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ void regmodule_properties(py::module m) {

// Submodule intel_cpu property
wrap_property_RW(m_intel_cpu, ov::intel_cpu::denormals_optimization, "denormals_optimization");
wrap_property_RW(m_intel_cpu,
ov::intel_cpu::sparse_weights_decompression_rate,
"sparse_weights_decompression_rate");

// Submodule device
py::module m_device =
Expand Down
2 changes: 2 additions & 0 deletions src/inference/include/ie/cpu/cpu_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,7 @@ namespace CPUConfigParams {
*/
DECLARE_CPU_CONFIG_KEY(DENORMALS_OPTIMIZATION);

DECLARE_CPU_CONFIG_KEY(SPARSE_WEIGHTS_DECOMPRESSION_RATE);

} // namespace CPUConfigParams
} // namespace InferenceEngine
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,7 @@ namespace intel_cpu {
*/
static constexpr Property<bool> denormals_optimization{"CPU_DENORMALS_OPTIMIZATION"};

static constexpr Property<float> sparse_weights_decompression_rate{"SPARSE_WEIGHTS_DECOMPRESSION_RATE"};

} // namespace intel_cpu
} // namespace ov
14 changes: 14 additions & 0 deletions src/plugins/intel_cpu/src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,20 @@ void Config::readProperties(const std::map<std::string, std::string> &prop) {
// zero and any negative value will be treated
// as default batch size
batchLimit = std::max(val_i, 0);
} else if (key == CPUConfigParams::KEY_CPU_SPARSE_WEIGHTS_DECOMPRESSION_RATE) {
float val_f = 0.0f;
try {
val_f = std::stof(val);
} catch (const std::exception&) {
IE_THROW() << "Wrong value for property key " << CPUConfigParams::KEY_CPU_SPARSE_WEIGHTS_DECOMPRESSION_RATE
<< ". Expected only float numbers";
}
if (val_f < 0.f || val_f > 1.f) {
IE_THROW() << "Wrong value for property key " << CPUConfigParams::KEY_CPU_SPARSE_WEIGHTS_DECOMPRESSION_RATE
<< ". Sparse rate must be in range [0.0f,1.0f]";
} else {
fcSparseWeiDecompressionRate = val_f;
}
} else if (key == PluginConfigParams::KEY_PERF_COUNT) {
if (val == PluginConfigParams::YES) collectPerfCounters = true;
else if (val == PluginConfigParams::NO) collectPerfCounters = false;
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ struct Config {
bool enableDynamicBatch = false;
std::string dumpToDot = "";
int batchLimit = 0;
float fcSparseWeiDecompressionRate = 1.0f;
size_t rtCacheCapacity = 5000ul;
InferenceEngine::IStreamsExecutor::Config streamExecutorConfig;
InferenceEngine::PerfHintsConfig perfHintsConfig;
Expand Down
12 changes: 12 additions & 0 deletions src/plugins/intel_cpu/src/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <nodes/reorder.h>
#include "nodes/convert.h"
#include "nodes/subgraph.h"
#include "nodes/fullyconnected.h"

#include <ie_algorithm.hpp>
#include <blob_factory.hpp>
Expand Down Expand Up @@ -341,6 +342,9 @@ void Graph::Replicate(const CNNNetwork &network, const ExtensionManager::Ptr& ex
if (config.enforceBF16)
EnforceBF16();

if (config.fcSparseWeiDecompressionRate < 1.0f)
setMinSparseRate(config.fcSparseWeiDecompressionRate);

auto hasSubgraphConsumers = [] (const NodePtr& node) -> bool {
const auto & childEdges = node->getChildEdges();
return std::any_of(childEdges.begin(), childEdges.end(),
Expand Down Expand Up @@ -1454,6 +1458,14 @@ void Graph::EnforceBF16() {
}
}

void Graph::setMinSparseRate(float minSparseRate) {
for (const auto &node : graphNodes) {
if (auto fcNodePtr = std::dynamic_pointer_cast<node::FullyConnected>(node)) {
fcNodePtr->setMinSparseRate(minSparseRate);
}
}
}

std::shared_ptr<ngraph::Function> Graph::dump() const {
return dump_graph_as_ie_ngraph_net(*this);
}
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ class Graph {
DnnlScratchPadPtr rtScratchPad;

void EnforceBF16();
void setMinSparseRate(float minSparseRate);
};

} // namespace intel_cpu
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ std::string Node::getPrimitiveDescriptorType() {
SEARCH_TYPE(uni);

SEARCH_TYPE(winograd);
SEARCH_TYPE(sparse);
SEARCH_TYPE(_dw);
SEARCH_TYPE(_1x1);

Expand Down
82 changes: 78 additions & 4 deletions src/plugins/intel_cpu/src/nodes/fullyconnected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "fullyconnected.h"
#include "eltwise.h"
#include "input.h"
#include "fake_quantize.h"
#include "input.h"
#include "reorder.h"
Expand All @@ -22,6 +23,7 @@
#include <common/primitive_desc.hpp>
#include <common/primitive_desc_iface.hpp>
#include "onednn/dnnl.h"
#include "cpu/x64/cpu_isa_traits.hpp"

using namespace dnnl;
using namespace InferenceEngine;
Expand Down Expand Up @@ -172,6 +174,8 @@ void FullyConnected::getSupportedDescriptors() {
if (getChildEdges().empty())
IE_THROW()<< errorPrefix << " has incorrect number of output edges";

useSparseWeights = useSparseWeightsDecompression();

auto inputDataType = DnnlExtensionUtils::IEPrecisionToDataType(getOriginalInputPrecisionAtPort(DATA_ID));
outputDataType = DnnlExtensionUtils::IEPrecisionToDataType(getOriginalOutputPrecisionAtPort(DATA_ID));

Expand Down Expand Up @@ -360,6 +364,10 @@ void FullyConnected::prepareParams() {
}
// changed shapes may also cause the kernel type changed
selected_pd->setImplementationType(execPtr->getImplementationType());
// WA: We update implType to know whether weights decompression was used inside the kernel
if (selected_pd->getImplementationType() == ov::intel_cpu::brgemm_avx512_amx && useSparseWeights) {
selected_pd->setImplementationType(ov::intel_cpu::brgemm_sparse_avx512_amx);
}
// maybe expected 1x1 conv is not created, update the flag depends on the real type
useConv1x1 = execPtr->getImplementationType() == brgconv_avx512_1x1;

Expand Down Expand Up @@ -503,6 +511,7 @@ bool FullyConnected::created() const {
const std::vector<impl_desc_type>& FullyConnected::getPrimitivesPriority() {
std::vector<impl_desc_type> priorities = {
impl_desc_type::unknown,
impl_desc_type::brgemm_sparse_avx512_amx,
impl_desc_type::brgemm_avx512_amx,
impl_desc_type::brgemm_avx512,
impl_desc_type::gemm_blas,
Expand Down Expand Up @@ -578,9 +587,15 @@ void FullyConnected::createDescriptorInternal(const dnnl::memory::desc &inputDes
DnnlExtensionUtils::GetPlainFormatByRank(normalizedOutDims.size()));
}

dnnl::memory::desc wgh_candidate(DnnlExtensionUtils::convertToDnnlDims(getInputShapeAtPort(WEIGHTS_ID).getStaticDims()),
wdt, dnnl::memory::format_tag::any);

// We need to explicitly specify the memory descriptor to use sparse weights decompression
dnnl::memory::desc wgh_candidate;
if (useSparseWeights) {
wgh_candidate = { DnnlExtensionUtils::convertToDnnlDims(getInputShapeAtPort(WEIGHTS_ID).getStaticDims()),
wdt, memory::desc::packed(nnzCount) };
} else {
wgh_candidate = { DnnlExtensionUtils::convertToDnnlDims(getInputShapeAtPort(WEIGHTS_ID).getStaticDims()),
wdt, dnnl::memory::format_tag::any };
}
if (withBiases) {
dnnl::memory::desc bias_candidate(DnnlExtensionUtils::convertToDnnlDims(getInputShapeAtPort(BIAS_ID).getStaticDims()), bdt,
dnnl::memory::format_tag::any);
Expand Down Expand Up @@ -634,7 +649,7 @@ void FullyConnected::initSupportedPrimitiveDescriptors() {
portConfig.inPlace(-1);
portConfig.constant(false);
auto desc = getSrcMemDesc(itpd, i);
if (supportsUndefStridesAndOffset()) {
if (supportsUndefStridesAndOffset() && !(i == WEIGHTS_ID && useSparseWeights)) {
portConfig.setMemDesc(std::dynamic_pointer_cast<BlockedMemoryDesc>(desc), BLOCKED_DESC_EMPTY_MASK);
} else {
portConfig.setMemDesc(desc);
Expand Down Expand Up @@ -868,6 +883,65 @@ MemoryPtr FullyConnected::prepareWeightMemory(DnnlMemoryDescPtr weightDesc) {
return ptr;
}

bool FullyConnected::useSparseWeightsDecompression() {
// minSparseRate == 1 means that sparse feature is switched off
if (minSparseRate == 1.f) {
return false;
}

if (!impl::cpu::x64::mayiuse(impl::cpu::x64::avx512_core_amx))
return false;

auto weiDims = getInputShapeAtPort(WEIGHTS_ID).getStaticDims();
if (weiDims.size() != 2 || weiDims[0] % 64 != 0 || weiDims[1] % 64 != 0) {
return false;
}

auto inputPrecision = getOriginalInputPrecisionAtPort(DATA_ID);
auto weightsPrecision = getOriginalInputPrecisionAtPort(WEIGHTS_ID);
if (!one_of(inputPrecision , Precision::U8, Precision::I8) || weightsPrecision != Precision::I8) {
return false;
}

// calculate sparse rate
const auto constNode = std::dynamic_pointer_cast<Input>(getParentEdgeAt(WEIGHTS_ID)->getParent());
if (!constNode) {
return false;
}
auto blb = constNode->getMemoryPtr();
if (blb == nullptr)
IE_THROW() << "Cannot get const blob for node " << getName() << ".";

auto weightsData = reinterpret_cast<const int8_t*>(blb->GetPtr());
auto elementsCount = blb->GetDescWithType<BlockedMemoryDesc>()->getPaddedElementsCount();
size_t zerosCounts = 0;
for (int i = 0; i < elementsCount; i++) {
if (weightsData[i] == 0) {
zerosCounts++;
}
}
nnzCount = elementsCount - zerosCounts;

DEBUG_LOG(getName(), ", weightsData.size() = ", elementsCount, ", zerosCounts = ",
zerosCounts, ", nnzCount = ", nnzCount);

weiSparseRate = static_cast<float>(zerosCounts) / static_cast<float>(elementsCount);

// [av] WA: there is no point in using sparse decompression when the sparse rate is low
// todo: add heuristic
if (minSparseRate < 0.5)
minSparseRate = 0.5;

DEBUG_LOG(getName(), " | sparse rate = ", weiSparseRate * 100, "%, min sparse rate = ",
minSparseRate * 100, "%, use sparse weights = ", weiSparseRate >= minSparseRate);

if (weiSparseRate < minSparseRate) {
return false;
}

return true;
}

} // namespace node
} // namespace intel_cpu
} // namespace ov
10 changes: 10 additions & 0 deletions src/plugins/intel_cpu/src/nodes/fullyconnected.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class FullyConnected : public Node {

void initSupportedPrimitiveDescriptors() override;
void initOptimalPrimitiveDescriptor() override;
// void createPrimitive() override;
std::shared_ptr<MemoryDesc> getSrcMemDesc(dnnl::primitive_desc_iterator &primitive_desc_it, size_t idx) override;
std::shared_ptr<MemoryDesc> getDstMemDesc(dnnl::primitive_desc_iterator &primitive_desc_it, size_t idx) override;

Expand All @@ -58,6 +59,8 @@ class FullyConnected : public Node {

void setDynamicBatchLim(int lim) override;

void setMinSparseRate(float sparseRate) { minSparseRate = sparseRate; }

private:
void createDescriptorInternal(const dnnl::memory::desc &inputDesc,
const dnnl::memory::desc &outputDesc);
Expand Down Expand Up @@ -106,6 +109,13 @@ class FullyConnected : public Node {

bool canBeExecutedInConv1x1() const;
MemoryPtr prepareWeightMemory(const DnnlMemoryDescPtr weightDesc);

// sparse weights
bool useSparseWeights = false;
int nnzCount = -1;
float minSparseRate = 1.f;
float weiSparseRate = 0.f;
bool useSparseWeightsDecompression();
};

} // namespace node
Expand Down
2 changes: 2 additions & 0 deletions src/plugins/intel_cpu/src/onednn/iml_type_mapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ impl_desc_type parse_impl_name(std::string impl_desc_name) {
SEARCH_WORD(_1x1);
SEARCH_WORD(_dw);
SEARCH_WORD(reorder);
SEARCH_WORD(sparse);
if ((res & impl_desc_type::avx2) != impl_desc_type::avx2 &&
(res & impl_desc_type::avx512) != impl_desc_type::avx512)
SEARCH_WORD(avx);
Expand Down Expand Up @@ -108,6 +109,7 @@ const char* impl_type_to_string(impl_desc_type type) {
CASE(brgemm_sse42);
CASE(brgemm_uni);
CASE(brgemm_avx512_amx);
CASE(brgemm_sparse_avx512_amx);

#undef CASE
return "unknown";
Expand Down
3 changes: 3 additions & 0 deletions src/plugins/intel_cpu/src/onednn/iml_type_mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ enum impl_desc_type {
reorder = 1<<22,
// winograd
winograd = 1<<23,
// sparse
sparse = 1<<24,

// real types
ref_any = ref | any,
Expand Down Expand Up @@ -90,6 +92,7 @@ enum impl_desc_type {
brgemm_sse42 = brgemm | sse42,
brgemm_uni = brgemm | uni,
brgemm_avx512_amx = brgemm | avx512 | amx,
brgemm_sparse_avx512_amx = brgemm | sparse | avx512 | amx,
};

const char * impl_type_to_string(impl_desc_type type);
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/thirdparty/onednn
Submodule onednn updated 67 files
+95 −0 include/oneapi/dnnl/dnnl.h
+159 −3 include/oneapi/dnnl/dnnl.hpp
+1 −0 include/oneapi/dnnl/dnnl_debug.h
+44 −0 include/oneapi/dnnl/dnnl_types.h
+8 −1 scripts/generate_dnnl_debug.py
+13 −0 src/common/c_types_map.hpp
+13 −0 src/common/dnnl_debug_autogenerated.cpp
+205 −6 src/common/memory.cpp
+46 −6 src/common/memory.hpp
+111 −8 src/common/memory_desc_wrapper.hpp
+5 −1 src/common/primitive.hpp
+4 −3 src/common/primitive_exec_types.cpp
+3 −3 src/common/primitive_exec_types.hpp
+22 −0 src/common/primitive_hashing_utils.cpp
+9 −0 src/common/primitive_hashing_utils.hpp
+62 −8 src/common/type_helpers.hpp
+2 −2 src/common/utils.hpp
+6 −0 src/common/verbose.cpp
+3 −0 src/cpu/reorder/cpu_reorder.hpp
+4 −0 src/cpu/reorder/cpu_reorder_regular_f32_s8.cpp
+4 −0 src/cpu/reorder/cpu_reorder_regular_s8.cpp
+248 −0 src/cpu/reorder/simple_sparse_reorder.hpp
+6 −0 src/cpu/x64/brgemm/brgemm_types.hpp
+4 −4 src/cpu/x64/brgemm/jit_brdgmm_kernel.cpp
+1 −1 src/cpu/x64/brgemm/jit_brgemm_amx_uker.cpp
+1 −1 src/cpu/x64/brgemm/jit_brgemm_kernel.cpp
+2 −2 src/cpu/x64/gemm_bf16_convolution.cpp
+615 −394 src/cpu/x64/injectors/jit_uni_binary_injector.cpp
+94 −115 src/cpu/x64/injectors/jit_uni_binary_injector.hpp
+1 −1 src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.cpp
+1 −1 src/cpu/x64/jit_avx2_conv_kernel_f32.cpp
+1 −1 src/cpu/x64/jit_avx512_common_1x1_conv_kernel.cpp
+1 −1 src/cpu/x64/jit_avx512_common_conv_kernel.cpp
+4 −2 src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.cpp
+1 −0 src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.hpp
+5 −2 src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp
+1 −0 src/cpu/x64/jit_avx512_core_amx_conv_kernel.hpp
+1 −1 src/cpu/x64/jit_avx512_core_bf16_1x1_conv_kernel.cpp
+1 −1 src/cpu/x64/jit_avx512_core_bf16_conv_kernel.cpp
+2 −2 src/cpu/x64/jit_avx512_core_bf16_dw_conv_kernel.cpp
+1 −1 src/cpu/x64/jit_avx512_core_fork_bf16_dw_conv_kernel.cpp
+1 −1 src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp
+1 −1 src/cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.cpp
+1 −1 src/cpu/x64/jit_avx512_core_x8s8s32x_deconvolution.cpp
+106 −0 src/cpu/x64/jit_brgemm_decompress_kernel.cpp
+80 −0 src/cpu/x64/jit_brgemm_decompress_kernel.hpp
+33 −6 src/cpu/x64/jit_brgemm_inner_product.cpp
+8 −0 src/cpu/x64/jit_brgemm_inner_product.hpp
+16 −1 src/cpu/x64/jit_brgemm_inner_product_utils.cpp
+1 −1 src/cpu/x64/jit_brgemm_post_ops.hpp
+3 −0 src/cpu/x64/jit_brgemm_primitive_conf.hpp
+1 −1 src/cpu/x64/jit_gemm_convolution_utils.cpp
+1 −1 src/cpu/x64/jit_gemm_inner_product_utils.cpp
+1 −1 src/cpu/x64/jit_gemm_x8s8s32x_convolution_utils.cpp
+2 −2 src/cpu/x64/jit_sse41_1x1_conv_kernel_f32.cpp
+2 −2 src/cpu/x64/jit_sse41_conv_kernel_f32.cpp
+4 −3 src/cpu/x64/jit_uni_binary_kernel.cpp
+1 −1 src/cpu/x64/jit_uni_dw_conv_kernel_f32.cpp
+1 −1 src/cpu/x64/jit_uni_fork_dw_conv_kernel_f32.cpp
+1 −1 src/cpu/x64/jit_uni_i8i8_pooling.cpp
+1 −1 src/cpu/x64/jit_uni_pool_kernel.cpp
+3 −3 src/cpu/x64/jit_uni_reduction_kernel.cpp
+1 −0 src/cpu/x64/jit_uni_reduction_kernel.hpp
+1 −1 src/cpu/x64/jit_uni_resampling_kernel.cpp
+1 −1 src/cpu/x64/jit_uni_x8s8s32x_1x1_conv_kernel.cpp
+1 −1 src/cpu/x64/jit_uni_x8s8s32x_conv_kernel.cpp
+1 −1 src/cpu/x64/jit_uni_x8s8s32x_deconvolution.cpp

0 comments on commit 9bb8a83

Please sign in to comment.