Skip to content

Commit

Permalink
test: Updated test to complain with new code
Browse files Browse the repository at this point in the history
  • Loading branch information
pesap committed Oct 18, 2024
1 parent a00e374 commit f97922b
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 68 deletions.
15 changes: 10 additions & 5 deletions src/r2x/exporter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,13 @@ def apply_flatten_key(d: dict[str, Any], keys_to_flatten: set[str]) -> dict[str,
>>> flatten_selected_keys(d, ["y"])
{'x': {'min': 1, 'max': 2}, 'y_min': 5, 'y_max': 10, 'z': 42}
"""
return {
f"{key}_{sub_key}" if key in keys_to_flatten and isinstance(val, dict) else key: sub_val
for key, val in d.items()
for sub_key, sub_val in (val.items() if isinstance(val, dict) else [(key, val)])
}
flattened_dict = {}

for key, val in d.items():
if key in keys_to_flatten and isinstance(val, dict):
for sub_key, sub_val in val.items():
flattened_dict[f"{key}_{sub_key}"] = sub_val
else:
flattened_dict[key] = val

return flattened_dict
3 changes: 2 additions & 1 deletion src/r2x/models/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ class Generator(Device):

@field_serializer("active_power_limits")
def serialize_address(self, min_max: MinMax) -> dict[str, Any]:
return {"min": min_max.min, "max": min_max.max}
if min_max is not None:
return {"min": min_max.min, "max": min_max.max}


class RenewableGen(Generator):
Expand Down
12 changes: 9 additions & 3 deletions src/r2x/parser/polars_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import polars as pl
from loguru import logger
from polars.lazyframe import LazyFrame

from r2x.parser.plexos_utils import DATAFILE_COLUMNS

Expand Down Expand Up @@ -133,7 +134,7 @@ def pl_rename(
)


def pl_left_multi_join(l_df: pl.LazyFrame, *r_dfs: pl.LazyFrame, **kwargs):
def pl_left_multi_join(l_df: pl.LazyFrame, *r_dfs: pl.DataFrame, **kwargs):
"""Perform a left join on multiple DataFrames.
Parameters
Expand Down Expand Up @@ -164,11 +165,16 @@ def pl_left_multi_join(l_df: pl.LazyFrame, *r_dfs: pl.LazyFrame, **kwargs):
for r_df in r_dfs:
current_keys = set(r_df.collect_schema().names())
current_keys = original_keys.intersection(current_keys)
if isinstance(r_df, LazyFrame):
r_df = r_df.collect()
output_df = output_df.join(r_df, on=list(current_keys), how="left", coalesce=True)

output_df = output_df.collect()
if isinstance(output_df, pl.LazyFrame):
output_df = output_df.collect()

l_df_shape = l_df.collect().shape[0] if isinstance(l_df, pl.LazyFrame) else l_df.shape[0]
assert (
output_df.shape[0] == l_df.collect().shape[0]
output_df.shape[0] == l_df_shape
), f"Merge resulted in less rows. Check the shared keys. {original_keys=} vs {current_keys=}"
return output_df

Expand Down
25 changes: 13 additions & 12 deletions src/r2x/parser/reeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def build_system(self) -> System:
# NOTE: Rename to create topology
def _construct_buses(self):
logger.info("Creating bus objects.")
bus_data = self.get_data("hierarchy").collect()
bus_data = self.get_data("hierarchy")

zones = bus_data["transmission_region"].unique()
for zone in zones:
Expand All @@ -122,7 +122,7 @@ def _construct_buses(self):

def _construct_reserves(self):
logger.info("Creating reserves objects.")
bus_data = self.get_data("hierarchy").collect()
bus_data = self.get_data("hierarchy")

reserves = bus_data["transmission_region"].unique()
for reserve in reserves:
Expand Down Expand Up @@ -247,7 +247,7 @@ def _construct_tx_interfaces(self):
def _construct_emissions(self) -> None:
"""Construct emission objects."""
logger.info("Creating emission objects")
emit_rates = self.get_data("emission_rates").collect()
emit_rates = self.get_data("emission_rates")

emit_rates = emit_rates.with_columns(
pl.concat_str([pl.col("tech"), pl.col("tech_vintage"), pl.col("region")], separator="_").alias(
Expand Down Expand Up @@ -472,7 +472,7 @@ def _construct_generators(self) -> None: # noqa: C901
def _construct_load(self):
logger.info("Adding load time series.")

bus_data = self.get_data("hierarchy").collect()
bus_data = self.get_data("hierarchy")
load_df = self.get_data("load").collect()
start = datetime(year=self.weather_year, month=1, day=1)
resolution = timedelta(hours=1)
Expand All @@ -492,8 +492,8 @@ def _construct_load(self):
resolution=resolution,
)
user_dict = {"solve_year": self.config.weather_year}
max_load = np.max(ts.data.to_numpy())
load = PowerLoad(name=f"{bus.name}", bus=bus, max_active_power=max_load * ureg.MW)
max_load = np.max(ts.data)
load = PowerLoad(name=f"{bus.name}", bus=bus, max_active_power=max_load)
self.system.add_component(load)
self.system.add_time_series(ts, load, **user_dict)

Expand All @@ -503,11 +503,11 @@ def _construct_cf_time_series(self):
raise AttributeError("Missing weather year from the configuration class.")

cf_data = self.get_data("cf").collect()
cf_adjustment = self.get_data("cf_adjustment").collect()
cf_adjustment = self.get_data("cf_adjustment")
# NOTE: We take the median of the seasonal adjustment since we
# aggregate the generators by technology vintage
cf_adjustment = cf_adjustment.group_by("tech").agg(pl.col("cf_adj").median())
ilr = self.get_data("ilr").collect()
ilr = self.get_data("ilr")
ilr = dict(
ilr.group_by("tech").agg(pl.col("ilr").sum()).iter_rows()
) # Dict is more useful here than series
Expand Down Expand Up @@ -669,10 +669,11 @@ def _construct_hydro_profiles(self):
hydro_cf = hydro_cf.with_columns(
month=pl.col("month").map_elements(lambda row: month_map.get(row, row), return_dtype=pl.String)
)
hydro_cap_adj = self.get_data("hydro_cap_adj")
hydro_cap_adj = hydro_cap_adj.with_columns(
season=pl.col("season").map_elements(lambda row: season_map.get(row, row), return_dtype=pl.String)
)
# hydro_cap_adj = self.get_data("hydro_cap_adj")
# hydro_cap_adj = hydro_cap_adj.with_columns(
# season=pl.col("season").map_elements(lambda row: season_map.get(row, row),
# return_dtype=pl.String)
# )
hydro_minload = self.get_data("hydro_min_gen")
hydro_minload = hydro_minload.with_columns(
season=pl.col("season").map_elements(lambda row: season_map.get(row, row), return_dtype=pl.String)
Expand Down
57 changes: 12 additions & 45 deletions tests/test_csv_handler.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,29 @@
from pathlib import Path
import pytest

from tempfile import NamedTemporaryFile

import polars as pl
from r2x.parser.parser_helpers import csv_handler
from r2x.parser.plexos import (
PROPERTY_SV_COLUMNS_BASIC,
PROPERTY_SV_COLUMNS_NAMEYEAR,
PROPERTY_TS_COLUMNS_BASIC,
PROPERTY_TS_COLUMNS_MDH,
PROPERTY_TS_COLUMNS_MDP,
PROPERTY_TS_COLUMNS_MONTH_PIVOT,
PROPERTY_TS_COLUMNS_MULTIZONE,
PROPERTY_TS_COLUMNS_PIVOT,
PROPERTY_TS_COLUMNS_YM,
from r2x.parser.handler import csv_handler
from r2x.parser.plexos_utils import (
DATAFILE_COLUMNS,
get_column_enum,
)


@pytest.mark.parametrize(
"columns_case, csv_content",
"expected_enum,csv_content",
[
(PROPERTY_SV_COLUMNS_BASIC, "name,value\nTemp,25.5\nLoad,1200\nPressure,101.3"),
(
PROPERTY_SV_COLUMNS_NAMEYEAR,
"name,year,month,day,period,value\nTemp,2024,10,7,1,22.5\nLoad,2024,10,7,2,1150",
),
# Test case 3: PROPERTY_TS_COLUMNS_BASIC
(
PROPERTY_TS_COLUMNS_BASIC,
"year,month,day,period,value\n2024,10,7,1,23.0\n2024,10,7,2,24.5",
),
# Test case 4: PROPERTY_TS_COLUMNS_MULTIZONE
(PROPERTY_TS_COLUMNS_MULTIZONE, "year,month,day,period\n2024,10,7,1\n2024,10,7,2"),
# Test case 5: PROPERTY_TS_COLUMNS_PIVOT
(
PROPERTY_TS_COLUMNS_PIVOT,
"name,year,month,day,value\nTemp,2024,10,7,23.5\nLoad,2024,10,7,1150",
),
# Test case 6: PROPERTY_TS_COLUMNS_YM
(PROPERTY_TS_COLUMNS_YM, "year,month\n2024,10\n2023,9"),
# Test case 7: PROPERTY_TS_COLUMNS_MDP
(PROPERTY_TS_COLUMNS_MDP, "month,day,period\n10,7,1\n9,5,2"),
# Test case 8: PROPERTY_TS_COLUMNS_MDH
(
PROPERTY_TS_COLUMNS_MDH,
"name,month,day,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24\nRating,10,7,100,110,120,130,140,150,160,170,180,190,200,210,220,230,240,250,260,270,280,290,300,310,320",
),
# Test case 9: PROPERTY_TS_COLUMNS_MONTH_PIVOT
(
PROPERTY_TS_COLUMNS_MONTH_PIVOT,
"name,m01,m02,m03,m04,m05,m06,m07,m08,m09,m10,m11,m12\nRating,100,110,120,130,140,150,160,170,180,190,200,210",
),
(DATAFILE_COLUMNS.NV, "name,value\nTemp,25.5\nLoad,1200\nPressure,101.3"),
(DATAFILE_COLUMNS.TS_NYV, "name,year,value\nTemp,2030,25.5\nLoad,2030,1200\nPressure,2030,101.3"),
],
)
def test_csv_handler(columns_case, csv_content):
def test_csv_handler(expected_enum, csv_content):
with NamedTemporaryFile(mode="w+", suffix=".csv", delete=False) as temp_file:
temp_file.write(csv_content)
temp_file.seek(0)

df_csv = csv_handler(temp_file.name)
df_csv = csv_handler(Path(temp_file.name))
assert isinstance(df_csv, pl.DataFrame)
column_type = get_column_enum(df_csv.columns)
assert column_type == expected_enum
4 changes: 2 additions & 2 deletions tests/test_parser_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pathlib import Path
from polars.testing import assert_frame_equal
from tempfile import NamedTemporaryFile
from r2x.parser.parser_helpers import csv_handler
from r2x.parser.handler import csv_handler


@pytest.fixture
Expand All @@ -17,7 +17,7 @@ def temp_csv_file(sample_csv_basic):
with NamedTemporaryFile(mode="w", delete=False, suffix=".csv") as temp_file:
temp_file.write(sample_csv_basic)
temp_file.flush()
return temp_file.name
return Path(temp_file.name)


def test_csv_handler_basic(temp_csv_file):
Expand Down

0 comments on commit f97922b

Please sign in to comment.