diff --git a/dev_scripts/checks.py b/dev_scripts/checks.py new file mode 100644 index 0000000..265a864 --- /dev/null +++ b/dev_scripts/checks.py @@ -0,0 +1,42 @@ +import pandas as pd +import portfolyo as pf +from portfolyo.core.shared import concat + + +def get_idx( + startdate: str, starttime: str, tz: str, freq: str, enddate: str +) -> pd.DatetimeIndex: + # Empty index. + if startdate is None: + return pd.DatetimeIndex([], freq=freq, tz=tz) + # Normal index. + ts_start = pd.Timestamp(f"{startdate} {starttime}", tz=tz) + ts_end = pd.Timestamp(f"{enddate} {starttime}", tz=tz) + return pd.date_range(ts_start, ts_end, freq=freq, inclusive="left") + + +index = pd.date_range("2020", "2024", freq="QS", inclusive="left") +# index2 = pd.date_range("2023", "2025", freq="QS", inclusive="left") +# pfl = pf.dev.get_flatpfline(index) +# pfl2 = pf.dev.get_flatpfline(index2) +# print(pfl) +# print(pfl2) + +# pfs = pf.dev.get_pfstate(index) + +# pfs2 = pf.dev.get_pfstate(index2) +# pfl3 = concat.general(pfl, pfl2) +# print(pfl3) + +# print(index) +# print(index2) + +whole_pfl = pf.dev.get_nestedpfline(index) +pfl_a = whole_pfl.slice[:"2021"] + +pfl_b = whole_pfl.slice["2021":"2022"] +pfl_c = whole_pfl.slice["2022":] +result = concat.concat_pflines(pfl_a, pfl_b, pfl_c) +result2 = concat.concat_pflines(pfl_b, pfl_c, pfl_a) +print(result) +print(result2) diff --git a/docs/core/pfline.rst b/docs/core/pfline.rst index 5a81555..e4ccb54 100644 --- a/docs/core/pfline.rst +++ b/docs/core/pfline.rst @@ -271,6 +271,29 @@ Another slicing method is implemented with the ``.slice[]`` property. The improv +Concatenation +============= + +Portfolio lines can be concatenated with the ``portfolio.concat()`` function. This only works if the input portfolio lines have contain compatible information (the same frequency, timezone, start-of-day, kind, etc) and, crucially, their indices are gapless and without overlap. To remove any overlap, use the ``.slice[]`` property. + +.. exec_code:: + + # --- hide: start --- + import portfolyo as pf, pandas as pd + index = pd.date_range('2024', freq='AS', periods=3) + input_df = pd.DataFrame({'w':[200, 220, 300], 'p': [100, 150, 200]}, index) + pfl = pf.PfLine(input_df) + # --- hide: stop --- + # continuation of previous code example + index2 = pd.date_range('2025', freq='AS', periods=3) # 2 years' overlap with pfl + pfl2 = pf.PfLine(pd.DataFrame({'w':[22, 30, 40], 'p': [15, 20, 21]}, index)) + # first two datapoints (until/excl 2026) from pfl, last two datapoints (from/incl 2026) from pfl2 + pf.concat([pfl.slice[:'2026'], pfl2.slice['2026':]]) + # --- hide: start --- + print(pf.concat([pfl.slice[:'2026'], pfl2.slice['2026':]])) + # --- hide: stop --- + + Volume-only, price-only or revenue-only ======================================= diff --git a/docs/savefig/fig_plot_pfl.png b/docs/savefig/fig_plot_pfl.png index 03698fa..a379a06 100644 Binary files a/docs/savefig/fig_plot_pfl.png and b/docs/savefig/fig_plot_pfl.png differ diff --git a/docs/savefig/fig_plot_pfs.png b/docs/savefig/fig_plot_pfs.png index a00e5de..8ac43f3 100644 Binary files a/docs/savefig/fig_plot_pfs.png and b/docs/savefig/fig_plot_pfs.png differ diff --git a/portfolyo/__init__.py b/portfolyo/__init__.py index 4f87da2..504d5e3 100644 --- a/portfolyo/__init__.py +++ b/portfolyo/__init__.py @@ -3,7 +3,7 @@ from . import _version, dev, tools from .core import extendpandas # extend functionalty of pandas from .core import suppresswarnings -from .core.mixins.plot import plot_pfstates +from .core.shared.plot import plot_pfstates from .core.pfline import Kind, PfLine, Structure, create from .core.pfstate import PfState from .prices.hedge import hedge @@ -14,6 +14,9 @@ from .tools.tzone import force_agnostic, force_aware from .tools.unit import Q_, ureg, Unit from .tools.wavg import general as wavg +from .core.shared.concat import general as concat + +# from .core.shared.concat import general as concat VOLUME = Kind.VOLUME PRICE = Kind.PRICE diff --git a/portfolyo/core/pfline/classes.py b/portfolyo/core/pfline/classes.py index a2aa855..740bea2 100644 --- a/portfolyo/core/pfline/classes.py +++ b/portfolyo/core/pfline/classes.py @@ -8,7 +8,7 @@ import pandas as pd from ... import tools -from ..mixins import ExcelClipboardOutput, PfLinePlot, PfLineText +from ..shared import ExcelClipboardOutput, PfLinePlot, PfLineText from ..ndframelike import NDFrameLike from . import ( create, diff --git a/portfolyo/core/pfstate/pfstate.py b/portfolyo/core/pfstate/pfstate.py index ed127a4..cd634ac 100644 --- a/portfolyo/core/pfstate/pfstate.py +++ b/portfolyo/core/pfstate/pfstate.py @@ -12,7 +12,7 @@ import pandas as pd from ... import tools -from ..mixins import ExcelClipboardOutput, PfStatePlot, PfStateText +from ..shared import ExcelClipboardOutput, PfStatePlot, PfStateText from ..ndframelike import NDFrameLike from ..pfline import PfLine, create from . import pfstate_helper diff --git a/portfolyo/core/mixins/__init__.py b/portfolyo/core/shared/__init__.py similarity index 100% rename from portfolyo/core/mixins/__init__.py rename to portfolyo/core/shared/__init__.py diff --git a/portfolyo/core/shared/concat.py b/portfolyo/core/shared/concat.py new file mode 100644 index 0000000..5f105d2 --- /dev/null +++ b/portfolyo/core/shared/concat.py @@ -0,0 +1,149 @@ +# import pandas as pd +# import portfolyo as pf +from __future__ import annotations +from typing import Iterable +import pandas as pd +from portfolyo import tools + +from ..pfstate import PfState +from ..pfline.enums import Structure + +from ..pfline import PfLine, create +from .. import pfstate + + +def general(pfl_or_pfs: Iterable[PfLine | PfState]) -> None: + """ + Based on passed parameters calls either concat_pflines() or concat_pfstates(). + + Parameters + ---------- + pfl_or_pfs: Iterable[PfLine | PfState] + The input values. Can be either a list of Pflines or PfStates to concatenate. + + Returns + ------- + None + + Notes + ----- + Input portfolio lines must contain compatible information, i.e., same frequency, + timezone, start-of-day, and kind. Their indices must be gapless and without overlap. + + For nested pflines, the number and names of their children must match; concatenation + is done on a name-by-name basis. + + Concatenation returns the same result regardless of input order. + + """ + if all(isinstance(item, PfLine) for item in pfl_or_pfs): + return concat_pflines(pfl_or_pfs) + elif all(isinstance(item, PfState) for item in pfl_or_pfs): + return concat_pfstates(pfl_or_pfs) + else: + raise NotImplementedError( + "Concatenation is implemented only for PfState or PfLine." + ) + + +def concat_pflines(pfls: Iterable[PfLine]) -> PfLine: + """ + Concatenate porfolyo lines along their index. + + Parameters + ---------- + pfls: Iterable[PfLine] + The input values. + + Returns + ------- + PfLine + Concatenated version of PfLines. + + Notes + ----- + Input portfolio lines must contain compatible information, i.e., same frequency, + timezone, start-of-day, and kind. Their indices must be gapless and without overlap. + + For nested pflines, the number and names of their children must match; concatenation + is done on a name-by-name basis. + + Concatenation returns the same result regardless of input order. + """ + if len(pfls) < 2: + raise NotImplementedError( + "Cannot perform operation with less than 2 portfolio lines." + ) + if len({pfl.kind for pfl in pfls}) != 1: + raise TypeError("Not possible to concatenate PfLines of different kinds.") + if len({pfl.index.freq for pfl in pfls}) != 1: + raise TypeError("Not possible to concatenate PfLines of different frequencies.") + if len({pfl.index.tz for pfl in pfls}) != 1: + raise TypeError("Not possible to concatenate PfLines of different time zones.") + if len({tools.startofday.get(pfl.index, "str") for pfl in pfls}) != 1: + raise TypeError( + "Not possible to concatenate PfLines of different start_of_day." + ) + # we can concatenate only pflines of the same type: nested of flat + # with this test and check whether pfls are the same types and they have the same number of children + if len({pfl.structure for pfl in pfls}) != 1: + raise TypeError("Not possible to concatenate PfLines of different structures.") + if pfls[0].structure is Structure.NESTED: + child_names = pfls[0].children.keys() + for pfl in pfls: + diffs = set(child_names) ^ set(pfl.children.keys()) + if len(diffs) != 0: + raise TypeError( + "Not possible to concatenate PfLines with different children names." + ) + # If we reach here, all pfls have same kind, same number and names of children. + + # concat(a,b) and concat(b,a) should give the same result: + sorted_pfls = sorted(pfls, key=lambda pfl: pfl.index[0]) + if pfls[0].structure is Structure.FLAT: + # create flat dataframe of parent + dataframes_flat = [pfl.df for pfl in sorted_pfls] + # concatenate dataframes into one + concat_data = pd.concat(dataframes_flat, axis=0) + try: + # Call create.flatpfline() and catch any ValueError + return create.flatpfline(concat_data) + except ValueError as e: + # Handle the error + raise ValueError( + "Error by creating PfLine. PfLine is either not gapless or has overlaps" + ) from e + child_data = {} + child_names = pfls[0].children.keys() + for cname in child_names: + # for every name in children need to concatenate elements + child_values = [pfl.children[cname] for pfl in sorted_pfls] + child_data[cname] = concat_pflines(child_values) + + # create pfline from dataframes: -> + # call the constructor of pfl to check check gaplesnes and overplap + return create.nestedpfline(child_data) + + +def concat_pfstates(pfss: Iterable[PfState]) -> PfState: + """ + Concatenate porfolyo states along their index. + + Parameters + ---------- + pfss: Iterable[PfState] + The input values. + + Returns + ------- + PfState + Concatenated version of PfStates. + + """ + if len(pfss) < 2: + print("Concatenate needs at least two elements.") + return + offtakevolume = concat_pflines([pfs.offtakevolume for pfs in pfss]) + sourced = concat_pflines([pfs.sourced for pfs in pfss]) + unsourcedprice = concat_pflines([pfs.unsourcedprice for pfs in pfss]) + return pfstate.PfState(offtakevolume, unsourcedprice, sourced) diff --git a/portfolyo/core/mixins/excelclipboard.py b/portfolyo/core/shared/excelclipboard.py similarity index 100% rename from portfolyo/core/mixins/excelclipboard.py rename to portfolyo/core/shared/excelclipboard.py diff --git a/portfolyo/core/mixins/plot.py b/portfolyo/core/shared/plot.py similarity index 100% rename from portfolyo/core/mixins/plot.py rename to portfolyo/core/shared/plot.py diff --git a/portfolyo/core/mixins/text.py b/portfolyo/core/shared/text.py similarity index 100% rename from portfolyo/core/mixins/text.py rename to portfolyo/core/shared/text.py diff --git a/tests/core/shared/test_concat_error_cases.py b/tests/core/shared/test_concat_error_cases.py new file mode 100644 index 0000000..a5eb551 --- /dev/null +++ b/tests/core/shared/test_concat_error_cases.py @@ -0,0 +1,137 @@ +"""Test different error cases for concatenation of PfStates and PfLines.""" + +import pandas as pd +import pytest + + +from portfolyo import dev +from portfolyo.core.pfline.enums import Kind +from portfolyo.core.pfstate.pfstate import PfState +from portfolyo.core.shared import concat + + +def test_general(): + """Test if concatenating PfLine with PfState raises error.""" + index = pd.date_range("2020", "2024", freq="QS", inclusive="left") + index2 = pd.date_range("2024", "2025", freq="QS", inclusive="left") + pfl = dev.get_flatpfline(index) + pfs = dev.get_pfstate(index2) + with pytest.raises(NotImplementedError): + _ = concat.general([pfl, pfs]) + + +def test_diff_freq(): + """Test if concatenating of two flat PfLines with different freq raises error.""" + index = pd.date_range("2020", "2024", freq="QS", inclusive="left") + index2 = pd.date_range("2024", "2025", freq="AS", inclusive="left") + pfl = dev.get_flatpfline(index) + pfl2 = dev.get_flatpfline(index2) + with pytest.raises(TypeError): + _ = concat.concat_pflines([pfl, pfl2]) + + +def test_diff_sod(): + """Test if concatenating of two flat PfLines with different sod raises error.""" + index = pd.date_range("2020-01-01 00:00", "2024", freq="QS", inclusive="left") + index2 = pd.date_range("2024-01-01 06:00", "2025", freq="QS", inclusive="left") + pfl = dev.get_flatpfline(index) + pfl2 = dev.get_flatpfline(index2) + with pytest.raises(TypeError): + _ = concat.concat_pflines([pfl, pfl2]) + + +def test_slice_not_sod(): + """Test if concatenating of two flat PfLines with different sod raises error.""" + index = pd.date_range("2020-01-01 00:00", "2020-03-01", freq="H", inclusive="left") + index2 = pd.date_range( + "2020-02-01 06:00", "2020-04-01 06:00", freq="H", inclusive="left" + ) + pfl_a = dev.get_flatpfline(index) + pfl_b = dev.get_flatpfline(index2) + with pytest.raises(TypeError): + _ = concat.concat_pflines([pfl_a, pfl_b]) + + +def test_diff_tz(): + """Test if concatenating of two flat PfLines with different tz raises error.""" + index = pd.date_range( + "2020-01-01", "2024", freq="QS", tz="Europe/Berlin", inclusive="left" + ) + index2 = pd.date_range("2024-01-01", "2025", freq="QS", tz=None, inclusive="left") + pfl = dev.get_flatpfline(index) + pfl2 = dev.get_flatpfline(index2) + with pytest.raises(TypeError): + _ = concat.concat_pflines([pfl, pfl2]) + + +def test_diff_kind(): + """Test if concatenating of two flat PfLines with different kind raises error.""" + index = pd.date_range("2020-01-01", "2024", freq="QS", inclusive="left") + index2 = pd.date_range("2024-01-01", "2025", freq="QS", inclusive="left") + pfl = dev.get_flatpfline(index, kind=Kind.COMPLETE) + pfl2 = dev.get_flatpfline(index2, kind=Kind.VOLUME) + with pytest.raises(TypeError): + _ = concat.concat_pflines([pfl, pfl2]) + + +def test_app_lenght(): + """Test if concatenatination raises error if we pass only one parameter.""" + index = pd.date_range("2020-01-01", "2024", freq="QS", inclusive="left") + pfl = dev.get_flatpfline(index) + with pytest.raises(NotImplementedError): + _ = concat.concat_pflines([pfl]) + + +def test_concat_with_overlap(): + """Test if concatenatination raises error if there is overlap in indices of PfLines.""" + index = pd.date_range("2020-01-01", "2024", freq="QS", inclusive="left") + index2 = pd.date_range("2020-01-01", "2023", freq="QS", inclusive="left") + pfl = dev.get_flatpfline(index) + pfl2 = dev.get_flatpfline(index2) + with pytest.raises(ValueError): + _ = concat.concat_pflines([pfl, pfl2]) + + +def test_concat_with_gaps(): + """Test if concatenatination raises error if there is a gap in indices of PfLines.""" + index = pd.date_range("2020-01-01", "2023", freq="QS", inclusive="left") + index2 = pd.date_range("2024-01-01", "2025", freq="QS", inclusive="left") + pfl = dev.get_flatpfline(index) + pfl2 = dev.get_flatpfline(index2) + with pytest.raises(ValueError): + _ = concat.concat_pflines([pfl, pfl2]) + + +def test_concat_children(): + """Test if concatenating of flat PfLine with nested PfLine raises error.""" + index = pd.date_range("2020-01-01", "2024", freq="QS", inclusive="left") + index2 = pd.date_range("2024-01-01", "2025", freq="QS", inclusive="left") + pfl = dev.get_flatpfline(index) + pfl2 = dev.get_nestedpfline(index2) + with pytest.raises(TypeError): + _ = concat.concat_pflines([pfl, pfl2]) + + +def test_concat_diff_children(): + """Test if concatenating of two nested PfLines with different children raises error.""" + index = pd.date_range("2020-01-01", "2024", freq="QS", inclusive="left") + index2 = pd.date_range("2024-01-01", "2025", freq="QS", inclusive="left") + pfl = dev.get_nestedpfline(index) + pfl2 = dev.get_nestedpfline(index2).drop_child(name="a") + with pytest.raises(TypeError): + _ = concat.concat_pflines([pfl, pfl2]) + + +def test_concat_pfss(): + """Test if concatenating of Pfstate with "nested" PfState + (meaning that offtakevolume, sourced and unsourcedprice are nested Pflines) raises error. + """ + index = pd.date_range("2020-01-01", "2024", freq="QS", inclusive="left") + index2 = pd.date_range("2024-01-01", "2025", freq="QS", inclusive="left") + pfs1 = dev.get_pfstate(index) + offtakevolume = dev.get_nestedpfline(index2, kind=Kind.VOLUME) + sourced = dev.get_nestedpfline(index2, kind=Kind.COMPLETE) + unsourcedprice = dev.get_nestedpfline(index2, kind=Kind.PRICE) + pfs2 = PfState(offtakevolume, unsourcedprice, sourced) + with pytest.raises(TypeError): + _ = concat.concat_pfstates([pfs1, pfs2]) diff --git a/tests/core/shared/test_concat_pfline.py b/tests/core/shared/test_concat_pfline.py new file mode 100644 index 0000000..f066602 --- /dev/null +++ b/tests/core/shared/test_concat_pfline.py @@ -0,0 +1,151 @@ +"""Test if concatenation of PfLines works properly with different test cases.""" + +import pandas as pd +import pytest +from portfolyo import dev +from portfolyo.core.shared import concat + + +TESTCASES2 = [ # whole idx, freq, where + ( + ("2020-01-01", "2023-04-01"), + "QS", + "2022-04-01", + ), + (("2020", "2022"), "AS", "2021-01-01"), + ( + ("2020-05-01", "2023-04-01"), + "MS", + "2022-11-01", + ), + (("2022-03-20", "2022-07-28"), "D", "2022-05-28"), +] + +TESTCASES3 = [ # whole idx, freq, where + ( + ("2020-01-01", "2023-04-01"), + "QS", + ("2022-04-01", "2023-01-01"), + ), + (("2020", "2023"), "AS", ("2021-01-01", "2022-01-01")), + ( + ("2020-05-01", "2023-04-01"), + "MS", + ("2022-11-01", "2023-01-01"), + ), + (("2022-03-20", "2022-07-28"), "D", ("2022-04-28", "2022-05-15")), +] + + +def get_idx( + startdate: str, starttime: str, tz: str, freq: str, enddate: str +) -> pd.DatetimeIndex: + # Empty index. + if startdate is None: + return pd.DatetimeIndex([], freq=freq, tz=tz) + # Normal index. + ts_start = pd.Timestamp(f"{startdate} {starttime}", tz=tz) + ts_end = pd.Timestamp(f"{enddate} {starttime}", tz=tz) + return pd.date_range(ts_start, ts_end, freq=freq, inclusive="left") + + +@pytest.mark.parametrize("tz", [None, "Europe/Berlin", "Asia/Kolkata"]) +@pytest.mark.parametrize("starttime", ["00:00", "06:00"]) +@pytest.mark.parametrize(("whole_idx", "freq", "where"), TESTCASES2) +@pytest.mark.parametrize("test_fn", ["general", "concat_pflines"]) +def test_concat_flat_pflines( + whole_idx: str, starttime: str, tz: str, freq: str, where: str, test_fn: str +): + """Test that two flat pflines with the same attributes (i.e., same frequency, + timezone, start-of-day, and kind) get concatenated properly.""" + idx = get_idx(whole_idx[0], starttime, tz, freq, whole_idx[1]) + whole_pfl = dev.get_flatpfline(idx) + pfl_a = whole_pfl.slice[:where] + pfl_b = whole_pfl.slice[where:] + fn = concat.general if test_fn == "general" else concat.concat_pflines + result = fn([pfl_a, pfl_b]) + result2 = fn([pfl_b, pfl_a]) + assert whole_pfl == result + assert whole_pfl == result2 + + +@pytest.mark.parametrize("tz", [None, "Europe/Berlin", "Asia/Kolkata"]) +@pytest.mark.parametrize("starttime", ["00:00", "06:00"]) +@pytest.mark.parametrize(("whole_idx", "freq", "where"), TESTCASES2) +@pytest.mark.parametrize("test_fn", ["general", "concat_pflines"]) +def test_concat_nested_pflines( + whole_idx: str, + starttime: str, + tz: str, + freq: str, + where: str, + test_fn: str, +): + """Test that two nested pflines with the same attributes (i.e., same frequency, + timezone, start-of-day, and kind) and the same number of children and children names + get concatenated properly.""" + idx = get_idx(whole_idx[0], starttime, tz, freq, whole_idx[1]) + whole_pfl = dev.get_nestedpfline(idx) + pfl_a = whole_pfl.slice[:where] + pfl_b = whole_pfl.slice[where:] + fn = concat.general if test_fn == "general" else concat.concat_pflines + result = fn([pfl_a, pfl_b]) + result2 = fn([pfl_b, pfl_a]) + assert whole_pfl == result + assert whole_pfl == result2 + + +@pytest.mark.parametrize("tz", [None, "Europe/Berlin", "Asia/Kolkata"]) +@pytest.mark.parametrize("starttime", ["00:00", "06:00"]) +@pytest.mark.parametrize(("whole_idx", "freq", "where"), TESTCASES3) +@pytest.mark.parametrize("test_fn", ["general", "concat_pflines"]) +def test_concat_three_flatpflines( + whole_idx: str, + starttime: str, + tz: str, + freq: str, + where: str, + test_fn: str, +): + """Test that three flat pflines with the same attributes (i.e., same frequency, + timezone, start-of-day, and kind) get concatenated properly.""" + idx = get_idx(whole_idx[0], starttime, tz, freq, whole_idx[1]) + whole_pfl = dev.get_flatpfline(idx) + split_one = where[0] + split_two = where[1] + pfl_a = whole_pfl.slice[:split_one] + pfl_b = whole_pfl.slice[split_one:split_two] + pfl_c = whole_pfl.slice[split_two:] + fn = concat.general if test_fn == "general" else concat.concat_pflines + result = fn([pfl_a, pfl_b, pfl_c]) + result2 = fn([pfl_b, pfl_c, pfl_a]) + assert whole_pfl == result + assert whole_pfl == result2 + + +@pytest.mark.parametrize("tz", [None, "Europe/Berlin", "Asia/Kolkata"]) +@pytest.mark.parametrize("starttime", ["00:00", "06:00"]) +@pytest.mark.parametrize(("whole_idx", "freq", "where"), TESTCASES3) +@pytest.mark.parametrize("test_fn", ["general", "concat_pflines"]) +def test_concat_three_nestedpflines( + whole_idx: str, + starttime: str, + tz: str, + freq: str, + where: str, + test_fn: str, +): + """Test that three nested pflines with the same attributes ( aka kind, freq, sod, etc.) + and the same number of children and children names get concatenated properly.""" + idx = get_idx(whole_idx[0], starttime, tz, freq, whole_idx[1]) + whole_pfl = dev.get_nestedpfline(idx) + split_one = where[0] + split_two = where[1] + pfl_a = whole_pfl.slice[:split_one] + pfl_b = whole_pfl.slice[split_one:split_two] + pfl_c = whole_pfl.slice[split_two:] + fn = concat.general if test_fn == "general" else concat.concat_pflines + result = fn([pfl_a, pfl_b, pfl_c]) + result2 = fn([pfl_b, pfl_c, pfl_a]) + assert whole_pfl == result + assert whole_pfl == result2 diff --git a/tests/core/shared/test_concat_pfstate.py b/tests/core/shared/test_concat_pfstate.py new file mode 100644 index 0000000..3ce923e --- /dev/null +++ b/tests/core/shared/test_concat_pfstate.py @@ -0,0 +1,100 @@ +"""Test if concatenation of PfStates works properly with different test cases.""" + +import pandas as pd +import pytest +from portfolyo import dev +from portfolyo.core.shared import concat + + +TESTCASES2 = [ # whole idx, freq, where + ( + ("2020-01-01", "2023-04-01"), + "QS", + "2022-04-01", + ), + (("2020", "2022"), "AS", "2021-01-01"), + ( + ("2020-05-01", "2023-04-01"), + "MS", + "2022-11-01", + ), + (("2022-03-20", "2022-07-28"), "D", "2022-05-28"), +] + +TESTCASES3 = [ # whole idx, freq, where + ( + ("2020-01-01", "2023-04-01"), + "QS", + ("2022-04-01", "2023-01-01"), + ), + (("2020", "2023"), "AS", ("2021-01-01", "2022-01-01")), + ( + ("2020-05-01", "2023-04-01"), + "MS", + ("2022-11-01", "2023-01-01"), + ), + (("2022-03-20", "2022-07-28"), "D", ("2022-04-28", "2022-05-15")), +] + + +def get_idx( + startdate: str, starttime: str, tz: str, freq: str, enddate: str +) -> pd.DatetimeIndex: + # Empty index. + if startdate is None: + return pd.DatetimeIndex([], freq=freq, tz=tz) + # Normal index. + ts_start = pd.Timestamp(f"{startdate} {starttime}", tz=tz) + ts_end = pd.Timestamp(f"{enddate} {starttime}", tz=tz) + return pd.date_range(ts_start, ts_end, freq=freq, inclusive="left") + + +@pytest.mark.parametrize("tz", [None, "Europe/Berlin", "Asia/Kolkata"]) +@pytest.mark.parametrize("starttime", ["00:00", "06:00"]) +@pytest.mark.parametrize(("whole_idx", "freq", "where"), TESTCASES2) +@pytest.mark.parametrize("test_fn", ["general", "concat_pfstates"]) +def test_concat_pfstates( + whole_idx: str, + starttime: str, + tz: str, + freq: str, + where: str, + test_fn: str, +): + """Test that two PfStates get concatenated properly.""" + idx = get_idx(whole_idx[0], starttime, tz, freq, whole_idx[1]) + whole_pfs = dev.get_pfstate(idx) + pfs_a = whole_pfs.slice[:where] + pfs_b = whole_pfs.slice[where:] + fn = concat.general if test_fn == "general" else concat.concat_pfstates + result = fn([pfs_a, pfs_b]) + result2 = fn([pfs_b, pfs_a]) + assert whole_pfs == result + assert whole_pfs == result2 + + +@pytest.mark.parametrize("tz", [None, "Europe/Berlin", "Asia/Kolkata"]) +@pytest.mark.parametrize("starttime", ["00:00", "06:00"]) +@pytest.mark.parametrize(("whole_idx", "freq", "where"), TESTCASES3) +@pytest.mark.parametrize("test_fn", ["general", "concat_pfstates"]) +def test_concat_three_pfstates( + whole_idx: str, + starttime: str, + tz: str, + freq: str, + where: str, + test_fn: str, +): + """Test that three PfStates get concatenated properly.""" + idx = get_idx(whole_idx[0], starttime, tz, freq, whole_idx[1]) + whole_pfs = dev.get_pfstate(idx) + split_one = where[0] + split_two = where[1] + pfs_a = whole_pfs.slice[:split_one] + pfs_b = whole_pfs.slice[split_one:split_two] + pfs_c = whole_pfs.slice[split_two:] + fn = concat.general if test_fn == "general" else concat.concat_pfstates + result = fn([pfs_a, pfs_b, pfs_c]) + result2 = fn([pfs_b, pfs_c, pfs_a]) + assert whole_pfs == result + assert whole_pfs == result2