Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/Pass kwargs to the underlying models fit functions #2460

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
**Improved**
- Added `IQRDetector`, that allows to detect anomalies using the interquartile range algorithm. [#2441] by [Igor Urbanik](https://github.com/u8-igor).
- Made README's forecasting model support table more colorblind-friendly. [#2433](https://github.com/unit8co/darts/pull/2433)
- Allowed passing of kwargs to the `fit` functions of `Prophet` and `AutoARIMA`
- 🔴 Restructured the signatures of `ExponentialSmoothing` `__init__` and `fit` functions so that the passing of additional parameters is consistent with other models
- Keyword arguments to be passed to the underlying model's constructor must now be passed as keyword arguments instead of a `dict` to the `ExponentialSmoothing` constructor
- Keyword arguments to be passed to the underlying model's `fit` function must now be passed to the `ExponentialSmoothing.fit` function instead of the constructor
Comment on lines +17 to +20
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once all suggestions have been addressed, we could formulate it as below :)

Suggested change
- Allowed passing of kwargs to the `fit` functions of `Prophet` and `AutoARIMA`
- 🔴 Restructured the signatures of `ExponentialSmoothing` `__init__` and `fit` functions so that the passing of additional parameters is consistent with other models
- Keyword arguments to be passed to the underlying model's constructor must now be passed as keyword arguments instead of a `dict` to the `ExponentialSmoothing` constructor
- Keyword arguments to be passed to the underlying model's `fit` function must now be passed to the `ExponentialSmoothing.fit` function instead of the constructor
- Improvements to `ForecastingModel` : [#2460](https://github.com/unit8co/darts/pull/2460) by [DavidKleindienst](https://github.com/DavidKleindienst).
- All forecasting models now support keyword arguments `**kwargs` when calling `fit()` that will be passed to the underlying model's fit function.
- 🔴 Changes to `ExponentialSmoothing` for a unified API:
- Removed `fit_kwargs` from `__init__()`. They must now be passed as keyword arguments to `fit()`.
- Parameters to be passed to the underlying model's constructor must now be passed as keyword arguments (instead of an explicit `kwargs` parameter).


**Fixed**
- Fixed a bug when using `historical_forecasts()` with a pre-trained `RegressionModel` that has no target lags `lags=None` but uses static covariates. [#2426](https://github.com/unit8co/darts/pull/2426) by [Dennis Bader](https://github.com/dennisbader).
Expand Down
8 changes: 5 additions & 3 deletions darts/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,9 @@ def __exit__(self, *_):
os.close(fd)


def execute_and_suppress_output(function, logger, suppression_threshold_level, *args):
def execute_and_suppress_output(
function, logger, suppression_threshold_level, *args, **kwargs
):
"""
This function conditionally executes the given function with the given arguments
based on whether the current level of 'logger' is below, above or equal to
Expand All @@ -207,9 +209,9 @@ def execute_and_suppress_output(function, logger, suppression_threshold_level, *
"""
if logger.level >= suppression_threshold_level:
with SuppressStdoutStderr():
return_value = function(*args)
return_value = function(*args, **kwargs)
else:
return_value = function(*args)
return_value = function(*args, **kwargs)
return return_value


Expand Down
11 changes: 9 additions & 2 deletions darts/models/forecasting/auto_arima.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,19 @@ def encode_year(idx):
def supports_multivariate(self) -> bool:
return False

def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None):
def _fit(
self,
series: TimeSeries,
future_covariates: Optional[TimeSeries] = None,
**fit_kwargs,
):
super()._fit(series, future_covariates)
self._assert_univariate(series)
series = self.training_series
self.model.fit(
series.values(), X=future_covariates.values() if future_covariates else None
series.values(),
X=future_covariates.values() if future_covariates else None,
**fit_kwargs,
)
return self

Expand Down
4 changes: 3 additions & 1 deletion darts/models/forecasting/croston.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ def encode_year(idx):
def supports_multivariate(self) -> bool:
return False

def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None):
def _fit(
self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None, **_
):
super()._fit(series, future_covariates)
self._assert_univariate(series)
series = self.training_series
Expand Down
34 changes: 22 additions & 12 deletions darts/models/forecasting/exponential_smoothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
---------------------
"""

from typing import Any, Dict, Optional
from typing import Optional

import numpy as np
import statsmodels.tsa.holtwinters as hw
Expand All @@ -24,8 +24,7 @@ def __init__(
seasonal: Optional[SeasonalityMode] = SeasonalityMode.ADDITIVE,
seasonal_periods: Optional[int] = None,
random_state: int = 0,
kwargs: Optional[Dict[str, Any]] = None,
**fit_kwargs,
**kwargs,
):
"""Exponential Smoothing

Expand Down Expand Up @@ -66,11 +65,6 @@ def __init__(
:func:`statsmodels.tsa.holtwinters.ExponentialSmoothing()`.
See `the documentation
<https://www.statsmodels.org/stable/generated/statsmodels.tsa.holtwinters.ExponentialSmoothing.html>`_.
fit_kwargs
Some optional keyword arguments that will be used to call
:func:`statsmodels.tsa.holtwinters.ExponentialSmoothing.fit()`.
See `the documentation
<https://www.statsmodels.org/stable/generated/statsmodels.tsa.holtwinters.ExponentialSmoothing.fit.html>`_.

Examples
--------
Expand All @@ -96,12 +90,28 @@ def __init__(
self.seasonal = seasonal
self.infer_seasonal_periods = seasonal_periods is None
self.seasonal_periods = seasonal_periods
self.constructor_kwargs = dict() if kwargs is None else kwargs
self.fit_kwargs = fit_kwargs
self.constructor_kwargs = kwargs
self.model = None
np.random.seed(random_state)

def fit(self, series: TimeSeries):
def fit(self, series: TimeSeries, **fit_kwargs):
"""Fit/train the model on the (single) provided series.

Parameters
----------
series
The model will be trained to forecast this time series.
fit_kwargs
Some optional keyword arguments that will be used to call
:func:`statsmodels.tsa.holtwinters.ExponentialSmoothing.fit()`.
See `the documentation
<https://www.statsmodels.org/stable/generated/statsmodels.tsa.holtwinters.ExponentialSmoothing.fit.html>`_.

Returns
-------
self
Fitted model.
"""
super().fit(series)
self._assert_univariate(series)
series = self.training_series
Expand All @@ -128,7 +138,7 @@ def fit(self, series: TimeSeries):
dates=series.time_index if series.has_datetime_index else None,
**self.constructor_kwargs,
)
hw_results = hw_model.fit(**self.fit_kwargs)
hw_results = hw_model.fit(**fit_kwargs)
self.model = hw_results

if self.infer_seasonal_periods:
Expand Down
12 changes: 10 additions & 2 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2902,7 +2902,12 @@ class FutureCovariatesLocalForecastingModel(LocalForecastingModel, ABC):
All implementations must implement the :func:`_fit()` and :func:`_predict()` methods.
"""

def fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None):
def fit(
self,
series: TimeSeries,
future_covariates: Optional[TimeSeries] = None,
**fit_kwargs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would call it kwargs since we already have this for regression models.

Also, should we add this to all forecasting models? E.g. ForecastingModel, and all it's children (local models, ensemble models, regression models, torch models)? In some cases like TorchForecastingModel, it will just not do anything, but at least there is a uniform API.

Suggested change
**fit_kwargs,
**kwargs,

Copy link
Contributor Author

@DavidKleindienst DavidKleindienst Jul 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main concern I have here is the EnsembleModel
In principle I see 3 possibilities of dealing with the kwargs in this case

  1. Do nothing and ignore the kwargs
    • That could be confusing because I might have an EnsembleModel of a set of closely related models which all support the same keyword argument. Then it would be pretty confusing for the User to have kwargs in the function signature but not have them passed to the models
  2. Pass kwargs to all of the models
    • That's pretty impractical because some of my models may accept a certain argument and others won't
  3. instead define kwargs: list[dict] which needs to correspond to the number of models. EnsembleModel.forecasting_model[0].fit gets passed kwargs[0] and so on
    • That option makes most sense to me, but of course we won't really have a unified API in this case

@dennisbader What do you think?

Copy link
Contributor Author

@DavidKleindienst DavidKleindienst Jul 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking through the code, there are some more models where I find it difficult to deal with the kwargs, namely models that are mainly implemented within darts than wrapped from another library, such as Theta, FFT or the Baseline Models.
Let's take for example the Theta model:
There are 1 or 2 calls to hw.SimpleExpSmoothing inside the Theta.fit function, so we could pass the kwargs in those calls (hw.SimpleExpSmoothing seems to allow for meaningful keyword arguments), but I know too little about the Theta model to judge if that`s a meaningful expansion

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main concern I have here is the EnsembleModel In principle I see 3 possibilities of dealing with the kwargs in this case

  1. Do nothing and ignore the kwargs

    • That could be confusing because I might have an EnsembleModel of a set of closely related models which all support the same keyword argument. Then it would be pretty confusing for the User to have kwargs in the function signature but not have them passed to the models
  2. Pass kwargs to all of the models

    • That's pretty impractical because some of my models may accept a certain argument and others won't
  3. instead define kwargs: list[dict] which needs to correspond to the number of models. EnsembleModel.forecasting_model[0].fit gets passed kwargs[0] and so on

    • That option makes most sense to me, but of course we won't really have a unified API in this case

@dennisbader What do you think?

Yes, I see your point. We could do something like this:

  • **kwargs should be passed only to the ensemble model itself (not the underlying forecasting_models). Currently this will only have an effect for RegressionEnsembleModel, where the ensemble model is one of Darts RegressionModels.
  • add a new parameter like your option 3 that expects a single dict or a list of dicts (Optional[Union[Dict[str, Any]], List[Dict[str, Any]]]) that are passed to the forecasting models fit().
    • if a single dict, pass the same to all models
    • if a list of dicts, the length must match the number of forecasting models. Pass each dict to the correspoinding fc model.

What do you think?

):
"""Fit/train the model on the (single) provided series.

Optionally, a future covariates series can be provided as well.
Expand All @@ -2915,6 +2920,9 @@ def fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None
A time series of future-known covariates. This time series will not be forecasted, but can be used by
some models as an input. It must contain at least the same time steps/indices as the target `series`.
If it is longer than necessary, it will be automatically trimmed.
fit_kwargs
Optional keyword arguments that will be passed to the fit function of the underlying model if supported
by the underlying model.

Returns
-------
Expand Down Expand Up @@ -2946,7 +2954,7 @@ def fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None

super().fit(series)

return self._fit(series, future_covariates=future_covariates)
return self._fit(series, future_covariates=future_covariates, **fit_kwargs)

@abstractmethod
def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None):
Expand Down
11 changes: 8 additions & 3 deletions darts/models/forecasting/prophet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,12 @@ def encode_year(idx):
# Use 0 as default value
self._floor = 0

def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None):
def _fit(
self,
series: TimeSeries,
future_covariates: Optional[TimeSeries] = None,
**fit_kwargs,
):
super()._fit(series, future_covariates)
self._assert_univariate(series)
series = self.training_series
Expand Down Expand Up @@ -249,10 +254,10 @@ def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = Non

if self.suppress_stdout_stderr:
self._execute_and_suppress_output(
self.model.fit, logger, logging.WARNING, fit_df
self.model.fit, logger, logging.WARNING, fit_df, **fit_kwargs
)
else:
self.model.fit(fit_df)
self.model.fit(fit_df, **fit_kwargs)

return self

Expand Down
4 changes: 3 additions & 1 deletion darts/models/forecasting/sf_auto_arima.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def encode_year(idx):
super().__init__(add_encoders=add_encoders)
self.model = SFAutoARIMA(*autoarima_args, **autoarima_kwargs)

def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None):
def _fit(
self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None, **_
):
super()._fit(series, future_covariates)
self._assert_univariate(series)
series = self.training_series
Expand Down
4 changes: 3 additions & 1 deletion darts/models/forecasting/sf_auto_ets.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def encode_year(idx):
self.model = SFAutoETS(*autoets_args, **autoets_kwargs)
self._linreg = None

def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None):
def _fit(
self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None, **_
):
super()._fit(series, future_covariates)
self._assert_univariate(series)
series = self.training_series
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_constructor_kwargs(self):
"initial_trend": 0.2,
"initial_seasonal": np.arange(1, 25),
}
model = ExponentialSmoothing(kwargs=constructor_kwargs)
model = ExponentialSmoothing(**constructor_kwargs)
model.fit(self.series)
# must be checked separately, name is not consistent
np.testing.assert_array_almost_equal(
Expand All @@ -70,22 +70,19 @@ def test_fit_kwargs(self):
# using default optimization method
model = ExponentialSmoothing()
model.fit(self.series)
assert model.fit_kwargs == {}
pred = model.predict(n=2)

model_bis = ExponentialSmoothing()
model_bis.fit(self.series)
assert model_bis.fit_kwargs == {}
pred_bis = model_bis.predict(n=2)

# two methods with the same parameters should yield the same forecasts
assert pred.time_index.equals(pred_bis.time_index)
np.testing.assert_array_almost_equal(pred.values(), pred_bis.values())

# change optimization method
model_ls = ExponentialSmoothing(method="least_squares")
model_ls.fit(self.series)
assert model_ls.fit_kwargs == {"method": "least_squares"}
model_ls = ExponentialSmoothing()
model_ls.fit(self.series, method="least_squares")
pred_ls = model_ls.predict(n=2)

# forecasts should be slightly different
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@ def test_model_str_call(self, config):
(
ExponentialSmoothing(),
"ExponentialSmoothing(trend=ModelMode.ADDITIVE, damped=False, seasonal=SeasonalityMode.ADDITIVE, "
+ "seasonal_periods=None, random_state=0, kwargs=None)",
+ "seasonal_periods=None, random_state=0)",
), # no params changed
(
ARIMA(1, 1, 1),
Expand Down
Loading