Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix lags_past_covariates dict type breaks lags_future_covariates when output_chunk_shift>0 #2655

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co

**Improved**

- Fixed a bug when initiating a `RegressionModel` with `lags_past_covariates` as dict and `lags_future_covariates` as some other type (not dict) and `output_chunk_shift>0`, [#2652
](https://github.com/unit8co/darts/issues/2652) by [Jules Authier](https://github.com/authierj).
- New model: `StatsForecastAutoTBATS`. This model offers the [AutoTBATS](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#autotbats) model from Nixtla's `statsforecasts` library. [#2611](https://github.com/unit8co/darts/pull/2611) by [He Weilin](https://github.com/cnhwl).
- Added the `title` attribute to `TimeSeries.plot()`. This allows to set a title for the plot. [#2639](https://github.com/unit8co/darts/pull/2639) by [Jonathan Koch](https://github.com/jonathankoch99).
- Added parameter `component_wise` to `show_anomalies()` to separately plot each component in multivariate series. [#2544](https://github.com/unit8co/darts/pull/2544) by [He Weilin](https://github.com/cnhwl).
Expand Down
4 changes: 3 additions & 1 deletion darts/models/forecasting/regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,9 @@ def _generate_lags(
processed_lags[lags_abbrev] = [
lag_ + output_chunk_shift for lag_ in processed_lags[lags_abbrev]
]
if processed_component_lags:
if processed_component_lags and list(tmp_components_lags.keys()) != [
"default_lags"
]:
processed_component_lags[lags_abbrev] = {
comp_: [lag_ + output_chunk_shift for lag_ in lags_]
for comp_, lags_ in processed_component_lags[
Expand Down
Loading