Skip to content

Commit

Permalink
Feat/slice intersect multi series (#2592)
Browse files Browse the repository at this point in the history
* Add a function to find the intersection of multiple time series

* Add empty sequence treatment.

* Used logic of slice_intersect for faster code.

* Improved complexity of slice_intersect.

* Intersect times indexes to avoid creating new TimeSeries for optimization purposes.

* Add early exit in case of empty intersection.

* Modified slice_intersect for empty TimeSeries treatment.

* Added an entry in the changelog.

* last updates for pr

---------

Co-authored-by: madtoinou <[email protected]>
Co-authored-by: dennisbader <[email protected]>
  • Loading branch information
3 people authored Jan 24, 2025
1 parent c52ec83 commit 1d7f0d1
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 21 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- New model: `StatsForecastAutoTBATS`. This model offers the [AutoTBATS](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#autotbats) model from Nixtla's `statsforecasts` library. [#2611](https://github.com/unit8co/darts/pull/2611) by [He Weilin](https://github.com/cnhwl).
- Added the `title` attribute to `TimeSeries.plot()`. This allows to set a title for the plot. [#2639](https://github.com/unit8co/darts/pull/2639) by [Jonathan Koch](https://github.com/jonathankoch99).
- Added parameter `component_wise` to `show_anomalies()` to separately plot each component in multivariate series. [#2544](https://github.com/unit8co/darts/pull/2544) by [He Weilin](https://github.com/cnhwl).
- Added general function `darts.slice_intersect()` to intersect a sequence of `TimeSeries` along the time index. [#2592](https://github.com/unit8co/darts/pull/2592) by [Yoav Matzkevich](https://github.com/ymatzkevich).

**Fixed**
- Fixed a bug when performing optimized historical forecasts with `stride=1` using a `RegressionModel` with `output_chunk_shift>=1` and `output_chunk_length=1`, where the forecast time index was not properly shifted. [#2634](https://github.com/unit8co/darts/pull/2634) by [Mattias De Charleroy](https://github.com/MattiasDC).
Expand Down
4 changes: 2 additions & 2 deletions darts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import matplotlib as mpl
from matplotlib import cycler

from darts.timeseries import TimeSeries, concatenate
from darts.timeseries import TimeSeries, concatenate, slice_intersect

__version__ = "0.32.0"

Expand Down Expand Up @@ -42,4 +42,4 @@
if os.getenv("DARTS_CONFIGURE_MATPLOTLIB", "1") != "0":
mpl.rcParams.update(u8plots_mplstyle)

__all__ = ["TimeSeries", "concatenate"]
__all__ = ["TimeSeries", "concatenate", "slice_intersect"]
66 changes: 50 additions & 16 deletions darts/tests/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import xarray as xr
from scipy.stats import kurtosis, skew

from darts import TimeSeries, concatenate
from darts import TimeSeries, concatenate, slice_intersect
from darts.utils.timeseries_generation import constant_timeseries, linear_timeseries
from darts.utils.utils import expand_arr, freqs, generate_index

Expand Down Expand Up @@ -603,6 +603,11 @@ def check_intersect(other, start_, end_, freq_):
s_int_idx = series.slice_intersect_times(other, copy=False)
assert s_int.time_index.equals(s_int_idx)

assert slice_intersect([series, other]) == [
series.slice_intersect(other),
other.slice_intersect(series),
]

# slice with exact range
startA = start
endA = end
Expand All @@ -611,11 +616,11 @@ def check_intersect(other, start_, end_, freq_):
check_intersect(seriesA, startA, endA, freq_expected)

# entire slice within the range
startA = start + freq
endA = startA + 6 * freq_other
idxA = generate_index(startA, endA, freq=freq_other)
seriesA = TimeSeries.from_series(pd.Series(range(len(idxA)), index=idxA))
check_intersect(seriesA, startA, endA, freq_expected)
startB = start + freq
endB = startB + 6 * freq_other
idxB = generate_index(startB, endB, freq=freq_other)
seriesB = TimeSeries.from_series(pd.Series(range(len(idxB)), index=idxB))
check_intersect(seriesB, startB, endB, freq_expected)

# start outside of range
startC = start - 4 * freq
Expand All @@ -625,11 +630,11 @@ def check_intersect(other, start_, end_, freq_):
check_intersect(seriesC, start, endC, freq_expected)

# end outside of range
startC = start + 4 * freq
endC = end + 4 * freq_other
idxC = generate_index(startC, endC, freq=freq_other)
seriesC = TimeSeries.from_series(pd.Series(range(len(idxC)), index=idxC))
check_intersect(seriesC, startC, end, freq_expected)
startD = start + 4 * freq
endD = end + 4 * freq_other
idxD = generate_index(startD, endD, freq=freq_other)
seriesD = TimeSeries.from_series(pd.Series(range(len(idxD)), index=idxD))
check_intersect(seriesD, startD, end, freq_expected)

# small intersect
startE = start + (n_steps - 1) * freq
Expand All @@ -639,12 +644,41 @@ def check_intersect(other, start_, end_, freq_):
check_intersect(seriesE, startE, end, freq_expected)

# No intersect
startG = end + 3 * freq
endG = startG + 6 * freq_other
idxG = generate_index(startG, endG, freq=freq_other)
seriesG = TimeSeries.from_series(pd.Series(range(len(idxG)), index=idxG))
startF = end + 3 * freq
endF = startF + 6 * freq_other
idxF = generate_index(startF, endF, freq=freq_other)
seriesF = TimeSeries.from_series(pd.Series(range(len(idxF)), index=idxF))
# for empty slices, we expect the original freq
check_intersect(seriesG, None, None, freq)
check_intersect(seriesF, None, None, freq)

# sequence with zero or one element
assert slice_intersect([]) == []
assert slice_intersect([series]) == [series]

# sequence with more than 2 elements
intersected_series = slice_intersect([series, seriesA, seriesE])
s1_int = intersected_series[0]
s2_int = intersected_series[1]
s3_int = intersected_series[2]

assert s1_int.time_index.equals(s2_int.time_index) and s1_int.time_index.equals(
s3_int.time_index
)
assert s1_int.start_time() == startE
assert s1_int.end_time() == endA

# check treatment different time index types
if series.has_datetime_index:
seriesF = TimeSeries.from_series(
pd.Series(range(len(idxF)), index=pd.to_numeric(idxF))
)
else:
seriesF = TimeSeries.from_series(
pd.Series(range(len(idxF)), index=pd.to_datetime(idxF))
)

with pytest.raises(IndexError):
slice_intersect([series, seriesF])

@staticmethod
def helper_test_shift(test_case, test_series: TimeSeries):
Expand Down
35 changes: 32 additions & 3 deletions darts/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -2496,7 +2496,7 @@ def slice_intersect(self, other: Self) -> Self:
"""
if other.has_same_time_as(self):
return self.__class__(self._xa)
if other.freq == self.freq:
elif other.freq == self.freq and len(self) and len(other):
start, end = self._slice_intersect_bounds(other)
return self[start:end]
else:
Expand Down Expand Up @@ -2815,9 +2815,9 @@ def has_same_time_as(self, other: Self) -> bool:
"""
if len(other) != len(self):
return False
if other.freq != self.freq:
elif other.freq != self.freq:
return False
if other.start_time() != self.start_time():
elif other.start_time() != self.start_time():
return False
else:
return True
Expand Down Expand Up @@ -5662,6 +5662,35 @@ def concatenate(
return TimeSeries.from_xarray(da_concat, fill_missing_dates=False)


def slice_intersect(series: Sequence[TimeSeries]) -> list[TimeSeries]:
"""Returns a list of ``TimeSeries``, where all `series` have been intersected along the time index.
Parameters
----------
series : Sequence[TimeSeries]
sequence of ``TimeSeries`` to intersect
Returns
-------
Sequence[TimeSeries]
Intersected series.
"""
if not series:
return []

# find global intersection on first series
intersection = series[0]
for series_ in series[1:]:
intersection = intersection.slice_intersect(series_)

# intersect all other series
series_intersected = [intersection]
for series_ in series[1:]:
series_intersected.append(series_.slice_intersect(intersection))

return series_intersected


def _finite_rows_boundaries(
values: np.ndarray, how: str = "all"
) -> tuple[Optional[int], Optional[int]]:
Expand Down

0 comments on commit 1d7f0d1

Please sign in to comment.