diff --git a/ski_bearings/tests/test_bearing.py b/ski_bearings/tests/test_bearing.py index bfe7c7cd31..a3bfcc3ddf 100644 --- a/ski_bearings/tests/test_bearing.py +++ b/ski_bearings/tests/test_bearing.py @@ -1,3 +1,4 @@ +import itertools from dataclasses import dataclass from typing import Literal @@ -9,7 +10,7 @@ analyze_all_ski_areas, ) from ski_bearings.bearing import get_bearing_summary_stats -from ski_bearings.openskimap_utils import load_runs +from ski_bearings.openskimap_utils import get_ski_area_to_runs, load_runs from ski_bearings.osmnx_utils import create_networkx_with_metadata @@ -145,7 +146,12 @@ def test_get_bearing_summary_stats_repeated_aggregation() -> None: """ # aggregate all runs at once all_runs = load_runs() - combined_graph = create_networkx_with_metadata(all_runs, ski_area_metadata={}) + # we cannot create networkx graph directly from all runs because get_ski_area_to_runs performs some filtering + ski_area_to_runs = get_ski_area_to_runs(all_runs) + all_runs_filtered = list(itertools.chain.from_iterable(ski_area_to_runs.values())) + combined_graph = create_networkx_with_metadata( + all_runs_filtered, ski_area_metadata={} + ) single_pass = combined_graph.graph # aggregate runs by ski area and then aggregate ski areas analyze_all_ski_areas() @@ -159,6 +165,8 @@ def test_get_bearing_summary_stats_repeated_aggregation() -> None: ) double_pass = hemisphere_pl.row(by_predicate=pl.lit(True), named=True) for key in [ + "run_count", + "run_count_filtered", "mean_bearing", "mean_bearing_strength", "poleward_affinity",