diff --git a/timely_beliefs/tests/test_belief_io.py b/timely_beliefs/tests/test_belief_io.py index b968ab45..298d1e05 100644 --- a/timely_beliefs/tests/test_belief_io.py +++ b/timely_beliefs/tests/test_belief_io.py @@ -7,8 +7,8 @@ import pytz import timely_beliefs as tb -from timely_beliefs.beliefs.classes import METADATA from timely_beliefs.examples import example_df +from timely_beliefs.tests.utils import assert_metadata_is_retained @pytest.fixture(scope="module") @@ -360,21 +360,16 @@ def test_converting_between_data_frame_and_series_retains_metadata(): Test whether expanding dimensions of a BeliefsSeries into a BeliefsDataFrame retains the metadata. """ df = example_df - metadata = {md: getattr(example_df, md) for md in METADATA} series = df["event_value"] - for md in metadata: - assert getattr(series, md) == metadata[md] + assert_metadata_is_retained(series, original_df=example_df, is_series=True) df = series.to_frame() - for md in metadata: - assert getattr(df, md) == metadata[md] + assert_metadata_is_retained(df, original_df=example_df) def test_dropping_index_levels_retains_metadata(): df = example_df.copy() - metadata = {md: getattr(example_df, md) for md in METADATA} df.index = df.index.get_level_values("event_start") # drop all other index levels - for md in metadata: - assert getattr(df, md) == metadata[md] + assert_metadata_is_retained(df, original_df=example_df) @pytest.mark.parametrize("drop_level", [True, False]) @@ -383,12 +378,9 @@ def test_slicing_retains_metadata(drop_level): Test whether slicing the index of a BeliefsDataFrame retains the metadata. """ df = example_df - metadata = {md: getattr(example_df, md) for md in METADATA} df = df.xs("2000-01-03 10:00:00+00:00", level="event_start", drop_level=drop_level) print(df) - assert isinstance(df, tb.BeliefsDataFrame) - for md in metadata: - assert getattr(df, md) == metadata[md] + assert_metadata_is_retained(df, original_df=example_df) @pytest.mark.parametrize("resolution", [timedelta(minutes=30), timedelta(hours=2)]) @@ -400,15 +392,13 @@ def test_mean_resampling_retains_metadata(resolution): Succeeds with pandas==1.1.0 """ df = example_df - metadata = {md: getattr(example_df, md) for md in METADATA} df = df.resample(resolution, level="event_start").mean() print(df) - assert isinstance(df, tb.BeliefsDataFrame) - for md in metadata: - # if md == "event_resolution": - # assert df.event_resolution == resolution - # else: # todo: the event_resolution metadata is only updated when resampling using df.resample_events(). A reason to override the original resample method, or otherwise something to document. - assert getattr(df, md) == metadata[md] + assert_metadata_is_retained( + df, + original_df=example_df, + event_resolution=example_df.event_resolution, + ) # todo: the event_resolution metadata is only updated when resampling using df.resample_events(). A reason to override the original resample method, or otherwise something to document. @pytest.mark.parametrize("resolution", [timedelta(minutes=30), timedelta(hours=2)]) @@ -419,7 +409,6 @@ def _test_agg_resampling_retains_metadata(resolution): Fails with pandas==1.1.5 """ df = example_df - metadata = {md: getattr(example_df, md) for md in METADATA} df = df.reset_index(level=["belief_time", "source", "cumulative_probability"]) df = df.resample(resolution).agg( { @@ -431,12 +420,11 @@ def _test_agg_resampling_retains_metadata(resolution): ) df = df.set_index(["belief_time", "source", "cumulative_probability"], append=True) print(df) - assert isinstance(df, tb.BeliefsDataFrame) - for md in metadata: - # if md == "event_resolution": - # assert df.event_resolution == resolution - # else: # todo: the event_resolution metadata is only updated when resampling using df.resample_events(). A reason to override the original resample method, or otherwise something to document. - assert getattr(df, md) == metadata[md] + assert_metadata_is_retained( + df, + original_df=example_df, + event_resolution=example_df.event_resolution, + ) # todo: the event_resolution metadata is only updated when resampling using df.resample_events(). A reason to override the original resample method, or otherwise something to document. def test_groupby_retains_metadata(): @@ -447,19 +435,14 @@ def test_groupby_retains_metadata(): Fixed with pandas==1.1.5 """ df = example_df - metadata = {md: getattr(example_df, md) for md in METADATA} def assert_function(x): print(x) - assert isinstance(x, tb.BeliefsDataFrame) - for md in metadata: - assert getattr(x, md) == metadata[md] + assert_metadata_is_retained(x, original_df=example_df) return x df = df.groupby(level="event_start").apply(lambda x: assert_function(x)) - assert isinstance(df, tb.BeliefsDataFrame) - for md in metadata: - assert getattr(df, md) == metadata[md] + assert_metadata_is_retained(df, original_df=example_df) def test_copy_series_retains_name_and_metadata(): @@ -589,3 +572,19 @@ def _constructor_sliced(self): df2 = getattr(df.groupby("x"), att)(*args) print(df2) assert df2.a == "b" + + +@pytest.mark.parametrize("constant", [1, -1, 3.14, timedelta(hours=1), ["TiledString"]]) +def test_multiplication_with_constant_retains_metadata(constant): + """ Check whether the metadata is still there after multiplication. """ + # GH 35 + df = example_df * constant + assert_metadata_is_retained(df, original_df=example_df) + + # Also check suggested workarounds from GH 35 + if constant == -1: + df = -example_df + assert_metadata_is_retained(df, original_df=example_df) + + df = example_df.abs() + assert_metadata_is_retained(df, original_df=example_df) diff --git a/timely_beliefs/tests/utils.py b/timely_beliefs/tests/utils.py index 05fc7d9c..3e368c8c 100644 --- a/timely_beliefs/tests/utils.py +++ b/timely_beliefs/tests/utils.py @@ -1,7 +1,37 @@ -from typing import Union +from datetime import timedelta +from typing import Optional, Union import numpy as np +import timely_beliefs as tb +from timely_beliefs.beliefs.classes import METADATA + def equal_lists(list_a: Union[list, np.ndarray], list_b: Union[list, np.ndarray]): return all(np.isclose(a, b) for a, b in zip(list_a, list_b)) + + +def assert_metadata_is_retained( + result_df: Union[tb.BeliefsDataFrame, tb.BeliefsSeries], + original_df: tb.BeliefsDataFrame, + is_series: bool = False, + event_resolution: Optional[timedelta] = None, +): + """Fail if result_df is not a BeliefsDataFrame with the same metadata as the original BeliefsDataFrame. + + Can also be used to check for a BeliefsSeries (using is_series=True). + + :param result_df: BeliefsDataFrame or BeliefsSeries to be checked for metadata propagation + :param original_df: BeliefsDataFrame containing the original metadata + :param is_series: if True, we check that the result is a BeliefsSeries rather than a BeliefsDataFrame + :param event_resolution: optional timedelta in case we expect a different event_resolution than the original + """ + metadata = {md: getattr(original_df, md) for md in METADATA} + assert isinstance( + result_df, tb.BeliefsDataFrame if not is_series else tb.BeliefsSeries + ) + for md in metadata: + if md == "event_resolution" and event_resolution is not None: + assert result_df.event_resolution == event_resolution + else: + assert getattr(result_df, md) == metadata[md]