Skip to content

Commit

Permalink
feat: add checks for scenario config
Browse files Browse the repository at this point in the history
  • Loading branch information
FabianHofmann committed Jan 26, 2025
1 parent ae02d82 commit a848ea3
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 2 deletions.
31 changes: 29 additions & 2 deletions scripts/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,22 @@ def path_provider(dir, rdir, shared_resources, exclude_from_shared):
)


def check_deprecated_config(config: dict, deprecations_file: str) -> None:
def _check_scenarios(config: dict, check_fn: Callable, fn_kwargs: dict) -> None:
"""Helper function to check configuration in scenario files"""
scenarios = config.get("run", {}).get("scenarios", {})
if scenarios.get("enable"):
with open(scenarios["file"]) as f:
scenario_config = yaml.safe_load(f)

for run in scenario_config:
# Disable recursive scenario checking to avoid infinite loops
fn_kwargs["check_scenarios"] = False
check_fn(scenario_config[run], **fn_kwargs)


def check_deprecated_config(
config: dict, deprecations_file: str, check_scenarios: bool = True
) -> None:
"""Check config against deprecations and warn users"""

with open(deprecations_file) as f:
Expand Down Expand Up @@ -219,8 +234,15 @@ def set_by_path(root, path, value):

warnings.warn(msg, DeprecationConfigWarning)

if check_scenarios:
_check_scenarios(
config, check_deprecated_config, {"deprecations_file": deprecations_file}
)


def check_invalid_config(config: dict, config_default_fn: str) -> None:
def check_invalid_config(
config: dict, config_default_fn: str, check_scenarios: bool = True
) -> None:
"""Check if config contains entries that are not supported by the default config"""

with open(config_default_fn) as f:
Expand All @@ -241,6 +263,11 @@ def check_keys(config, config_default, path=""):

check_keys(config, config_default)

if check_scenarios:
_check_scenarios(
config, check_invalid_config, {"config_default_fn": config_default_fn}
)


def get_opt(opts, expr, flags=None):
"""
Expand Down
76 changes: 76 additions & 0 deletions test/test_config_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,22 @@
- old_entry: "example:old_key"
new_entry: "example:new_key"
message: "Custom warning message"
- old_entry: "clustering:deprecated_option"
message: "Custom warning message for deprecated clustering option"
"""

SAMPLE_SCENARIOS = """
scenario_1:
example:
old_key: "test_value"
invalid_section:
some_key: "value"
scenario_2:
clustering:
invalid_option: "bad"
deprecated_option: "old_value"
"""


Expand Down Expand Up @@ -107,3 +123,63 @@ def test_config_invalid_entries():

# Verify warning types
assert all(isinstance(w.message, InvalidConfigWarning) for w in captured_warnings)


def test_config_scenario_checks():
"""Test that configuration checks are performed on scenario files"""
config = {
"run": {
"scenarios": {
"enable": True,
"file": "scenarios.yaml",
}
}
}

# Setup mock files
mock_files = {
"deprecations.yaml": SAMPLE_DEPRECATIONS,
"config.default.yaml": """
run:
scenarios:
enable: true
file: scenarios.yaml
example:
new_key: "default"
clustering:
temporal:
resolution: 1
""",
"scenarios.yaml": SAMPLE_SCENARIOS,
}

def mock_open(filename, *args, **kwargs):
return StringIO(mock_files[filename.split("/")[-1]])

with warnings.catch_warnings(record=True) as captured_warnings:
with patch("builtins.open", mock_open):
# Check both deprecated and invalid entries in scenarios
check_deprecated_config(config, "deprecations.yaml")
check_invalid_config(config, "config.default.yaml")

warning_messages = [str(w.message) for w in captured_warnings]

# Verify warnings from scenario_1
assert any("'example:old_key' is deprecated" in msg for msg in warning_messages)
assert any("'invalid_section' is not supported" in msg for msg in warning_messages)

# Verify warnings from scenario_2
assert any(
"'clustering:invalid_option' is not supported" in msg
for msg in warning_messages
)

# Verify warning types
deprecation_warnings = [
w for w in captured_warnings if isinstance(w.message, DeprecationConfigWarning)
]
invalid_warnings = [
w for w in captured_warnings if isinstance(w.message, InvalidConfigWarning)
]
assert len(deprecation_warnings) == 2
assert len(invalid_warnings) == 4

0 comments on commit a848ea3

Please sign in to comment.