Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
antonvor committed Jul 26, 2023
1 parent 6d59a62 commit 2afbd44
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 9 deletions.
21 changes: 21 additions & 0 deletions src/plugins/intel_cpu/src/graph_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ void GraphOptimizer::ApplyCommonGraphOptimizations(Graph &graph) {
FuseFCAndConvertOnWeights(graph);
graph.RemoveDroppedNodes();

OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "FuseFCAndTransposeOnWeights");
FuseFCAndTransposeOnWeights(graph);
graph.RemoveDroppedNodes();

OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "FuseDeconvolutionAndSimpleOperation");
FuseDeconvolutionAndSimpleOperation(graph);
graph.RemoveDroppedNodes();
Expand Down Expand Up @@ -696,6 +700,7 @@ void GraphOptimizer::MergeConvertAndScaleShift(Graph& graph) {
}
}

// todo: FuseFCAndConvertOnWeights and FuseFCAndTransposeOnWeights?
void GraphOptimizer::FuseFCAndConvertOnWeights(Graph& graph) {
// This optimization fuses Convert (fp16 -> bf16/fp32) on weights directly to FC input to allow precision conversion handling based on internal logic
// (e.g. fuse conversion with weights reordering)
Expand All @@ -710,6 +715,22 @@ void GraphOptimizer::FuseFCAndConvertOnWeights(Graph& graph) {
}
}

void GraphOptimizer::FuseFCAndTransposeOnWeights(Graph& graph) {
// todo: add
auto& graphNodes = graph.GetNodes();

for (auto parent : graphNodes) {
// todo: add order and ndims check?
if (parent->getType() == Type::Transpose && parent->isConstant() && parent->getChildEdgeAt(0)->getChild()->getType() == Type::FullyConnected) {
auto fcNode = std::dynamic_pointer_cast<FullyConnected>(parent->getChildEdgeAt(0)->getChild());
fcNode->setTransposeWeights(true);
auto transposeNode = std::dynamic_pointer_cast<Transpose>(parent);
transposeNode->setFakeTranspose(true);
// graph.DropNode(parent);
}
}
}

void GraphOptimizer::FuseConvolutionAndZeroPoints(Graph &graph) {
auto& graphNodes = graph.GetNodes();

Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/graph_optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class GraphOptimizer {
void FuseMultiplyAndAdd(Graph &graph);
void MergeConvertAndScaleShift(Graph& graph);
void FuseFCAndConvertOnWeights(Graph& graph);
void FuseFCAndTransposeOnWeights(Graph& graph);
void FuseFullyConnectedAndSimpleOperation(Graph &graph);
void FuseMatMulAndSimpleOperation(Graph &graph);
void FuseConvolutionAndSimpleOperationThroughMaxPool(Graph &graph);
Expand Down
15 changes: 8 additions & 7 deletions src/plugins/intel_cpu/src/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,14 @@ class Node {

std::shared_ptr<IShapeInfer> shapeInference;

// we cannot rely on per-NUMA weightCache for caching weights because:
// 1.it may not exist(in single stream configuration)
// 2.it only holds weak references, the life-cycle of cached item
// is still under control of strong references outside of cache.
// privateWeightCache is for holding strong references to constant weight
// copies of same content with different layouts.
std::unordered_map<std::string, MemoryPtr> privateWeightCache;

private:
std::vector<EdgeWeakPtr> parentEdges;
std::vector<EdgeWeakPtr> childEdges;
Expand Down Expand Up @@ -723,13 +731,6 @@ class Node {
ConstantType checkConstant(LOOK look, std::vector<NodePtr>& checkNodes);
// Hold output scales
std::vector<float> DQScales;
// we cannot rely on per-NUMA weightCache for caching weights because:
// 1.it may not exist(in single stream configuration)
// 2.it only holds weak references, the life-cycle of cached item
// is still under control of strong references outside of cache.
// privateWeightCache is for holding strong references to constant weight
// copies of same content with different layouts.
std::unordered_map<std::string, MemoryPtr> privateWeightCache;

CPU_DEBUG_CAP_ENABLE(friend class Verbose);
};
Expand Down
50 changes: 49 additions & 1 deletion src/plugins/intel_cpu/src/nodes/fullyconnected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,11 @@ void FullyConnected::prepareParams() {
}

if (!prevExecPtr || !execPtr->getWeightDesc()->isCompatible(*(prevExecPtr->getWeightDesc()))) {
primArgs[DNNL_ARG_WEIGHTS] = prepareWeightMemory(execPtr->getWeightDesc())->getPrimitive();
if (transposeWeights) {
primArgs[DNNL_ARG_WEIGHTS] = prepareTransposedWeightMemory(execPtr->getWeightDesc())->getPrimitive();
} else {
primArgs[DNNL_ARG_WEIGHTS] = prepareWeightMemory(execPtr->getWeightDesc())->getPrimitive();
}
}
// changed shapes may also cause the kernel type changed
selected_pd->setImplementationType(execPtr->getImplementationType());
Expand Down Expand Up @@ -928,6 +932,50 @@ bool FullyConnected::useSparseWeightsDecompression() {

return true;
}

MemoryPtr FullyConnected::prepareTransposedWeightMemory(DnnlMemoryDescPtr weightDesc) {
if (!getParentEdgeAt(1)->getParent()->isConstant())
IE_THROW() << "Weight input is not const for node " << getName() << ".";
auto edgeMem = getParentEdgeAt(1)->getMemoryPtr();
if (!edgeMem)
IE_THROW() << "Cannot get const weights edgeMem for node " << getName() << ".";

auto constDnnlMemOutDesc = edgeMem->getDescWithType<DnnlMemoryDesc>();
auto weightSrcDesc = constDnnlMemOutDesc->getDnnlDesc();
weightSrcDesc = {weightSrcDesc.get_dims(), weightSrcDesc.get_data_type(), memory::format_tag::ba};
weightSrcDesc = weightSrcDesc.reshape(weightDesc->getDnnlDesc().get_dims());

auto create = [&] () {
auto newSrcDesc = DnnlExtensionUtils::makeDescriptor(weightSrcDesc);

Memory srcMemory{ getEngine(), newSrcDesc, edgeMem->getData() };
MemoryPtr _ptr = std::make_shared<Memory>(getEngine(), weightDesc);
node::Reorder::reorderData(srcMemory, *_ptr, context->getParamsCache());

return _ptr;
};

MemoryPtr ptr;
const auto& format = weightDesc->serializeFormat();
auto itr = privateWeightCache.find(format);
if (privateWeightCache.end() != itr) {
ptr = itr->second;
} else {
auto weightCache = context->getWeightsCache();
if (weightCache != nullptr) {
const std::string string_hash = getName() + "_" + format
+ "_" + std::to_string(edgeMem->getSize())
+ "_" + std::to_string(reinterpret_cast<uint64_t>(edgeMem->getData()));

ptr = *weightCache->findOrCreate(string_hash, create);
} else {
ptr = create();
}
privateWeightCache[format] = ptr;
}

return ptr;
}
} // namespace node
} // namespace intel_cpu
} // namespace ov
9 changes: 9 additions & 0 deletions src/plugins/intel_cpu/src/nodes/fullyconnected.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ class FullyConnected : public Node {
void prepareParams() override;
void executeDynamicImpl(dnnl::stream strm) override;
bool canBeExecutedInInt8() const override;
void setTransposeWeights(bool transpose) {
transposeWeights = transpose;
}

private:
void createDescriptorInternal(const dnnl::memory::desc &inputDesc,
Expand Down Expand Up @@ -100,6 +103,12 @@ class FullyConnected : public Node {
float weiSparseRate = 0.f;
bool useSparseWeightsDecompression();
VectorDims expectedBiasDims {};

// todo:
// FC with transposed weights
bool transposeWeights = false;
// this method is using to prepare transposed memory
MemoryPtr prepareTransposedWeightMemory(DnnlMemoryDescPtr weightDesc);
};

} // namespace node
Expand Down
4 changes: 3 additions & 1 deletion src/plugins/intel_cpu/src/nodes/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ void Transpose::initSupportedPrimitiveDescriptors() {
config.inConfs[INPUT_ORDER_IDX].constant(isInputOrderConst);
config.inConfs[INPUT_ORDER_IDX].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(
Precision::I32, getInputShapeAtPort(INPUT_ORDER_IDX)));
config.outConfs[0].inPlace(-1);
config.outConfs[0].inPlace(fakeTranspose ? 0 : -1);
config.outConfs[0].constant(false);
transpose_context = std::make_shared<ExecutorContext>(context, getImplPriority());

Expand Down Expand Up @@ -287,6 +287,8 @@ void Transpose::createPrimitive() {
}

void Transpose::execute(dnnl::stream strm) {
if (fakeTranspose)
return;
if (prim) {
prim.execute(strm, primArgs);
} else if (execPtr) {
Expand Down
5 changes: 5 additions & 0 deletions src/plugins/intel_cpu/src/nodes/transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ class Transpose : public Node {
bool needPrepareParams() const override;
void prepareParams() override;

void setFakeTranspose(bool fake) {
fakeTranspose = fake;
}

protected:
void executeDynamicImpl(dnnl::stream strm) override;
std::shared_ptr<ExecutorContext> transpose_context;
Expand All @@ -56,6 +60,7 @@ class Transpose : public Node {
static constexpr size_t INPUT_ORDER_IDX = 1lu;

bool performAsReorder = false;
bool fakeTranspose = false;
};

} // namespace node
Expand Down

0 comments on commit 2afbd44

Please sign in to comment.