diff --git a/CHANGELOG.md b/CHANGELOG.md index 0956a0a4a0..40fa7065eb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co - New model: `StatsForecastAutoTBATS`. This model offers the [AutoTBATS](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#autotbats) model from Nixtla's `statsforecasts` library. [#2611](https://github.com/unit8co/darts/pull/2611) by [He Weilin](https://github.com/cnhwl). - Added the `title` attribute to `TimeSeries.plot()`. This allows to set a title for the plot. [#2639](https://github.com/unit8co/darts/pull/2639) by [Jonathan Koch](https://github.com/jonathankoch99). - Added parameter `component_wise` to `show_anomalies()` to separately plot each component in multivariate series. [#2544](https://github.com/unit8co/darts/pull/2544) by [He Weilin](https://github.com/cnhwl). +- Added general function `darts.slice_intersect()` to intersect a sequence of `TimeSeries` along the time index. [#2592](https://github.com/unit8co/darts/pull/2592) by [Yoav Matzkevich](https://github.com/ymatzkevich). **Fixed** - Fixed a bug when performing optimized historical forecasts with `stride=1` using a `RegressionModel` with `output_chunk_shift>=1` and `output_chunk_length=1`, where the forecast time index was not properly shifted. [#2634](https://github.com/unit8co/darts/pull/2634) by [Mattias De Charleroy](https://github.com/MattiasDC). diff --git a/darts/__init__.py b/darts/__init__.py index cb0b70632e..61b4e98f6a 100644 --- a/darts/__init__.py +++ b/darts/__init__.py @@ -8,7 +8,7 @@ import matplotlib as mpl from matplotlib import cycler -from darts.timeseries import TimeSeries, concatenate +from darts.timeseries import TimeSeries, concatenate, slice_intersect __version__ = "0.32.0" @@ -42,4 +42,4 @@ if os.getenv("DARTS_CONFIGURE_MATPLOTLIB", "1") != "0": mpl.rcParams.update(u8plots_mplstyle) -__all__ = ["TimeSeries", "concatenate"] +__all__ = ["TimeSeries", "concatenate", "slice_intersect"] diff --git a/darts/tests/test_timeseries.py b/darts/tests/test_timeseries.py index 10cd710ba9..b34bb3a789 100644 --- a/darts/tests/test_timeseries.py +++ b/darts/tests/test_timeseries.py @@ -9,7 +9,7 @@ import xarray as xr from scipy.stats import kurtosis, skew -from darts import TimeSeries, concatenate +from darts import TimeSeries, concatenate, slice_intersect from darts.utils.timeseries_generation import constant_timeseries, linear_timeseries from darts.utils.utils import expand_arr, freqs, generate_index @@ -603,6 +603,11 @@ def check_intersect(other, start_, end_, freq_): s_int_idx = series.slice_intersect_times(other, copy=False) assert s_int.time_index.equals(s_int_idx) + assert slice_intersect([series, other]) == [ + series.slice_intersect(other), + other.slice_intersect(series), + ] + # slice with exact range startA = start endA = end @@ -611,11 +616,11 @@ def check_intersect(other, start_, end_, freq_): check_intersect(seriesA, startA, endA, freq_expected) # entire slice within the range - startA = start + freq - endA = startA + 6 * freq_other - idxA = generate_index(startA, endA, freq=freq_other) - seriesA = TimeSeries.from_series(pd.Series(range(len(idxA)), index=idxA)) - check_intersect(seriesA, startA, endA, freq_expected) + startB = start + freq + endB = startB + 6 * freq_other + idxB = generate_index(startB, endB, freq=freq_other) + seriesB = TimeSeries.from_series(pd.Series(range(len(idxB)), index=idxB)) + check_intersect(seriesB, startB, endB, freq_expected) # start outside of range startC = start - 4 * freq @@ -625,11 +630,11 @@ def check_intersect(other, start_, end_, freq_): check_intersect(seriesC, start, endC, freq_expected) # end outside of range - startC = start + 4 * freq - endC = end + 4 * freq_other - idxC = generate_index(startC, endC, freq=freq_other) - seriesC = TimeSeries.from_series(pd.Series(range(len(idxC)), index=idxC)) - check_intersect(seriesC, startC, end, freq_expected) + startD = start + 4 * freq + endD = end + 4 * freq_other + idxD = generate_index(startD, endD, freq=freq_other) + seriesD = TimeSeries.from_series(pd.Series(range(len(idxD)), index=idxD)) + check_intersect(seriesD, startD, end, freq_expected) # small intersect startE = start + (n_steps - 1) * freq @@ -639,12 +644,41 @@ def check_intersect(other, start_, end_, freq_): check_intersect(seriesE, startE, end, freq_expected) # No intersect - startG = end + 3 * freq - endG = startG + 6 * freq_other - idxG = generate_index(startG, endG, freq=freq_other) - seriesG = TimeSeries.from_series(pd.Series(range(len(idxG)), index=idxG)) + startF = end + 3 * freq + endF = startF + 6 * freq_other + idxF = generate_index(startF, endF, freq=freq_other) + seriesF = TimeSeries.from_series(pd.Series(range(len(idxF)), index=idxF)) # for empty slices, we expect the original freq - check_intersect(seriesG, None, None, freq) + check_intersect(seriesF, None, None, freq) + + # sequence with zero or one element + assert slice_intersect([]) == [] + assert slice_intersect([series]) == [series] + + # sequence with more than 2 elements + intersected_series = slice_intersect([series, seriesA, seriesE]) + s1_int = intersected_series[0] + s2_int = intersected_series[1] + s3_int = intersected_series[2] + + assert s1_int.time_index.equals(s2_int.time_index) and s1_int.time_index.equals( + s3_int.time_index + ) + assert s1_int.start_time() == startE + assert s1_int.end_time() == endA + + # check treatment different time index types + if series.has_datetime_index: + seriesF = TimeSeries.from_series( + pd.Series(range(len(idxF)), index=pd.to_numeric(idxF)) + ) + else: + seriesF = TimeSeries.from_series( + pd.Series(range(len(idxF)), index=pd.to_datetime(idxF)) + ) + + with pytest.raises(IndexError): + slice_intersect([series, seriesF]) @staticmethod def helper_test_shift(test_case, test_series: TimeSeries): diff --git a/darts/timeseries.py b/darts/timeseries.py index 4b7940e91f..c35142c478 100644 --- a/darts/timeseries.py +++ b/darts/timeseries.py @@ -2496,7 +2496,7 @@ def slice_intersect(self, other: Self) -> Self: """ if other.has_same_time_as(self): return self.__class__(self._xa) - if other.freq == self.freq: + elif other.freq == self.freq and len(self) and len(other): start, end = self._slice_intersect_bounds(other) return self[start:end] else: @@ -2815,9 +2815,9 @@ def has_same_time_as(self, other: Self) -> bool: """ if len(other) != len(self): return False - if other.freq != self.freq: + elif other.freq != self.freq: return False - if other.start_time() != self.start_time(): + elif other.start_time() != self.start_time(): return False else: return True @@ -5662,6 +5662,35 @@ def concatenate( return TimeSeries.from_xarray(da_concat, fill_missing_dates=False) +def slice_intersect(series: Sequence[TimeSeries]) -> list[TimeSeries]: + """Returns a list of ``TimeSeries``, where all `series` have been intersected along the time index. + + Parameters + ---------- + series : Sequence[TimeSeries] + sequence of ``TimeSeries`` to intersect + + Returns + ------- + Sequence[TimeSeries] + Intersected series. + """ + if not series: + return [] + + # find global intersection on first series + intersection = series[0] + for series_ in series[1:]: + intersection = intersection.slice_intersect(series_) + + # intersect all other series + series_intersected = [intersection] + for series_ in series[1:]: + series_intersected.append(series_.slice_intersect(intersection)) + + return series_intersected + + def _finite_rows_boundaries( values: np.ndarray, how: str = "all" ) -> tuple[Optional[int], Optional[int]]: