Skip to content

Commit

Permalink
Used common method to convert types.
Browse files Browse the repository at this point in the history
  • Loading branch information
popovaan committed Apr 18, 2024
1 parent ac6191a commit cc44edd
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 30 deletions.
25 changes: 24 additions & 1 deletion src/bindings/python/src/openvino/frontend/tensorflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import List, Dict, Union

import numpy as np
from openvino.runtime import PartialShape, Dimension
from openvino.runtime import PartialShape, Dimension, Type


# TODO: reuse this method in ovc and remove duplication
Expand Down Expand Up @@ -435,3 +435,26 @@ def model_is_graph_iterator(model):
except:
return False
return isinstance(model, GraphIteratorTFGraph)


def tf_type_to_ov_type(val):
import tensorflow as tf # pylint: disable=import-error
if not isinstance(val, tf.dtypes.DType):
raise Exception("The provided type is not a TF type {}.".format(val))

tf_to_ov_type = {
tf.float32: Type.f32,
tf.float16: Type.f16,
tf.float64: Type.f64,
tf.bfloat16: Type.bf16,
tf.uint8: Type.u8,
tf.int8: Type.i8,
tf.int16: Type.i16,
tf.int32: Type.i32,
tf.int64: Type.i64,
tf.bool: Type.boolean,
tf.string: Type.string
}
if val not in tf_to_ov_type:
raise Exception("The provided data type is not supported by OpenVino {}.".format(val))
return tf_to_ov_type[val]
18 changes: 6 additions & 12 deletions tools/ovc/openvino/tools/ovc/convert_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from pathlib import Path
from typing import Iterable, Callable

import numpy as np

try:
import openvino_telemetry as tm
Expand All @@ -23,6 +22,7 @@
from openvino.tools.ovc.moc_frontend.check_config import any_extensions_used
from openvino.tools.ovc.moc_frontend.pipeline import moc_pipeline
from openvino.tools.ovc.moc_frontend.moc_emit_ir import moc_emit_ir
from openvino.tools.ovc.moc_frontend.type_utils import to_ov_type
from openvino.tools.ovc.cli_parser import get_available_front_ends, get_common_cli_options, depersonalize, \
get_mo_convert_params, input_to_input_cut_info, parse_inputs
from openvino.tools.ovc.help import get_convert_model_help_specifics
Expand All @@ -40,7 +40,7 @@
# pylint: disable=no-name-in-module,import-error
from openvino.frontend import FrontEndManager, OpConversionFailure, TelemetryExtension
from openvino.runtime import get_version as get_rt_version
from openvino.runtime import Type, PartialShape
from openvino.runtime import PartialShape

try:
from openvino.frontend.tensorflow.utils import create_tf_graph_iterator, type_supported_by_tf_fe, \
Expand Down Expand Up @@ -343,11 +343,8 @@ def normalize_inputs(argv: argparse.Namespace):
else:
shape_dict[inp.name] = None
if inp.type is not None:
# Convert type to numpy type for uniformity of stored values
if isinstance(inp.type, (np.dtype, str)):
data_type_dict[inp.name] = Type(inp.type)
else:
data_type_dict[inp.name] = inp.type
# Convert type to ov.Type for uniformity of stored values
data_type_dict[inp.name] = to_ov_type(inp.type)
argv.placeholder_shapes = shape_dict if shape_dict else None
argv.placeholder_data_types = data_type_dict if data_type_dict else {}
else:
Expand All @@ -359,11 +356,8 @@ def normalize_inputs(argv: argparse.Namespace):
# Wrap shape to PartialShape for uniformity of stored values
shape_list.append(PartialShape(inp.shape))
if inp.type is not None:
# Convert type to numpy type for uniformity of stored values
if isinstance(inp.type, (np.dtype, str)):
data_type_list.append(Type(inp.type))
else:
data_type_list.append(inp.type)
# Convert type to ov.Type for uniformity of stored values
data_type_list.append(to_ov_type(inp.type))
argv.placeholder_shapes = shape_list if shape_list else None
argv.placeholder_data_types = data_type_list if data_type_list else {}
if hasattr(argv, "framework") and argv.framework == "pytorch" and getattr(argv, "example_input", None) is not None:
Expand Down
22 changes: 5 additions & 17 deletions tools/ovc/openvino/tools/ovc/moc_frontend/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import sys

import numpy as np

import openvino as ov
from openvino.runtime import Type

Expand All @@ -28,27 +30,13 @@ def is_type(val):
def to_ov_type(val):
if isinstance(val, Type):
return val
if isinstance(val, type):
if isinstance(val, (type, str, np.dtype)):
return Type(val)
if 'tensorflow' in sys.modules:
import tensorflow as tf # pylint: disable=import-error
if isinstance(val, tf.dtypes.DType):
tf_to_ov_type = {
tf.float32: ov.Type.f32,
tf.float16: ov.Type.f16,
tf.float64: ov.Type.f64,
tf.bfloat16: ov.Type.bf16,
tf.uint8: ov.Type.u8,
tf.int8: ov.Type.i8,
tf.int16: ov.Type.i16,
tf.int32: ov.Type.i32,
tf.int64: ov.Type.i64,
tf.bool: ov.Type.boolean,
tf.string: ov.Type.string
}
if val not in tf_to_ov_type:
raise Exception("The provided data time is not supported {}.".format(val))
return tf_to_ov_type[val]
from openvino.frontend.tensorflow.utils import tf_type_to_ov_type # pylint: disable=no-name-in-module,import-error
return tf_type_to_ov_type(val)
if 'torch' in sys.modules:
import torch

Expand Down

0 comments on commit cc44edd

Please sign in to comment.