Skip to content

Commit

Permalink
update to_onnx
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisbader committed Feb 19, 2025
1 parent d523544 commit 52345bd
Showing 1 changed file with 13 additions and 24 deletions.
37 changes: 13 additions & 24 deletions darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,12 +646,7 @@ def _verify_past_future_covariates(self, past_covariates, future_covariates):
logger=logger,
)

def to_onnx(
self,
path: Optional[str] = None,
input_sample: Optional[tuple] = None,
**kwargs,
):
def to_onnx(self, path: Optional[str] = None, **kwargs):
"""Export model to ONNX format for optimized inference, wrapping around PyTorch Lightning's
:func:`torch.onnx.export` method (`official documentation <https://lightning.ai/docs/pytorch/
stable/common/lightning_module.html#to-onnx>`_).
Expand All @@ -663,13 +658,12 @@ def to_onnx(
.. highlight:: python
.. code-block:: python
from darts.datasets import AirPassengersDataset
from darts.models import DLinearModel
from darts import TimeSeries
import numpy as np
train_ts = TimeSeries.from_values(np.arange(0,100))
series = AirPassengersDataset().load()
model = DLinearModel(input_chunk_length=4, output_chunk_length=1)
model.fit(train_ts, epochs=1)
model.fit(series, epochs=1)
model.to_onnx("my_model.onnx")
..
Expand All @@ -678,14 +672,10 @@ def to_onnx(
path
Path under which to save the model at its current state. If no path is specified, the model
is automatically saved under ``"{ModelClass}_{YYYY-mm-dd_HH_MM_SS}.onnx"``.
input_sample
Tuple of Tensor corresponding to the inputs of the model forward pass. In order to avoid data leakage,
it's recommended to randomize the values as only the shape is important.
**kwargs
Additional kwargs for PyTorch's :func:`torch.onnx.export` method, such as ``verbose`` prints a
description of the model being exported to stdout.
For more information, read the `official documentation <https://pytorch.org/docs/master/
onnx.html#torch.onnx.export>`_.
Additional kwargs for PyTorch's :func:`torch.onnx.export` method (except parameters ``file_path``,
``input_sample``, ``input_name``). For more information, read the `official documentation
<https://pytorch.org/docs/master/onnx.html#torch.onnx.export>`_.
"""
if not self._fit_called:
raise_log(
Expand All @@ -695,13 +685,12 @@ def to_onnx(
if path is None:
path = self._default_save_path() + ".onnx"

if not input_sample:
# last dimension in train_sample_shape is the expected target
mock_batch = tuple(
torch.rand((1,) + shape, dtype=self.model.dtype) if shape else None
for shape in self.model.train_sample_shape[:-1]
)
input_sample = self.model._process_input_batch(mock_batch)
# last dimension in train_sample_shape is the expected target
mock_batch = tuple(
torch.rand((1,) + shape, dtype=self.model.dtype) if shape else None
for shape in self.model.train_sample_shape[:-1]
)
input_sample = self.model._process_input_batch(mock_batch)

# torch models necessarily use historic target values as features in current implementation
input_names = ["x_past"]
Expand Down

0 comments on commit 52345bd

Please sign in to comment.