You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
importtorchfromcompressai.zooimportmodels# net = models["bmshj2018-factorized"](quality=1, metric="mse", pretrained=True)# net = cheng2020_anchor(quality=5, pretrained=True).to(device)net=models["cheng2020-anchor"](quality=1, metric="mse", pretrained=True)
# Some dummy inputx=torch.randn(1, 3, 224, 224, requires_grad=True)
# Export the modeltorch.onnx.export(net, # model being runx, # model input (or a tuple for multiple inputs)"cheng2020.onnx", # where to save the model (can be a file or file-like object)export_params=True, # store the trained parameter weights inside the model fileopset_version=11, # the ONNX version to export the model todo_constant_folding=True, # whether to execute constant folding for optimizationinput_names= ['input'], # the model's input namesoutput_names= ['output'], # the model's output namesdynamic_axes={'input': {0 : 'batch_size'}, # variable length axes'output': {0 : 'batch_size'}}
)
Error occurs:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
[<ipython-input-48-bddb317a9b45>](https://localhost:8080/#) in <cell line: 12>()
10
11 # Export the model
---> 12 torch.onnx.export(net, # model being run
13 x, # model input (or a tuple for multiple inputs)
14 "cheng2020.onnx", # where to save the model (can be a file or file-like object)
15 frames
[/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py](https://localhost:8080/#) in export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions, autograd_inlining)
514 """ 515 --> 516 _export( 517 model, 518 args,[/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py](https://localhost:8080/#) in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, onnx_shape_inference, export_modules_as_functions, autograd_inlining) 1610 _validate_dynamic_axes(dynamic_axes, model, input_names, output_names) 1611 -> 1612 graph, params_dict, torch_out = _model_to_graph( 1613 model, 1614 args,[/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py](https://localhost:8080/#) in _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size, training, dynamic_axes) 1132 1133 model = _pre_trace_quant_model(model, args)-> 1134 graph, params, torch_out, module = _create_jit_graph(model, args) 1135 params_dict = _get_named_param_dict(graph, params) 1136 [/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py](https://localhost:8080/#) in _create_jit_graph(model, args) 1008 return graph, params, torch_out, None 1009 -> 1010 graph, torch_out = _trace_and_get_graph_from_model(model, args) 1011 _C._jit_pass_onnx_lint(graph) 1012 state_dict = torch.jit._unique_state_dict(model)[/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py](https://localhost:8080/#) in _trace_and_get_graph_from_model(model, args) 912 prev_autocast_cache_enabled = torch.is_autocast_cache_enabled() 913 torch.set_autocast_cache_enabled(False)--> 914 trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph( 915 model, 916 args,[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py](https://localhost:8080/#) in _fn(*args, **kwargs) 449 prior = set_eval_frame(callback) 450 try:--> 451 return fn(*args, **kwargs) 452 finally: 453 set_eval_frame(prior)[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py](https://localhost:8080/#) in inner(*args, **kwargs) 34 @functools.wraps(fn) 35 def inner(*args, **kwargs):---> 36 return fn(*args, **kwargs) 37 38 return inner[/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py](https://localhost:8080/#) in _get_trace_graph(f, args, kwargs, strict, _force_outplace, return_inputs, _return_inputs_states) 1308 if not isinstance(args, tuple): 1309 args = (args,)-> 1310 outs = ONNXTracedModule( 1311 f, strict, _force_outplace, return_inputs, _return_inputs_states 1312 )(*args, **kwargs)[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs) 1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1531 else:-> 1532 return self._call_impl(*args, **kwargs) 1533 1534 def _call_impl(self, *args, **kwargs):[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs) 1539 or _global_backward_pre_hooks or _global_backward_hooks 1540 or _global_forward_hooks or _global_forward_pre_hooks):-> 1541 return forward_call(*args, **kwargs) 1542 1543 try:[/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py](https://localhost:8080/#) in forward(self, *args) 136 return tuple(out_vars) 137 --> 138 graph, out = torch._C._create_graph_by_tracing( 139 wrapper, 140 in_vars + module_state,[/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py](https://localhost:8080/#) in wrapper(*args) 127 if self._return_inputs_states: 128 inputs_states.append(_unflatten(in_args, in_desc))--> 129 outs.append(self.inner(*trace_inputs)) 130 if self._return_inputs_states: 131 inputs_states[0] = (inputs_states[0], trace_inputs)[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs) 1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1531 else:-> 1532 return self._call_impl(*args, **kwargs) 1533 1534 def _call_impl(self, *args, **kwargs):[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs) 1539 or _global_backward_pre_hooks or _global_backward_hooks 1540 or _global_forward_hooks or _global_forward_pre_hooks):-> 1541 return forward_call(*args, **kwargs) 1542 1543 try:[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _slow_forward(self, *input, **kwargs) 1520 recording_scopes = False 1521 try:-> 1522 result = self.forward(*input, **kwargs) 1523 finally: 1524 if recording_scopes:[/usr/local/lib/python3.10/dist-packages/compressai/models/google.py](https://localhost:8080/#) in forward(self, x) 543 ctx_params = self.context_prediction(y_hat) 544 gaussian_params = self.entropy_parameters(--> 545 torch.cat((params, ctx_params), dim=1) 546 ) 547 scales_hat, means_hat = gaussian_params.chunk(2, 1)RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 16 but got size 14 for tensor number 1 in the list.
The text was updated successfully, but these errors were encountered:
Feature
Support exporting Cheng2020 model to onnx format.
Motivation
To deploy the model on various hardwares.
Additional context
This is my convertion code:
Error occurs:
The text was updated successfully, but these errors were encountered: