Skip to content

Commit

Permalink
refactor stateful model impl by removing storage from MemoryInput nod…
Browse files Browse the repository at this point in the history
…e and decrease memory copy from 4 times to 2 times.
  • Loading branch information
ceciliapeng2011 committed Oct 7, 2023
1 parent a13b503 commit 86ddb40
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 21 deletions.
12 changes: 6 additions & 6 deletions src/plugins/intel_cpu/src/infer_request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,13 @@ void InferRequestBase::PushStates() {
auto cur_id = cur_node->getId();
for (const auto& state : memoryStates) {
if (state->GetName() == cur_id) {
auto storage = cur_node->getStore();
auto state_blob = state->GetState();
if (storage->getData() == state_blob->cbuffer().as<const void *>())
continue; // there is no inferrequest switch

auto state_desc = MemoryDescUtils::convertToDnnlBlockedMemoryDesc(state_blob->getTensorDesc());
Memory state_mem(eng, state_desc, state_blob->cbuffer(), false);
auto state_mem = std::make_shared<Memory>(eng, state_desc, state_blob->cbuffer(), false);
cur_node->storeState(state_mem);
}
}
Expand All @@ -129,11 +133,7 @@ void InferRequestBase::PullStates() {
for (const auto& state : memoryStates) {
if (state->GetName() == cur_id) {
auto storage = cur_node->getStore();

//redefine state
auto blob = make_blob_with_precision(MemoryDescUtils::convertToTensorDesc(storage->getDesc()));
blob->allocate();
cpu_memcpy(blob->buffer(), storage->getData(), storage->getSize());
auto blob = make_blob_with_precision(MemoryDescUtils::convertToTensorDesc(storage->getDesc()), storage->getData());
state->SetState(blob);
}
}
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/memory_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class VariableState : public InferenceEngine::IVariableStateInternal {
VariableState(std::string name, MemoryPtr storage)
: InferenceEngine::IVariableStateInternal{name} {
tensor_desc = MemoryDescUtils::convertToTensorDesc(storage->getDesc());
Reset();
}

void Reset() override;
Expand Down
27 changes: 14 additions & 13 deletions src/plugins/intel_cpu/src/nodes/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ namespace ov {
namespace intel_cpu {
namespace node {

inline
static void simple_copy(const IMemory& dst, const IMemory& src);

std::mutex MemoryNodeVirtualEdge::holderMutex;

MemoryNode::MemoryNode(const std::shared_ptr<ngraph::Node>& op) {
Expand Down Expand Up @@ -79,7 +82,15 @@ void MemoryOutput::execute(dnnl::stream strm) {
auto inputMemoryNode = dynamic_cast<MemoryInput*>(inputNode);
IE_ASSERT(inputMemoryNode != nullptr);

inputMemoryNode->storeState(srcMemory);
//redefine storage memory before copy to it
auto storage = inputMemoryNode->getStore();

auto desc = storage->getDescPtr();
const auto new_shape = srcMemory.getStaticDims();
const auto newDesc = desc->cloneWithNewDims(new_shape, true);
storage->redefineDesc(newDesc);

simple_copy(*storage, srcMemory);
}

void MemoryOutput::executeDynamicImpl(dnnl::stream strm) {
Expand Down Expand Up @@ -162,18 +173,8 @@ MemoryPtr MemoryInput::getStore() {
return dataStore;
}

void MemoryInput::storeState(const IMemory &new_state) {
//redefine storage memory before copy to it
const auto &_desc = new_state.getDesc();
auto _shape = _desc.getShape();
VectorDims dims = _shape.getStaticDims();
auto desc = _desc.cloneWithNewDims(dims);
dataStore = std::make_shared<Memory>(getEngine(), desc);

// TODO: Should be next one call:
// dataStore.load(new_state, false);
// But because of performance reason we use simple manual copy
simple_copy(*dataStore, new_state);
void MemoryInput::storeState(const MemoryPtr new_state) {
dataStore = new_state;
}

void MemoryInput::execute(dnnl::stream strm) {
Expand Down
4 changes: 2 additions & 2 deletions src/plugins/intel_cpu/src/nodes/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,10 @@ class MemoryInput : public Input, public MemoryNode {
void createPrimitive() override;

void setInputNode(Node* node) override {}
void storeState(const IMemory& mem);
void storeState(const MemoryPtr mem);
MemoryPtr getStore();
private:
MemoryPtr dataStore;
MemoryPtr dataStore = nullptr;
MemoryNodeVirtualEdge::Holder* holder = nullptr;
};

Expand Down

0 comments on commit 86ddb40

Please sign in to comment.