Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Request ONNX export support for Cheng2020 model #296

Open
ZhangYuef opened this issue Jun 4, 2024 · 1 comment
Open

Request ONNX export support for Cheng2020 model #296

ZhangYuef opened this issue Jun 4, 2024 · 1 comment

Comments

@ZhangYuef
Copy link

ZhangYuef commented Jun 4, 2024

Feature

Support exporting Cheng2020 model to onnx format.

Motivation

To deploy the model on various hardwares.

Additional context

This is my convertion code:

import torch
from compressai.zoo import models

# net = models["bmshj2018-factorized"](quality=1, metric="mse", pretrained=True)
# net = cheng2020_anchor(quality=5, pretrained=True).to(device)
net = models["cheng2020-anchor"](quality=1, metric="mse", pretrained=True)

# Some dummy input
x = torch.randn(1, 3, 224, 224, requires_grad=True)

# Export the model
torch.onnx.export(net,                       # model being run
                  x,                         # model input (or a tuple for multiple inputs)
                  "cheng2020.onnx",              # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=11,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input': {0 : 'batch_size'},    # variable length axes
                                'output': {0 : 'batch_size'}}
                 )

Error occurs:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[<ipython-input-48-bddb317a9b45>](https://localhost:8080/#) in <cell line: 12>()
     10 
     11 # Export the model
---> 12 torch.onnx.export(net,                       # model being run
     13                   x,                         # model input (or a tuple for multiple inputs)
     14                   "cheng2020.onnx",              # where to save the model (can be a file or file-like object)

15 frames
[/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py](https://localhost:8080/#) in export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions, autograd_inlining)
    514     """
    515 
--> 516     _export(
    517         model,
    518         args,

[/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py](https://localhost:8080/#) in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, onnx_shape_inference, export_modules_as_functions, autograd_inlining)
   1610             _validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
   1611 
-> 1612             graph, params_dict, torch_out = _model_to_graph(
   1613                 model,
   1614                 args,

[/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py](https://localhost:8080/#) in _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size, training, dynamic_axes)
   1132 
   1133     model = _pre_trace_quant_model(model, args)
-> 1134     graph, params, torch_out, module = _create_jit_graph(model, args)
   1135     params_dict = _get_named_param_dict(graph, params)
   1136 

[/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py](https://localhost:8080/#) in _create_jit_graph(model, args)
   1008         return graph, params, torch_out, None
   1009 
-> 1010     graph, torch_out = _trace_and_get_graph_from_model(model, args)
   1011     _C._jit_pass_onnx_lint(graph)
   1012     state_dict = torch.jit._unique_state_dict(model)

[/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py](https://localhost:8080/#) in _trace_and_get_graph_from_model(model, args)
    912     prev_autocast_cache_enabled = torch.is_autocast_cache_enabled()
    913     torch.set_autocast_cache_enabled(False)
--> 914     trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
    915         model,
    916         args,

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py](https://localhost:8080/#) in _fn(*args, **kwargs)
    449             prior = set_eval_frame(callback)
    450             try:
--> 451                 return fn(*args, **kwargs)
    452             finally:
    453                 set_eval_frame(prior)

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py](https://localhost:8080/#) in inner(*args, **kwargs)
     34     @functools.wraps(fn)
     35     def inner(*args, **kwargs):
---> 36         return fn(*args, **kwargs)
     37 
     38     return inner

[/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py](https://localhost:8080/#) in _get_trace_graph(f, args, kwargs, strict, _force_outplace, return_inputs, _return_inputs_states)
   1308     if not isinstance(args, tuple):
   1309         args = (args,)
-> 1310     outs = ONNXTracedModule(
   1311         f, strict, _force_outplace, return_inputs, _return_inputs_states
   1312     )(*args, **kwargs)

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1530             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531         else:
-> 1532             return self._call_impl(*args, **kwargs)
   1533 
   1534     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541             return forward_call(*args, **kwargs)
   1542 
   1543         try:

[/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py](https://localhost:8080/#) in forward(self, *args)
    136                 return tuple(out_vars)
    137 
--> 138         graph, out = torch._C._create_graph_by_tracing(
    139             wrapper,
    140             in_vars + module_state,

[/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py](https://localhost:8080/#) in wrapper(*args)
    127             if self._return_inputs_states:
    128                 inputs_states.append(_unflatten(in_args, in_desc))
--> 129             outs.append(self.inner(*trace_inputs))
    130             if self._return_inputs_states:
    131                 inputs_states[0] = (inputs_states[0], trace_inputs)

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1530             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531         else:
-> 1532             return self._call_impl(*args, **kwargs)
   1533 
   1534     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541             return forward_call(*args, **kwargs)
   1542 
   1543         try:

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _slow_forward(self, *input, **kwargs)
   1520                 recording_scopes = False
   1521         try:
-> 1522             result = self.forward(*input, **kwargs)
   1523         finally:
   1524             if recording_scopes:

[/usr/local/lib/python3.10/dist-packages/compressai/models/google.py](https://localhost:8080/#) in forward(self, x)
    543         ctx_params = self.context_prediction(y_hat)
    544         gaussian_params = self.entropy_parameters(
--> 545             torch.cat((params, ctx_params), dim=1)
    546         )
    547         scales_hat, means_hat = gaussian_params.chunk(2, 1)

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 16 but got size 14 for tensor number 1 in the list.
@ZhangYuef
Copy link
Author

Related ISSUE #87.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant