forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
27 changed files
with
830 additions
and
75 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,6 +61,7 @@ enum class Type { | |
TensorIterator, | ||
Convert, | ||
ColorConvert, | ||
Col2Im, | ||
MVN, | ||
NormalizeL2, | ||
ScatterUpdate, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "col2im.h" | ||
#include "openvino/reference/col2im.hpp" | ||
#include "openvino/op/col2im.hpp" | ||
|
||
namespace ov { | ||
namespace intel_cpu { | ||
namespace node { | ||
Col2Im::Col2Im(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context) | ||
: Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) { | ||
std::string errorMessage; | ||
if (!isSupportedOperation(op, errorMessage)) { | ||
OPENVINO_THROW_NOT_IMPLEMENTED(errorMessage); | ||
} | ||
const auto col2Im = ov::as_type_ptr<const ov::op::v15::Col2Im>(op); | ||
strides = col2Im->get_strides(); | ||
dilations = col2Im->get_dilations(); | ||
padsBegin = col2Im->get_pads_begin(); | ||
padsEnd = col2Im->get_pads_end(); | ||
} | ||
|
||
bool Col2Im::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept { | ||
try { | ||
if (!ov::is_type<ov::op::v15::Col2Im>(op)) { | ||
errorMessage = "Only opset15 Col2Im operation is supported"; | ||
return false; | ||
} | ||
} catch (...) { | ||
return false; | ||
} | ||
return true; | ||
} | ||
|
||
void Col2Im::getSupportedDescriptors() { | ||
// Validation is already done in the ov::opset15::Col2Im. | ||
} | ||
|
||
void Col2Im::initSupportedPrimitiveDescriptors() { | ||
if (!supportedPrimitiveDescriptors.empty()) | ||
return; | ||
ov::element::Type dataPrecision = getOriginalInputPrecisionAtPort(0); | ||
addSupportedPrimDesc( | ||
{{LayoutType::ncsp, dataPrecision}, {LayoutType::ncsp, ov::element::i32}, {LayoutType::ncsp, ov::element::i32}}, | ||
{{LayoutType::ncsp, dataPrecision}}, | ||
impl_desc_type::ref); | ||
} | ||
|
||
bool Col2Im::created() const { | ||
return getType() == Type::Col2Im; | ||
} | ||
|
||
bool Col2Im::needPrepareParams() const { | ||
return false; | ||
} | ||
|
||
void Col2Im::executeDynamicImpl(dnnl::stream strm) { | ||
execute(strm); | ||
} | ||
|
||
template <class T, class T_idx> | ||
void Col2Im::executeImpl() { | ||
ov::reference::col2im<T, T_idx>( | ||
getSrcDataAtPortAs<const T>(0), | ||
ov::Shape{getSrcMemoryAtPort(0)->getStaticDims()}, | ||
getSrcDataAtPortAs<const T_idx>(1), | ||
getSrcDataAtPortAs<const T_idx>(2), | ||
getDstDataAtPortAs<T>(0), | ||
strides, | ||
dilations, | ||
padsBegin, | ||
padsEnd); | ||
} | ||
|
||
namespace { | ||
struct Col2ImContext { | ||
Col2Im &node; | ||
}; | ||
} | ||
|
||
template<typename T> | ||
struct Col2Im::Col2ImExecute { | ||
using TData = typename std::tuple_element<0, T>::type; | ||
using TIndex = typename std::tuple_element<1, T>::type; | ||
|
||
void operator()(Col2ImContext & ctx) { | ||
ctx.node.executeImpl<TData, TIndex>(); | ||
} | ||
}; | ||
void Col2Im::execute(dnnl::stream strm) { | ||
auto dataPrecision = getParentEdgeAt(0)->getMemory().getDesc().getPrecision(); | ||
auto indexPrecision = getParentEdgeAt(1)->getMemory().getDesc().getPrecision(); | ||
|
||
Col2ImContext ctx = { | ||
*this | ||
}; | ||
|
||
OV_SWITCH(intel_cpu, Col2ImExecute, ctx, std::tie(dataPrecision, indexPrecision), | ||
OV_CASE2(ov::element::f32, ov::element::i32, float, int32_t), | ||
OV_CASE2(ov::element::f16, ov::element::i32, ov::float16, int32_t), | ||
OV_CASE2(ov::element::bf16, ov::element::i32, ov::bfloat16, int32_t), | ||
OV_CASE2(ov::element::i32, ov::element::i32, int32_t, int32_t), | ||
OV_CASE2(ov::element::i8, ov::element::i32, int8_t, int32_t), | ||
OV_CASE2(ov::element::u8, ov::element::i32, uint8_t, int32_t)) | ||
} | ||
} // namespace node | ||
} // namespace intel_cpu | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "node.h" | ||
|
||
namespace ov { | ||
namespace intel_cpu { | ||
namespace node { | ||
|
||
class Col2Im : public Node { | ||
public: | ||
Col2Im(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context); | ||
|
||
static bool isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept; | ||
void getSupportedDescriptors() override; | ||
void initSupportedPrimitiveDescriptors() override; | ||
void execute(dnnl::stream strm) override; | ||
bool created() const override; | ||
bool needPrepareParams() const override; | ||
void executeDynamicImpl(dnnl::stream strm) override; | ||
|
||
private: | ||
template <class OV_DATA_TYPE, class OV_INDEX_TYPE> | ||
void executeImpl(); | ||
|
||
template<typename T> | ||
struct Col2ImExecute; | ||
|
||
ov::Strides strides; | ||
ov::Strides dilations; | ||
ov::Shape padsBegin; | ||
ov::Shape padsEnd; | ||
}; | ||
|
||
} // namespace node | ||
} // namespace intel_cpu | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.