Skip to content

Commit

Permalink
[CPU] Transpose constant folding on cpu plug-in side for MatMul node …
Browse files Browse the repository at this point in the history
…only
  • Loading branch information
antonvor committed Aug 2, 2023
1 parent 7bc4fad commit 095ab95
Show file tree
Hide file tree
Showing 8 changed files with 274 additions and 43 deletions.
23 changes: 21 additions & 2 deletions src/plugins/intel_cpu/src/graph_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ void GraphOptimizer::ApplyCommonGraphOptimizations(Graph &graph) {
FuseFCAndConvertOnWeights(graph);
graph.RemoveDroppedNodes();

OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "FuseFCAndTransposeOnWeights");
FuseFCAndTransposeOnWeights(graph);
graph.RemoveDroppedNodes();

OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "FuseDeconvolutionAndSimpleOperation");
FuseDeconvolutionAndSimpleOperation(graph);
graph.RemoveDroppedNodes();
Expand Down Expand Up @@ -792,13 +796,13 @@ void GraphOptimizer::MergeConvertAndScaleShift(Graph& graph) {
}

void GraphOptimizer::FuseFCAndConvertOnWeights(Graph& graph) {
// This optimization fuses Convert (fp16 -> bf16/fp32) on weights directly to FC input to allow precision conversion handling based on internal logic
// This optimization fuses Convert (fp16/u8 -> bf16/fp32) on weights directly to FC input to allow precision conversion handling based on internal logic
// (e.g. fuse conversion with weights reordering)
auto& graphNodes = graph.GetNodes();

for (auto parent : graphNodes) {
if (parent->getType() == Type::Convert && parent->isConstant() && parent->getChildEdgeAt(0)->getChild()->getType() == Type::FullyConnected
&& parent->getOriginalInputPrecisionAtPort(0) == Precision::FP16
&& one_of(parent->getOriginalInputPrecisionAtPort(0), Precision::FP16, Precision::U8)
&& one_of(parent->getOriginalOutputPrecisionAtPort(0), Precision::FP32, Precision::BF16)) {
auto childNode = parent->getChildEdgeAt(0)->getChild();
// set correct weight precision
Expand All @@ -808,6 +812,21 @@ void GraphOptimizer::FuseFCAndConvertOnWeights(Graph& graph) {
}
}

void GraphOptimizer::FuseFCAndTransposeOnWeights(Graph& graph) {
// This optimization allows us to avoid transposing the weights in Transpose node and do it directly along with reordering in FC node
auto& graphNodes = graph.GetNodes();

for (auto parent : graphNodes) {
if (parent->getType() == Type::Transpose && parent->isConstant() && parent->getChildEdgeAt(0)->getChild()->getType() == Type::FullyConnected
&& parent->getOutputShapeAtPort(0).getRank() == 2) {
auto fcNode = std::dynamic_pointer_cast<FullyConnected>(parent->getChildEdgeAt(0)->getChild());
fcNode->setTransposeWeights(true);
auto transposeNode = std::dynamic_pointer_cast<Transpose>(parent);
transposeNode->setFakeTranspose(true);
}
}
}

void GraphOptimizer::FuseConvolutionAndZeroPoints(Graph &graph) {
auto& graphNodes = graph.GetNodes();

Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/graph_optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class GraphOptimizer {
void FuseMultiplyAndAdd(Graph &graph);
void MergeConvertAndScaleShift(Graph& graph);
void FuseFCAndConvertOnWeights(Graph& graph);
void FuseFCAndTransposeOnWeights(Graph& graph);
void FuseFullyConnectedAndSimpleOperation(Graph &graph);
void FuseMatMulAndSimpleOperation(Graph &graph);
void FuseConvolutionAndSimpleOperationThroughMaxPool(Graph &graph);
Expand Down
15 changes: 8 additions & 7 deletions src/plugins/intel_cpu/src/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,14 @@ class Node {

std::shared_ptr<IShapeInfer> shapeInference;

// we cannot rely on per-NUMA weightCache for caching weights because:
// 1.it may not exist(in single stream configuration)
// 2.it only holds weak references, the life-cycle of cached item
// is still under control of strong references outside of cache.
// privateWeightCache is for holding strong references to constant weight
// copies of same content with different layouts.
std::unordered_map<std::string, MemoryPtr> privateWeightCache;

private:
std::vector<EdgeWeakPtr> parentEdges;
std::vector<EdgeWeakPtr> childEdges;
Expand Down Expand Up @@ -723,13 +731,6 @@ class Node {
ConstantType checkConstant(LOOK look, std::vector<NodePtr>& checkNodes);
// Hold output scales
std::vector<float> DQScales;
// we cannot rely on per-NUMA weightCache for caching weights because:
// 1.it may not exist(in single stream configuration)
// 2.it only holds weak references, the life-cycle of cached item
// is still under control of strong references outside of cache.
// privateWeightCache is for holding strong references to constant weight
// copies of same content with different layouts.
std::unordered_map<std::string, MemoryPtr> privateWeightCache;

CPU_DEBUG_CAP_ENABLE(friend class Verbose);
};
Expand Down
56 changes: 53 additions & 3 deletions src/plugins/intel_cpu/src/nodes/fullyconnected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,12 +337,14 @@ void FullyConnected::prepackMLASWeight() {
MemoryPtr ptr;
auto create = [&]() {
float* weightPtr = reinterpret_cast<float*>(weightsMem->getData());
size_t ldb = K;
// todo: leave comment about transposeWeights
size_t ldb = transposeWeights ? N : K;
MemoryPtr _ptr =
std::make_shared<Memory>(getEngine(),
intel_cpu::CpuBlockedMemoryDesc(Precision::I8, intel_cpu::Shape{packedBsize}));
float* prepackedDst = reinterpret_cast<float*>(_ptr->getData());
mlas_sgemm_pack("T", N, K, ldb, weightPtr, prepackedDst);
// todo: leave detailed comment (double transpose = do nothing)
mlas_sgemm_pack(transposeWeights ? "F" : "T", N, K, ldb, weightPtr, prepackedDst);
return _ptr;
};

Expand Down Expand Up @@ -513,7 +515,11 @@ void FullyConnected::prepareParams() {
}

if (!prevExecPtr || !execPtr->getWeightDesc()->isCompatible(*(prevExecPtr->getWeightDesc()))) {
primArgs[DNNL_ARG_WEIGHTS] = prepareWeightMemory(execPtr->getWeightDesc())->getPrimitive();
if (transposeWeights) {
primArgs[DNNL_ARG_WEIGHTS] = prepareTransposedWeightMemory(execPtr->getWeightDesc())->getPrimitive();
} else {
primArgs[DNNL_ARG_WEIGHTS] = prepareWeightMemory(execPtr->getWeightDesc())->getPrimitive();
}
}
// changed shapes may also cause the kernel type changed
selected_pd->setImplementationType(execPtr->getImplementationType());
Expand Down Expand Up @@ -1106,6 +1112,50 @@ void FullyConnected::fuseDecompressionConstant(const NodePtr& constData, std::ve
Precision::FP32,
elementsCount);
}

MemoryPtr FullyConnected::prepareTransposedWeightMemory(DnnlMemoryDescPtr weightDesc) {
if (!getParentEdgeAt(1)->getParent()->isConstant())
IE_THROW() << "Weight input is not const for node " << getName() << ".";
auto edgeMem = getParentEdgeAt(1)->getMemoryPtr();
if (!edgeMem)
IE_THROW() << "Cannot get const weights edgeMem for node " << getName() << ".";

auto constDnnlMemOutDesc = edgeMem->getDescWithType<DnnlMemoryDesc>();
auto weightSrcDesc = constDnnlMemOutDesc->getDnnlDesc();
weightSrcDesc = {weightSrcDesc.get_dims(), weightSrcDesc.get_data_type(), memory::format_tag::ba};
weightSrcDesc = weightSrcDesc.reshape(weightDesc->getDnnlDesc().get_dims());

auto create = [&] () {
auto newSrcDesc = DnnlExtensionUtils::makeDescriptor(weightSrcDesc);

Memory srcMemory{ getEngine(), newSrcDesc, edgeMem->getData() };
MemoryPtr _ptr = std::make_shared<Memory>(getEngine(), weightDesc);
node::Reorder::reorderData(srcMemory, *_ptr, context->getParamsCache());

return _ptr;
};

MemoryPtr ptr;
const auto& format = weightDesc->serializeFormat();
auto itr = privateWeightCache.find(format);
if (privateWeightCache.end() != itr) {
ptr = itr->second;
} else {
auto weightCache = context->getWeightsCache();
if (weightCache != nullptr) {
const std::string string_hash = getName() + "_" + format
+ "_" + std::to_string(edgeMem->getSize())
+ "_" + std::to_string(reinterpret_cast<uint64_t>(edgeMem->getData()));

ptr = *weightCache->findOrCreate(string_hash, create);
} else {
ptr = create();
}
privateWeightCache[format] = ptr;
}

return ptr;
}
} // namespace node
} // namespace intel_cpu
} // namespace ov
8 changes: 8 additions & 0 deletions src/plugins/intel_cpu/src/nodes/fullyconnected.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ class FullyConnected : public Node {
void prepareParams() override;
void executeDynamicImpl(dnnl::stream strm) override;
bool canBeExecutedInInt8() const override;
void setTransposeWeights(bool transpose) {
transposeWeights = transpose;
}

void fuseDecompressionMultiply(const NodePtr& constData);
const std::vector<float>& getDecompressionMultiply() const { return decompressionMultiply; }
Expand Down Expand Up @@ -117,6 +120,11 @@ class FullyConnected : public Node {

std::vector<float> decompressionSubtract;
std::vector<float> decompressionMultiply;

// FC with transposed weights
bool transposeWeights = false;
// this method is using to prepare transposed memory
MemoryPtr prepareTransposedWeightMemory(DnnlMemoryDescPtr weightDesc);
};

} // namespace node
Expand Down
4 changes: 3 additions & 1 deletion src/plugins/intel_cpu/src/nodes/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ void Transpose::initSupportedPrimitiveDescriptors() {
config.inConfs[INPUT_ORDER_IDX].constant(isInputOrderConst);
config.inConfs[INPUT_ORDER_IDX].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(
Precision::I32, getInputShapeAtPort(INPUT_ORDER_IDX)));
config.outConfs[0].inPlace(-1);
config.outConfs[0].inPlace(fakeTranspose ? 0 : -1);
config.outConfs[0].constant(false);
transpose_context = std::make_shared<ExecutorContext>(context, getImplPriority());

Expand Down Expand Up @@ -287,6 +287,8 @@ void Transpose::createPrimitive() {
}

void Transpose::execute(dnnl::stream strm) {
if (fakeTranspose)
return;
if (prim) {
prim.execute(strm, primArgs);
} else if (execPtr) {
Expand Down
5 changes: 5 additions & 0 deletions src/plugins/intel_cpu/src/nodes/transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ class Transpose : public Node {
bool needPrepareParams() const override;
void prepareParams() override;

void setFakeTranspose(bool fake) {
fakeTranspose = fake;
}

protected:
void executeDynamicImpl(dnnl::stream strm) override;
std::shared_ptr<ExecutorContext> transpose_context;
Expand All @@ -56,6 +60,7 @@ class Transpose : public Node {
static constexpr size_t INPUT_ORDER_IDX = 1lu;

bool performAsReorder = false;
bool fakeTranspose = false;
};

} // namespace node
Expand Down
Loading

0 comments on commit 095ab95

Please sign in to comment.