Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/slice intersect multi series #2592

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- New model: `StatsForecastAutoTBATS`. This model offers the [AutoTBATS](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#autotbats) model from Nixtla's `statsforecasts` library. [#2611](https://github.com/unit8co/darts/pull/2611) by [He Weilin](https://github.com/cnhwl).
- Added the `title` attribute to `TimeSeries.plot()`. This allows to set a title for the plot. [#2639](https://github.com/unit8co/darts/pull/2639) by [Jonathan Koch](https://github.com/jonathankoch99).
- Added parameter `component_wise` to `show_anomalies()` to separately plot each component in multivariate series. [#2544](https://github.com/unit8co/darts/pull/2544) by [He Weilin](https://github.com/cnhwl).
- Added general function `darts.slice_intersect()` to intersect a sequence of `TimeSeries` along the time index. [#2592](https://github.com/unit8co/darts/pull/2592) by [Yoav Matzkevich](https://github.com/ymatzkevich).

**Fixed**
- Fixed a bug when performing optimized historical forecasts with `stride=1` using a `RegressionModel` with `output_chunk_shift>=1` and `output_chunk_length=1`, where the forecast time index was not properly shifted. [#2634](https://github.com/unit8co/darts/pull/2634) by [Mattias De Charleroy](https://github.com/MattiasDC).
Expand Down
4 changes: 2 additions & 2 deletions darts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import matplotlib as mpl
from matplotlib import cycler

from darts.timeseries import TimeSeries, concatenate
from darts.timeseries import TimeSeries, concatenate, slice_intersect

__version__ = "0.32.0"

Expand Down Expand Up @@ -42,4 +42,4 @@
if os.getenv("DARTS_CONFIGURE_MATPLOTLIB", "1") != "0":
mpl.rcParams.update(u8plots_mplstyle)

__all__ = ["TimeSeries", "concatenate"]
__all__ = ["TimeSeries", "concatenate", "slice_intersect"]
66 changes: 50 additions & 16 deletions darts/tests/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import xarray as xr
from scipy.stats import kurtosis, skew

from darts import TimeSeries, concatenate
from darts import TimeSeries, concatenate, slice_intersect
from darts.utils.timeseries_generation import constant_timeseries, linear_timeseries
from darts.utils.utils import expand_arr, freqs, generate_index

Expand Down Expand Up @@ -603,6 +603,11 @@ def check_intersect(other, start_, end_, freq_):
s_int_idx = series.slice_intersect_times(other, copy=False)
assert s_int.time_index.equals(s_int_idx)

assert slice_intersect([series, other]) == [
series.slice_intersect(other),
other.slice_intersect(series),
]

# slice with exact range
startA = start
endA = end
Expand All @@ -611,11 +616,11 @@ def check_intersect(other, start_, end_, freq_):
check_intersect(seriesA, startA, endA, freq_expected)

# entire slice within the range
startA = start + freq
endA = startA + 6 * freq_other
idxA = generate_index(startA, endA, freq=freq_other)
seriesA = TimeSeries.from_series(pd.Series(range(len(idxA)), index=idxA))
check_intersect(seriesA, startA, endA, freq_expected)
startB = start + freq
endB = startB + 6 * freq_other
idxB = generate_index(startB, endB, freq=freq_other)
seriesB = TimeSeries.from_series(pd.Series(range(len(idxB)), index=idxB))
check_intersect(seriesB, startB, endB, freq_expected)

# start outside of range
startC = start - 4 * freq
Expand All @@ -625,11 +630,11 @@ def check_intersect(other, start_, end_, freq_):
check_intersect(seriesC, start, endC, freq_expected)

# end outside of range
startC = start + 4 * freq
endC = end + 4 * freq_other
idxC = generate_index(startC, endC, freq=freq_other)
seriesC = TimeSeries.from_series(pd.Series(range(len(idxC)), index=idxC))
check_intersect(seriesC, startC, end, freq_expected)
startD = start + 4 * freq
endD = end + 4 * freq_other
idxD = generate_index(startD, endD, freq=freq_other)
seriesD = TimeSeries.from_series(pd.Series(range(len(idxD)), index=idxD))
check_intersect(seriesD, startD, end, freq_expected)

# small intersect
startE = start + (n_steps - 1) * freq
Expand All @@ -639,12 +644,41 @@ def check_intersect(other, start_, end_, freq_):
check_intersect(seriesE, startE, end, freq_expected)

# No intersect
startG = end + 3 * freq
endG = startG + 6 * freq_other
idxG = generate_index(startG, endG, freq=freq_other)
seriesG = TimeSeries.from_series(pd.Series(range(len(idxG)), index=idxG))
startF = end + 3 * freq
endF = startF + 6 * freq_other
idxF = generate_index(startF, endF, freq=freq_other)
seriesF = TimeSeries.from_series(pd.Series(range(len(idxF)), index=idxF))
# for empty slices, we expect the original freq
check_intersect(seriesG, None, None, freq)
check_intersect(seriesF, None, None, freq)

# sequence with zero or one element
assert slice_intersect([]) == []
assert slice_intersect([series]) == [series]

# sequence with more than 2 elements
intersected_series = slice_intersect([series, seriesA, seriesE])
s1_int = intersected_series[0]
s2_int = intersected_series[1]
s3_int = intersected_series[2]

assert s1_int.time_index.equals(s2_int.time_index) and s1_int.time_index.equals(
s3_int.time_index
)
assert s1_int.start_time() == startE
assert s1_int.end_time() == endA

# check treatment different time index types
if series.has_datetime_index:
seriesF = TimeSeries.from_series(
pd.Series(range(len(idxF)), index=pd.to_numeric(idxF))
)
else:
seriesF = TimeSeries.from_series(
pd.Series(range(len(idxF)), index=pd.to_datetime(idxF))
)

with pytest.raises(IndexError):
slice_intersect([series, seriesF])

@staticmethod
def helper_test_shift(test_case, test_series: TimeSeries):
Expand Down
35 changes: 32 additions & 3 deletions darts/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -2496,7 +2496,7 @@ def slice_intersect(self, other: Self) -> Self:
"""
if other.has_same_time_as(self):
return self.__class__(self._xa)
if other.freq == self.freq:
elif other.freq == self.freq and len(self) and len(other):
start, end = self._slice_intersect_bounds(other)
return self[start:end]
else:
Expand Down Expand Up @@ -2815,9 +2815,9 @@ def has_same_time_as(self, other: Self) -> bool:
"""
if len(other) != len(self):
return False
if other.freq != self.freq:
elif other.freq != self.freq:
return False
if other.start_time() != self.start_time():
elif other.start_time() != self.start_time():
return False
else:
return True
Expand Down Expand Up @@ -5662,6 +5662,35 @@ def concatenate(
return TimeSeries.from_xarray(da_concat, fill_missing_dates=False)


def slice_intersect(series: Sequence[TimeSeries]) -> list[TimeSeries]:
"""Returns a list of ``TimeSeries``, where all `series` have been intersected along the time index.
Parameters
----------
series : Sequence[TimeSeries]
sequence of ``TimeSeries`` to intersect
Returns
-------
Sequence[TimeSeries]
Intersected series.
"""
if not series:
return []

# find global intersection on first series
intersection = series[0]
for series_ in series[1:]:
intersection = intersection.slice_intersect(series_)

# intersect all other series
series_intersected = [intersection]
for series_ in series[1:]:
series_intersected.append(series_.slice_intersect(intersection))

return series_intersected


def _finite_rows_boundaries(
values: np.ndarray, how: str = "all"
) -> tuple[Optional[int], Optional[int]]:
Expand Down
Loading