Skip to content

Commit

Permalink
Merge branch 'master' into sbm
Browse files Browse the repository at this point in the history
  • Loading branch information
iefode authored Jul 3, 2024
2 parents 032f62b + 6e9ee07 commit 7465aeb
Show file tree
Hide file tree
Showing 27 changed files with 830 additions and 75 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ jobs:
Overall_Status:
name: ci/gha_overall_status
needs: [Smart_CI, Build, Debian_Packages, Samples, Conformance, ONNX_Runtime, CXX_Unit_Tests, Python_Unit_Tests, TensorFlow_Layer_Tests,
CPU_Functional_Tests, TensorFlow_Models_Tests_Precommit, PyTorch_Models_Tests, NVIDIA_Plugin, Openvino_tokenizers]
CPU_Functional_Tests, TensorFlow_Models_Tests_Precommit, PyTorch_Models_Tests, NVIDIA_Plugin, Openvino_tokenizers, iGPU]
if: ${{ always() }}
runs-on: ubuntu-latest
steps:
Expand Down
16 changes: 16 additions & 0 deletions src/bindings/python/src/pyopenvino/graph/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,22 @@ void regclass_graph_Model(py::module m) {
:return: Index for value referencing it.
:rtype: int
)");
model.def(
"get_result_index",
[](const ov::Model& model, const ov::op::v0::Result& result) {
return model.get_result_index(result.get_default_output());
},
py::arg("result"),
R"(
Return index of result.
Return -1 if `result` not matched.
:param result: Result operation
:type result: op.Result
:return: Index for result referencing it.
:rtype: int
)");

model.def("get_name",
&ov::Model::get_name,
Expand Down
45 changes: 30 additions & 15 deletions src/bindings/python/tests/test_runtime/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,22 +157,37 @@ def test_get_result_index_invalid():
assert model.get_result_index(invalid_output) == -1


def test_parameter_index():
input_shape = PartialShape([1])
param = ops.parameter(input_shape, dtype=np.float32, name="data")
relu = ops.relu(param, name="relu")
model = Model(relu, [param], "TestModel")
assert model.get_parameter_index(param) == 0


def test_parameter_index_invalid():
shape1 = PartialShape([1])
param1 = ops.parameter(shape1, dtype=np.float32, name="data1")
@pytest.mark.parametrize(("shapes", "relu_names", "model_name", "expected_outputs_length", "is_invalid", "expected_result_index"), [
([PartialShape([1])], ["relu"], "TestModel", 1, False, 0),
([PartialShape([1]), PartialShape([4])], ["relu1", "relu2"], "TestModel1", 1, True, -1)
])
def test_result_index(shapes, relu_names, model_name, expected_outputs_length, is_invalid, expected_result_index):
params = [ops.parameter(shape, dtype=np.float32, name=f"data{i+1}") for i, shape in enumerate(shapes)]
relus = [ops.relu(param, name=relu_name) for param, relu_name in zip(params, relu_names)]

model = Model(relus[0], [params[0]], model_name)
assert len(model.outputs) == expected_outputs_length
if is_invalid:
invalid_result_node = ops.result(relus[1].outputs()[0])
assert model.get_result_index(invalid_result_node) == expected_result_index
else:
assert model.get_result_index(model.get_results()[0]) == expected_result_index


@pytest.mark.parametrize(("shapes", "param_names", "model_name", "expected_index", "is_invalid"), [
([PartialShape([1]), None], ["data", None], "TestModel", 0, False),
([PartialShape([1]), PartialShape([2])], ["data1", "data2"], "TestModel", -1, True)
])
def test_parameter_index(shapes, param_names, model_name, expected_index, is_invalid):
param1 = ops.parameter(shapes[0], dtype=np.float32, name=param_names[0])
relu = ops.relu(param1, name="relu")
model = Model(relu, [param1], "TestModel")
shape2 = PartialShape([2])
param2 = ops.parameter(shape2, dtype=np.float32, name="data2")
assert model.get_parameter_index(param2) == -1
model = Model(relu, [param1], model_name)

if is_invalid:
param2 = ops.parameter(shapes[1], dtype=np.float32, name=param_names[1])
assert model.get_parameter_index(param2) == expected_index
else:
assert model.get_parameter_index(param1) == expected_index


def test_replace_parameter():
Expand Down
2 changes: 2 additions & 0 deletions src/plugins/intel_cpu/src/cpu_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ static const TypeToNameMap& get_type_to_name_tbl() {
{"NV12toBGR", Type::ColorConvert},
{"I420toRGB", Type::ColorConvert},
{"I420toBGR", Type::ColorConvert},
{"Col2Im", Type::Col2Im},
{"MVN", Type::MVN},
{"NormalizeL2", Type::NormalizeL2},
{"ScatterUpdate", Type::ScatterUpdate},
Expand Down Expand Up @@ -305,6 +306,7 @@ std::string NameFromType(const Type type) {
CASE(MVN);
CASE(TensorIterator);
CASE(Convert);
CASE(Col2Im);
CASE(ColorConvert);
CASE(NormalizeL2);
CASE(ScatterUpdate);
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/cpu_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ enum class Type {
TensorIterator,
Convert,
ColorConvert,
Col2Im,
MVN,
NormalizeL2,
ScatterUpdate,
Expand Down
110 changes: 110 additions & 0 deletions src/plugins/intel_cpu/src/nodes/col2im.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "col2im.h"
#include "openvino/reference/col2im.hpp"
#include "openvino/op/col2im.hpp"

namespace ov {
namespace intel_cpu {
namespace node {
Col2Im::Col2Im(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context)
: Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) {
std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) {
OPENVINO_THROW_NOT_IMPLEMENTED(errorMessage);
}
const auto col2Im = ov::as_type_ptr<const ov::op::v15::Col2Im>(op);
strides = col2Im->get_strides();
dilations = col2Im->get_dilations();
padsBegin = col2Im->get_pads_begin();
padsEnd = col2Im->get_pads_end();
}

bool Col2Im::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
try {
if (!ov::is_type<ov::op::v15::Col2Im>(op)) {
errorMessage = "Only opset15 Col2Im operation is supported";
return false;
}
} catch (...) {
return false;
}
return true;
}

void Col2Im::getSupportedDescriptors() {
// Validation is already done in the ov::opset15::Col2Im.
}

void Col2Im::initSupportedPrimitiveDescriptors() {
if (!supportedPrimitiveDescriptors.empty())
return;
ov::element::Type dataPrecision = getOriginalInputPrecisionAtPort(0);
addSupportedPrimDesc(
{{LayoutType::ncsp, dataPrecision}, {LayoutType::ncsp, ov::element::i32}, {LayoutType::ncsp, ov::element::i32}},
{{LayoutType::ncsp, dataPrecision}},
impl_desc_type::ref);
}

bool Col2Im::created() const {
return getType() == Type::Col2Im;
}

bool Col2Im::needPrepareParams() const {
return false;
}

void Col2Im::executeDynamicImpl(dnnl::stream strm) {
execute(strm);
}

template <class T, class T_idx>
void Col2Im::executeImpl() {
ov::reference::col2im<T, T_idx>(
getSrcDataAtPortAs<const T>(0),
ov::Shape{getSrcMemoryAtPort(0)->getStaticDims()},
getSrcDataAtPortAs<const T_idx>(1),
getSrcDataAtPortAs<const T_idx>(2),
getDstDataAtPortAs<T>(0),
strides,
dilations,
padsBegin,
padsEnd);
}

namespace {
struct Col2ImContext {
Col2Im &node;
};
}

template<typename T>
struct Col2Im::Col2ImExecute {
using TData = typename std::tuple_element<0, T>::type;
using TIndex = typename std::tuple_element<1, T>::type;

void operator()(Col2ImContext & ctx) {
ctx.node.executeImpl<TData, TIndex>();
}
};
void Col2Im::execute(dnnl::stream strm) {
auto dataPrecision = getParentEdgeAt(0)->getMemory().getDesc().getPrecision();
auto indexPrecision = getParentEdgeAt(1)->getMemory().getDesc().getPrecision();

Col2ImContext ctx = {
*this
};

OV_SWITCH(intel_cpu, Col2ImExecute, ctx, std::tie(dataPrecision, indexPrecision),
OV_CASE2(ov::element::f32, ov::element::i32, float, int32_t),
OV_CASE2(ov::element::f16, ov::element::i32, ov::float16, int32_t),
OV_CASE2(ov::element::bf16, ov::element::i32, ov::bfloat16, int32_t),
OV_CASE2(ov::element::i32, ov::element::i32, int32_t, int32_t),
OV_CASE2(ov::element::i8, ov::element::i32, int8_t, int32_t),
OV_CASE2(ov::element::u8, ov::element::i32, uint8_t, int32_t))
}
} // namespace node
} // namespace intel_cpu
} // namespace ov
40 changes: 40 additions & 0 deletions src/plugins/intel_cpu/src/nodes/col2im.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "node.h"

namespace ov {
namespace intel_cpu {
namespace node {

class Col2Im : public Node {
public:
Col2Im(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context);

static bool isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept;
void getSupportedDescriptors() override;
void initSupportedPrimitiveDescriptors() override;
void execute(dnnl::stream strm) override;
bool created() const override;
bool needPrepareParams() const override;
void executeDynamicImpl(dnnl::stream strm) override;

private:
template <class OV_DATA_TYPE, class OV_INDEX_TYPE>
void executeImpl();

template<typename T>
struct Col2ImExecute;

ov::Strides strides;
ov::Strides dilations;
ov::Shape padsBegin;
ov::Shape padsEnd;
};

} // namespace node
} // namespace intel_cpu
} // namespace ov
2 changes: 2 additions & 0 deletions src/plugins/intel_cpu/src/nodes_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "nodes/bin_conv.h"
#include "nodes/broadcast.h"
#include "nodes/bucketize.h"
#include "nodes/col2im.h"
#include "nodes/color_convert.h"
#include "nodes/concat.h"
#include "nodes/conv.h"
Expand Down Expand Up @@ -160,6 +161,7 @@ Node::NodesFactory::NodesFactory() : Factory("NodesFactory") {
INTEL_CPU_NODE(Math, Type::Math);
INTEL_CPU_NODE(MultiClassNms, Type::MulticlassNms);
INTEL_CPU_NODE(Convert, Type::Convert);
INTEL_CPU_NODE(Col2Im, Type::Col2Im);
INTEL_CPU_NODE(ColorConvert, Type::ColorConvert);
INTEL_CPU_NODE(EmbeddingBagOffset, Type::EmbeddingBagOffsetsSum);
INTEL_CPU_NODE(EmbeddingBagOffset, Type::EmbeddingBagOffsets);
Expand Down
Loading

0 comments on commit 7465aeb

Please sign in to comment.