Skip to content

Commit

Permalink
test: Adding further testing and linting
Browse files Browse the repository at this point in the history
  • Loading branch information
pesap committed Jan 14, 2025
1 parent 26be706 commit 4b581d6
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 60 deletions.
17 changes: 17 additions & 0 deletions src/r2x/config_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,23 @@ class Models(StrEnum):
PRAS = "PRAS"


class ParserModels(StrEnum):
"""Enum of valid parser models supported."""

INFRASYS = "INFRASYS"
REEDS = "REEDS-US"
PLEXOS = "PLEXOS"
SIENNA = "SIENNA"


class ExporterModels(StrEnum):
"""Enum of valid exporter models supported."""

PLEXOS = "PLEXOS"
SIENNA = "SIENNA"
INFRASYS = "INFRASYS"


MODEL_CONFIGS = {
Models.REEDS: ReEDSConfig,
Models.PLEXOS: PlexosConfig,
Expand Down
32 changes: 20 additions & 12 deletions src/r2x/config_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ class Scenario:
Attributes
----------
name
Name for the scenario
Name of the translation scenario
run_folder
Path for the scenario inputs
Path to input model files
output_folder
Path for output exports
Path for placing the output translation files
input_model
Model to translate from
output_model
Expand All @@ -57,6 +57,8 @@ class Scenario:
Dictionary with experimental features
plugins
List of plugins enabled
user_dict
Dictionary with user configuration.
See Also
--------
Expand All @@ -80,7 +82,11 @@ def __post_init__(self) -> None:
self._load_plugins()
self._load_model_config()

# Overload configuration if the key appear on the config file
# Overload default initialization of the translation scenario if the user pass a `user_dict`. We allow
# overriding the `input_config.defaults` and/or `output_config.defaults`. Also, we allow the user
# to change the default file mapping by passing a `fmap` key and passing new key value pairs for
# predetermined files that we read for each model (see `{model}_fmap.json`}. The override only happens
# if the exact key appears on the `user_dict.
if self.user_dict and self.input_config:
self.input_config.defaults = update_dict(self.input_config.defaults, self.user_dict)
self.input_config.fmap = update_dict(self.input_config.fmap, self.user_dict.get("fmap", {}))
Expand Down Expand Up @@ -187,13 +193,11 @@ def from_kwargs(

@dataclass
class Configuration:
"""r2x.config_scenariouration manager that wraps multiple Scenario instances.
This class parses either the cases_*.csv file or reads the inputs from the CLI.
"""Configuration manager that wraps multiple Scenario instances.
Attributes
----------
scenarios_list
scenarios
Dictionary of scenarios to be translated.
See Also
Expand Down Expand Up @@ -237,13 +241,17 @@ def list_scenarios(self):
return self.scenarios.keys()

@classmethod
def from_cli(cls, cli_args: dict, user_dict: dict | None = None, **kwargs):
"""Create scenario from the CLI arguments.
def from_cli(cls, cli_args: dict[str, Any], user_dict: dict[str, Any] | None = None, **kwargs):
"""Create a `Scenario` from the CLI arguments.
It saves the created scenario in the scenario_list` and scenario_names
It saves the created scenario using CLI input. It can also be overwrited if an `user_dict` is passed.
Parameters
----------
cli_args
Arguments for constructing the scenario.
user_dict
Configuration for the translation.
kwargs
Arguments for constructing the scenario.
Expand All @@ -260,7 +268,7 @@ def from_cli(cls, cli_args: dict, user_dict: dict | None = None, **kwargs):
return instance

@classmethod
def from_scenarios(cls, cli_args: dict, user_dict: dict, **kwargs):
def from_scenarios(cls, cli_args: dict, user_dict: dict):
"""Create scenario from scenarios on the config file.
This method takes the `user_dict['scenarios'] key which is a list of dicts to create the different
Expand Down
17 changes: 12 additions & 5 deletions src/r2x/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import inspect
from loguru import logger

from .config_models import BaseModelConfig, MODEL_CONFIGS, Models
from .config_models import BaseModelConfig, MODEL_CONFIGS, Models, ExporterModels, ParserModels

from .utils import read_json, read_fmap

Expand All @@ -27,7 +27,11 @@ def get_input_defaults(model_enum: Models) -> dict:
defaults_dict = defaults_dict | read_json("r2x/defaults/plexos_input.json")
logger.debug("Returning input_model {} defaults", model_enum)
case _:
logger.warning("No input model passed")
msg = (
f"Unsupported input model: {model_enum}. "
f"Supported models: {[str(model) for model in ParserModels]}"
)
raise ValueError(msg)
return defaults_dict


Expand All @@ -46,8 +50,12 @@ def get_output_defaults(model_enum: Models) -> dict:
defaults_dict = read_json("r2x/defaults/sienna_config.json")
logger.debug("Returning sienna defaults")
case _:
msg = f"Unsupported model: {model_enum}. Supported models: {list(MODEL_CONFIGS.keys())}"
msg = (
f"Unsupported input model: {model_enum}. "
f"Supported models: {[str(model) for model in ExporterModels]}"
)
raise ValueError(msg)

if not defaults_dict:
return {}

Expand All @@ -66,8 +74,7 @@ def get_input_model_fmap(model_enum: Models) -> dict:
case Models.PLEXOS:
fmap = read_fmap("r2x/defaults/plexos_mapping.json")
case _:
logger.error("Input model {} not recognized", model_enum)
raise KeyError(f"Input model {model_enum=} not valid")
raise ValueError(f"Input model {model_enum=} not valid")
return fmap


Expand Down
51 changes: 33 additions & 18 deletions src/r2x/exporter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,41 +310,56 @@ def apply_flatten_key(d: dict[str, Any], keys_to_flatten: set[str]) -> dict[str,
"""
flattened_dict = {}

for key, val in d.items():
if key in keys_to_flatten and isinstance(val, dict):
for sub_key, sub_val in val.items():
flattened_dict[f"{key}_{sub_key}"] = sub_val
for key, value in d.items():
if key in keys_to_flatten and isinstance(value, dict):
for inner_key, inner_value in value.items():
flattened_dict[f"{key}_{inner_key}"] = inner_value
else:
flattened_dict[key] = val
flattened_dict[key] = value

return flattened_dict


def apply_extract_key(d: dict[str, Any], key: str, keys_to_extract: set[str]) -> dict[str, Any]:
"""Extract keys from a nested dictionary and put it in first level.
"""Extract keys from a nested dictionary and put them in the first level if specific conditions are met.
Parameters
----------
d : dict
The input dictionary, where some values are dictionaries to be flattened.
key: dict
Key that has a dictionary
keys_to_extract : list of str
The keys in the nested dictionary that will be extracted
The input dictionary that may contain nested dictionaries.
key : str
The key in the input dictionary whose value should be a nested dictionary.
keys_to_extract : set[str]
The set of keys to extract from the nested dictionary.
Returns
-------
dict
A new dictionary with the selected keys flattened. Other keys remain unchanged.
If conditions are met: a new dictionary containing all original key-value pairs
plus the extracted key-value pairs at the top level.
If conditions are not met: returns the input dictionary unchanged.
Notes
-----
The function will return the input dictionary unchanged if any of these conditions are met:
- The specified key is not in the input dictionary
- Any of keys_to_extract already exist in the top level of input dictionary
- None of keys_to_extract exist in the nested dictionary
Examples
--------
>>> d = {"x": {"min": 1, "max": 2}, "y": {"min": 5, "max": 10}, "z": 42}
>>> flatten_selected_keys(d, ["x"])
{'x_min': 1, 'x_max': 2, 'y': {'min': 5, 'max': 10}, 'z': 42}
>>> flatten_selected_keys(d, ["y"])
{'x': {'min': 1, 'max': 2}, 'y_min': 5, 'y_max': 10, 'z': 42}
>>> component = {"name": "Gen01", "rating": {"up": 100, "down": -100}}
>>> apply_extract_key(component, "rating", {"up"})
{'name': 'Gen01', 'rating': {'up': 100, 'down': -100}, 'up': 100}
>>> # Returns unchanged when extracted key already exists at top level
>>> d = {"name": "Gen01", "rating": {"up": 100}, "up": 50}
>>> apply_extract_key(d, "rating", {"up"})
{'name': 'Gen01', 'rating': {'up': 100}, 'up': 50}
>>> # Returns unchanged when no keys_to_extract exist in nested dict
>>> apply_extract_key(component, "rating", {"middle"})
{'name': 'Gen01', 'rating': {'up': 100, 'down': -100}}
"""
if key not in d.keys() or any(k in d.keys() for k in keys_to_extract):
return d
Expand Down
9 changes: 0 additions & 9 deletions src/r2x/parser/reeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,9 +705,6 @@ def _construct_hydro_budgets(self) -> None:
hydro_cf,
month_hrs,
)
# month_of_hour = np.array(
# [dt.astype("datetime64[M]").astype(int) % 12 + 1 for dt in self.hourly_time_index]
# )
month_of_day = np.array(
[dt.astype("datetime64[M]").astype(int) % 12 + 1 for dt in self.daily_time_index]
)
Expand All @@ -721,8 +718,6 @@ def _construct_hydro_budgets(self) -> None:
region = generator.bus.name
hydro_ratings = hydro_data.filter((pl.col("tech") == tech) & (pl.col("region") == region))

# hourly_time_series = np.zeros(len(month_of_hour), dtype=float)
# if self.config.feature_flags.get("daily-budgets", None):
hourly_time_series = np.zeros(len(month_of_day), dtype=float)

for row in hydro_ratings.iter_rows(named=True):
Expand All @@ -733,12 +728,8 @@ def _construct_hydro_budgets(self) -> None:
month_max_budget = (
generator.active_power * Percentage(row["hydro_cf"], "") * Time(row["hrs"], "h")
)
# if self.config.feature_flags.get("daily-budgets", None):
daily_max_budget = month_max_budget / (row["hrs"] / 24)
hourly_time_series[month_of_day == month] = daily_max_budget.magnitude
# else:
# month_indices = month_of_hour == month
# hourly_time_series[month_indices] = month_max_budget.magnitude

ts = SingleTimeSeries.from_array(
Energy(hourly_time_series / 1e3, "GWh"),
Expand Down
33 changes: 17 additions & 16 deletions src/r2x/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,17 @@
from importlib.resources import files
from pathlib import Path
from itertools import islice
from typing import Any, Hashable, Sequence

# Third-party packages
import numpy as np
import pandas as pd
import polars as pl
import pyarrow as pa
import pyarrow.dataset as ds
import pyarrow.parquet as pq
from tables import file
import yaml
from jsonschema import validate
from loguru import logger
import pint
from pint import UndefinedUnitError
from infrasys.base_quantity import BaseQuantity
from r2x.models import Generator
from r2x.units import ureg


Expand Down Expand Up @@ -116,7 +110,8 @@ def get_mean_data(
rename_dict: Dictionary passed to pd.DataFrame().rename(columns=rename_dict),
categories: Keys to aggregate the data.
Returns:
Returns
-------
Aggregated dataframe
"""
if not rename_dict:
Expand Down Expand Up @@ -178,12 +173,12 @@ def read_user_dict(fname: str) -> dict:
def _load_file(fname: str, loader) -> dict:
"""Helper function to load a file (either JSON or YAML)."""
try:
with open(fname, "r") as f:
with open(fname) as f:
return loader(f)
except FileNotFoundError:
raise FileNotFoundError(f"File {fname} not found.")
except IOError as e:
raise IOError(f"Error reading the file {fname}: {e}")
except OSError as e:
raise OSError(f"Error reading the file {fname}: {e}")


def read_json(fname: str):
Expand Down Expand Up @@ -214,7 +209,8 @@ def get_missing_columns(fpath: str, column_names: list) -> list:
fpath: Path to the csv file
column_names: list of columns to verify
Returns:
Returns
-------
A list of missing columns or empty list
"""
try:
Expand All @@ -237,7 +233,8 @@ def get_missing_files(project_folder: str, file_list: Iterable, max_depth: int =
file_list: Iterable of files to check
max_depth: Level of subfolders to look.
Returns:
Returns
-------
A list with the missing files or empty list
"""
all_files = set()
Expand Down Expand Up @@ -271,7 +268,8 @@ def read_csv(fname: str, package_data: str = "r2x.defaults", **kwargs) -> pl.Laz
package_data: Location of file in package. Default location is r2x.defaults
**kwargs: Additional keys passed to pandas read_csv function
Returns:
Returns
-------
A pandas dataframe of the csv requested
"""
csv_file = files(package_data).joinpath(fname).read_text(encoding="utf-8-sig")
Expand All @@ -283,7 +281,8 @@ def get_timeindex(
) -> pd.DatetimeIndex:
"""ReEDS time indices are in EST, and leap years drop Dec 31 instead of Feb 29.
Notes:
Notes
-----
- Function courtesy of P. Brown.
Args:
Expand Down Expand Up @@ -348,7 +347,8 @@ def check_file_exists(
default_folders: Default location to look for files
mandatory: Flag to identify needed files
Returns:
Returns
-------
fpath or None
"""
run_folder = Path(run_folder)
Expand Down Expand Up @@ -394,7 +394,8 @@ def get_csv(fpath: str, fname: str, fmap: dict[str, str | dict | list] = {}, **k
fmap: File mapping values
kwargs: Additional key arguments for pandas mostly
Attributes:
Attributes
----------
data: ReEDS parsed data for PCM.
"""
logger.debug(f"Attempting to read {fname}")
Expand Down
Loading

0 comments on commit 4b581d6

Please sign in to comment.