Skip to content

Commit

Permalink
[core] Migrate OneHot operator to new API (openvinotoolkit#21038)
Browse files Browse the repository at this point in the history
* Drop ngraph remains

* Use ov::Tensor

instaed of ngraph::HostTensor

* Set output shape
  • Loading branch information
t-jankowski authored Nov 15, 2023
1 parent 00d0c0a commit 3558f09
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 78 deletions.
4 changes: 1 addition & 3 deletions src/core/include/openvino/op/one_hot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ class OPENVINO_API OneHot : public Op {
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
void validate_and_infer_types() override;

OPENVINO_SUPPRESS_DEPRECATED_START
bool evaluate(const HostTensorVector& output_values, const HostTensorVector& input_values) const override;
OPENVINO_SUPPRESS_DEPRECATED_END
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
bool has_evaluate() const override;

/// \return The index of the one-hot axis.
Expand Down
156 changes: 81 additions & 75 deletions src/core/src/op/one_hot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,56 @@
// SPDX-License-Identifier: Apache-2.0
//

#include "ngraph/op/one_hot.hpp"

#include <one_hot_shape_inference.hpp>
#include "openvino/op/one_hot.hpp"

#include "element_visitor.hpp"
#include "itt.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/validation_util.hpp"
#include "one_hot_shape_inference.hpp"
#include "openvino/core/attribute_visitor.hpp"
#include "openvino/op/util/precision_sensitive_attribute.hpp"
#include "openvino/reference/one_hot.hpp"

using namespace std;
using namespace ngraph;
namespace ov {
namespace op {
namespace one_hot {
struct Evaluate : element::NoAction<bool> {
using element::NoAction<bool>::visit;

template <element::Type_t ET, class T = fundamental_type_for<ET>>
static result_type visit(const Tensor& indices,
const Shape& indices_shape,
char* const output_data,
const size_t output_et_size,
const int64_t one_hot_axis,
const char* const on_value,
const char* const off_value,
const int64_t axis) {
reference::one_hot(indices.data<const T>(),
indices_shape,
output_data,
output_et_size,
one_hot_axis,
axis,
on_value,
off_value);
return true;
}
};
} // namespace one_hot

op::v1::OneHot::OneHot(const Output<Node>& indices,
const Output<Node>& depth,
const Output<Node>& on_value,
const Output<Node>& off_value,
int64_t axis)
namespace v1 {
OneHot::OneHot(const Output<Node>& indices,
const Output<Node>& depth,
const Output<Node>& on_value,
const Output<Node>& off_value,
int64_t axis)
: Op({indices, depth, on_value, off_value}),
m_axis(axis) {
ov::mark_as_precision_sensitive(input(1));
mark_as_precision_sensitive(input(1));
constructor_validate_and_infer_types();
}

void op::v1::OneHot::validate_and_infer_types() {
void OneHot::validate_and_infer_types() {
OV_OP_SCOPE(v1_OneHot_validate_and_infer_types);
const auto& indices_et = get_input_element_type(0);
const auto& depth_et = get_input_element_type(1);
Expand Down Expand Up @@ -58,86 +82,68 @@ void op::v1::OneHot::validate_and_infer_types() {
set_output_type(0, on_value_et, output_shapes[0]);
}

bool ngraph::op::v1::OneHot::visit_attributes(AttributeVisitor& visitor) {
bool OneHot::visit_attributes(AttributeVisitor& visitor) {
OV_OP_SCOPE(v1_OneHot_visit_attributes);
visitor.on_attribute("axis", m_axis);
return true;
}

shared_ptr<Node> op::v1::OneHot::clone_with_new_inputs(const OutputVector& new_args) const {
std::shared_ptr<Node> OneHot::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v1_OneHot_clone_with_new_inputs);
check_new_args_count(this, new_args);
return make_shared<v1::OneHot>(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), m_axis);
return std::make_shared<v1::OneHot>(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), m_axis);
}

OPENVINO_SUPPRESS_DEPRECATED_START
namespace one_hot {
namespace {
template <element::Type_t T>
bool evaluate(const HostTensorVector& output_values, const HostTensorVector& input_values, const int64_t axis) {
using INPUT_TYPE = typename element_type_traits<T>::value_type;
const auto& indices = input_values[0];
const auto& on_value = input_values[2];
const auto& off_value = input_values[3];
const auto& out = output_values[0];
ov::reference::one_hot<INPUT_TYPE>(indices->get_data_ptr<INPUT_TYPE>(),
indices->get_shape(),
out->get_data_ptr<char>(),
out->get_element_type().size(),
out->get_shape()[axis],
axis,
on_value->get_data_ptr<char>(),
off_value->get_data_ptr<char>());
return true;
}
bool evaluate_onehot(const HostTensorVector& output_values, const HostTensorVector& input_values, const int64_t axis) {
bool rc = true;
const auto& indices = input_values[0];
switch (indices->get_element_type()) {
OPENVINO_TYPE_CASE(evaluate_onehot, i32, output_values, input_values, axis);
OPENVINO_TYPE_CASE(evaluate_onehot, i64, output_values, input_values, axis);
default:
rc = false;
}
return rc;
}
} // namespace
} // namespace one_hot

bool op::v1::OneHot::evaluate(const HostTensorVector& output_values, const HostTensorVector& input_values) const {
bool OneHot::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
OV_OP_SCOPE(v1_OneHot_evaluate);
OPENVINO_SUPPRESS_DEPRECATED_START
OPENVINO_ASSERT(validate_host_tensor_vector(input_values, 4));
OPENVINO_ASSERT(validate_host_tensor_vector(output_values, 1));
OPENVINO_SUPPRESS_DEPRECATED_END

const auto& ind_Pshape = input_values[0]->get_partial_shape();
const auto& out_Pshape = output_values[0]->get_partial_shape();
OPENVINO_ASSERT(ind_Pshape.is_static() && out_Pshape.is_static(), "Only static input/output shapes are supported");
const auto out_shape = out_Pshape.get_shape();
const int64_t axis = get_axis();
OPENVINO_ASSERT(axis >= 0 && static_cast<size_t>(axis) < out_shape.size(), "Invalid axis value.");
const auto depth = std::make_shared<op::v0::Constant>(input_values[1])->cast_vector<int64_t>()[0];
const auto ind_shape = ind_Pshape.get_shape();
OPENVINO_ASSERT(shape_size(ind_shape) * depth == shape_size(out_shape),
OPENVINO_ASSERT(inputs.size() == 4 && outputs.size() == 1);

const auto output_shape =
shape_infer(this, ov::util::get_tensors_partial_shapes(inputs), make_tensor_accessor(inputs))
.front()
.to_shape();
const auto axis = get_axis();
OPENVINO_ASSERT(axis >= 0 && static_cast<size_t>(axis) < output_shape.size(), "Invalid axis value.");

const auto depth = v0::Constant{inputs[1]}.cast_vector<int64_t>()[0];
OPENVINO_ASSERT(static_cast<int64_t>(output_shape[axis]) == depth, "Incompatible axis and depth values.");

const auto& indices = inputs[0];
const auto& indices_shape = indices.get_shape();
OPENVINO_ASSERT(shape_size(indices_shape) * depth == shape_size(output_shape),
"Incompatible I/O shapes or wrong depth value.");
OPENVINO_ASSERT(static_cast<int64_t>(out_shape[axis]) == depth, "Incompatible axis and depth values.");
return one_hot::evaluate_onehot(output_values, input_values, axis);

const auto on_value = static_cast<const char*>(inputs[2].data());
const auto off_value = static_cast<const char*>(inputs[3].data());
auto& output = outputs[0];
output.set_shape(output_shape);
using namespace ov::element;
return IfTypeOf<i32, i64>::apply<one_hot::Evaluate>(indices.get_element_type(),
indices,
indices_shape,
static_cast<char*>(output.data()),
output.get_element_type().size(),
output.get_shape()[axis],
on_value,
off_value,
axis);
}

bool op::v1::OneHot::has_evaluate() const {
bool OneHot::has_evaluate() const {
OV_OP_SCOPE(v1_OneHot_has_evaluate);
switch (get_input_element_type(0)) {
case ngraph::element::i32:
case ngraph::element::i64:
case element::i32:
case element::i64:
return true;
default:
break;
return false;
}
return false;
}

void op::v1::OneHot::set_axis(int64_t axis) {
void OneHot::set_axis(int64_t axis) {
m_axis = axis;
resolve_axis(this);
}
} // namespace v1
} // namespace op
} // namespace ov

0 comments on commit 3558f09

Please sign in to comment.