Skip to content

Commit

Permalink
Fixed pytorch utils.
Browse files Browse the repository at this point in the history
  • Loading branch information
popovaan committed Apr 18, 2024
1 parent 3515cab commit ac6191a
Showing 1 changed file with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,9 @@ def extract_input_info_from_example(args, inputs):
dtype = getattr(example_input, "dtype", type(example_input))
example_dtype = pt_to_ov_type_map.get(str(dtype))
user_dtype = get_value_from_list_or_dict(data_types, input_name, input_id)
if user_dtype is not None and example_dtype is not None and example_dtype.to_dtype() != user_dtype:
if user_dtype is not None and example_dtype is not None and example_dtype != user_dtype:
raise Error(
f"Defined input type {user_dtype} is not equal to provided example_input type {example_dtype.to_dtype()}")
f"Defined input type {user_dtype} is not equal to provided example_input type {example_dtype}")

data_rank = getattr(example_input, "ndim", 0)
user_input_shape = get_value_from_list_or_dict(input_shapes, input_name, input_id)
Expand All @@ -143,7 +143,7 @@ def extract_input_info_from_example(args, inputs):

input_shape = user_input_shape if user_input_shape is not None else PartialShape([-1] * data_rank)
update_list_or_dict(data_types, input_name, input_id,
example_dtype.to_dtype() if example_dtype is not None else None)
example_dtype if example_dtype is not None else None)
update_list_or_dict(input_shapes, input_name, input_id, input_shape)
else:
for input_id, example_input in enumerate(list_inputs):
Expand All @@ -153,7 +153,7 @@ def extract_input_info_from_example(args, inputs):
input_shape = PartialShape([-1] * data_rank)
input_name = input_names[input_id] if input_names else None
update_list_or_dict(input_shapes, input_name, input_id, input_shape)
update_list_or_dict(data_types, input_name, input_id, ov_dtype.to_dtype() if ov_dtype is not None else None)
update_list_or_dict(data_types, input_name, input_id, ov_dtype if ov_dtype is not None else None)

args.placeholder_data_types = data_types
args.placeholder_shapes = input_shapes
Expand Down

0 comments on commit ac6191a

Please sign in to comment.