diff --git a/scripts/_helpers.py b/scripts/_helpers.py index 73de2f92c..5babccaae 100644 --- a/scripts/_helpers.py +++ b/scripts/_helpers.py @@ -175,7 +175,22 @@ def path_provider(dir, rdir, shared_resources, exclude_from_shared): ) -def check_deprecated_config(config: dict, deprecations_file: str) -> None: +def _check_scenarios(config: dict, check_fn: Callable, fn_kwargs: dict) -> None: + """Helper function to check configuration in scenario files""" + scenarios = config.get("run", {}).get("scenarios", {}) + if scenarios.get("enable"): + with open(scenarios["file"]) as f: + scenario_config = yaml.safe_load(f) + + for run in scenario_config: + # Disable recursive scenario checking to avoid infinite loops + fn_kwargs["check_scenarios"] = False + check_fn(scenario_config[run], **fn_kwargs) + + +def check_deprecated_config( + config: dict, deprecations_file: str, check_scenarios: bool = True +) -> None: """Check config against deprecations and warn users""" with open(deprecations_file) as f: @@ -219,8 +234,15 @@ def set_by_path(root, path, value): warnings.warn(msg, DeprecationConfigWarning) + if check_scenarios: + _check_scenarios( + config, check_deprecated_config, {"deprecations_file": deprecations_file} + ) + -def check_invalid_config(config: dict, config_default_fn: str) -> None: +def check_invalid_config( + config: dict, config_default_fn: str, check_scenarios: bool = True +) -> None: """Check if config contains entries that are not supported by the default config""" with open(config_default_fn) as f: @@ -241,6 +263,11 @@ def check_keys(config, config_default, path=""): check_keys(config, config_default) + if check_scenarios: + _check_scenarios( + config, check_invalid_config, {"config_default_fn": config_default_fn} + ) + def get_opt(opts, expr, flags=None): """ diff --git a/test/test_config_checks.py b/test/test_config_checks.py index f0384334d..e18ee3bdf 100644 --- a/test/test_config_checks.py +++ b/test/test_config_checks.py @@ -22,6 +22,22 @@ - old_entry: "example:old_key" new_entry: "example:new_key" message: "Custom warning message" + +- old_entry: "clustering:deprecated_option" + message: "Custom warning message for deprecated clustering option" +""" + +SAMPLE_SCENARIOS = """ +scenario_1: + example: + old_key: "test_value" + invalid_section: + some_key: "value" + +scenario_2: + clustering: + invalid_option: "bad" + deprecated_option: "old_value" """ @@ -107,3 +123,63 @@ def test_config_invalid_entries(): # Verify warning types assert all(isinstance(w.message, InvalidConfigWarning) for w in captured_warnings) + + +def test_config_scenario_checks(): + """Test that configuration checks are performed on scenario files""" + config = { + "run": { + "scenarios": { + "enable": True, + "file": "scenarios.yaml", + } + } + } + + # Setup mock files + mock_files = { + "deprecations.yaml": SAMPLE_DEPRECATIONS, + "config.default.yaml": """ + run: + scenarios: + enable: true + file: scenarios.yaml + example: + new_key: "default" + clustering: + temporal: + resolution: 1 + """, + "scenarios.yaml": SAMPLE_SCENARIOS, + } + + def mock_open(filename, *args, **kwargs): + return StringIO(mock_files[filename.split("/")[-1]]) + + with warnings.catch_warnings(record=True) as captured_warnings: + with patch("builtins.open", mock_open): + # Check both deprecated and invalid entries in scenarios + check_deprecated_config(config, "deprecations.yaml") + check_invalid_config(config, "config.default.yaml") + + warning_messages = [str(w.message) for w in captured_warnings] + + # Verify warnings from scenario_1 + assert any("'example:old_key' is deprecated" in msg for msg in warning_messages) + assert any("'invalid_section' is not supported" in msg for msg in warning_messages) + + # Verify warnings from scenario_2 + assert any( + "'clustering:invalid_option' is not supported" in msg + for msg in warning_messages + ) + + # Verify warning types + deprecation_warnings = [ + w for w in captured_warnings if isinstance(w.message, DeprecationConfigWarning) + ] + invalid_warnings = [ + w for w in captured_warnings if isinstance(w.message, InvalidConfigWarning) + ] + assert len(deprecation_warnings) == 2 + assert len(invalid_warnings) == 4