Skip to content

Commit

Permalink
[OVC] Fail with unsupported message when output argument is used for …
Browse files Browse the repository at this point in the history
…pytorch (#27255)

### Details:
- *`output` argument cannot be used for pytorch as it is unclear what
behavior is expected in this case.*

### Tickets:
 - *#26457 *

---------

Co-authored-by: Roman Kazantsev <[email protected]>
  • Loading branch information
mvafin and rkazants authored Oct 27, 2024
1 parent 4a33ad8 commit afa9231
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 6 deletions.
30 changes: 30 additions & 0 deletions tests/layer_tests/ovc_python_api_tests/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1181,6 +1181,19 @@ def forward(self, a, b):
)}


def create_pytorch_module_with_output(tmp_dir):
class PTModel(torch.nn.Module):
def forward(self, a, b):
return a + b

net = PTModel()
return net, None, {
"example_input": (
torch.tensor([5, 6], dtype=torch.float32),
torch.tensor([5, 6], dtype=torch.float32),
), "output": "some_name"}


class TestMoConvertPyTorch(CommonMOConvertTest):
test_data = [
'create_pytorch_nn_module_case1',
Expand Down Expand Up @@ -1255,6 +1268,23 @@ def test_mo_import_from_memory(self, create_model, ie_device, precision, ir_vers
self._test_by_ref_graph(temp_dir, test_params,
graph_ref, compare_tensor_names=False)

@pytest.mark.parametrize("create_model,exception", [
('create_pytorch_module_with_output', AssertionError)
])
@pytest.mark.nightly
@pytest.mark.precommit
def test_mo_import_from_memory_negative(self, create_model, exception,
ie_device, precision, ir_version,
temp_dir, use_legacy_frontend):
fw_model, graph_ref, mo_params = eval(create_model)(temp_dir)

test_params = {'input_model': fw_model}
if mo_params is not None:
test_params.update(mo_params)
with pytest.raises(exception):
self._test_by_ref_graph(temp_dir, test_params,
graph_ref, compare_tensor_names=False)


def create_pt_model_with_custom_op():
#
Expand Down
9 changes: 3 additions & 6 deletions tools/ovc/openvino/tools/ovc/moc_frontend/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,13 @@ def merge_inputs(inputs, to_set_list):
res.append(p)
return res
iplaces = merge_inputs(model_inputs, iplaces)
# Currently this only work to reorder inputs/outputs
oplaces = []
# Currently this only work to reorder inputs
to_override_all_inputs = check_places_are_same(model_inputs, [{"node": p} for p in iplaces])
to_override_all_outputs = False
if argv.output:
oplaces = []
_outputs = fe_output_user_data_repack(input_model, argv.output, moc_front_end.get_name())
for out_desc in _outputs:
oplaces.append(out_desc["name"])
model_outputs = input_model.get_outputs()
to_override_all_outputs = check_places_are_same(model_outputs, [{"node": p} for p in oplaces])
assert len(_outputs) == 0, "`output` argument is not supported for PyTorch"
if to_override_all_inputs and to_override_all_outputs:
input_model.extract_subgraph(iplaces, oplaces)
elif to_override_all_inputs:
Expand Down

0 comments on commit afa9231

Please sign in to comment.