From dfc66779df1dff83a1a0a56c82029e61f2458598 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=98yvind=20Eide?= Date: Fri, 17 Jan 2025 11:58:22 +0100 Subject: [PATCH 1/6] Convert from dataclass to BaseModel --- src/ert/config/queue_config.py | 65 +++++++++---------- src/everest/config/simulator_config.py | 2 +- .../config/config_dict_generator.py | 53 +++++++++------ .../unit_tests/config/test_queue_config.py | 5 +- 4 files changed, 66 insertions(+), 59 deletions(-) diff --git a/src/ert/config/queue_config.py b/src/ert/config/queue_config.py index c0b58482324..3df9b97c549 100644 --- a/src/ert/config/queue_config.py +++ b/src/ert/config/queue_config.py @@ -5,10 +5,10 @@ import re import shutil from abc import abstractmethod -from dataclasses import asdict, field, fields from typing import Annotated, Any, Literal, no_type_check import pydantic +from pydantic import BaseModel, Field from pydantic.dataclasses import dataclass from ._get_num_cpu import get_num_cpu_from_data_file @@ -37,20 +37,18 @@ def activate_script() -> str: return "" -@pydantic.dataclasses.dataclass( - config={ - "extra": "forbid", - "validate_assignment": True, - "use_enum_values": True, - "validate_default": True, - } -) -class QueueOptions: +class QueueOptions( + BaseModel, + validate_assignment=True, + extra="forbid", + use_enum_values=True, + validate_default=True, +): name: QueueSystem max_running: pydantic.NonNegativeInt = 0 submit_sleep: pydantic.NonNegativeFloat = 0.0 project_code: str | None = None - activate_script: str = field(default_factory=activate_script) + activate_script: str = Field(default_factory=activate_script) @staticmethod def create_queue_options( @@ -78,12 +76,12 @@ def create_queue_options( return None def add_global_queue_options(self, config_dict: ConfigDict) -> None: - for generic_option in fields(QueueOptions): + for name, generic_option in QueueOptions.model_fields.items(): if ( - generic_value := config_dict.get(generic_option.name.upper(), None) # type: ignore - ) and self.__dict__[generic_option.name] == generic_option.default: + generic_value := config_dict.get(name.upper(), None) # type: ignore + ) and self.__dict__[name] == generic_option.default: try: - setattr(self, generic_option.name, generic_value) + setattr(self, name, generic_value) except pydantic.ValidationError as exception: for error in exception.errors(): _throw_error_or_warning( @@ -98,7 +96,6 @@ def driver_options(self) -> dict[str, Any]: """Translate the queue options to the key-value API provided by each driver""" -@pydantic.dataclasses.dataclass class LocalQueueOptions(QueueOptions): name: Literal[QueueSystem.LOCAL] = QueueSystem.LOCAL @@ -107,7 +104,6 @@ def driver_options(self) -> dict[str, Any]: return {} -@pydantic.dataclasses.dataclass class LsfQueueOptions(QueueOptions): name: Literal[QueueSystem.LSF] = QueueSystem.LSF bhist_cmd: NonEmptyString | None = None @@ -120,17 +116,13 @@ class LsfQueueOptions(QueueOptions): @property def driver_options(self) -> dict[str, Any]: - driver_dict = asdict(self) - driver_dict.pop("name") + driver_dict = self.model_dump(exclude={"name", "submit_sleep", "max_running"}) driver_dict["exclude_hosts"] = driver_dict.pop("exclude_host") driver_dict["queue_name"] = driver_dict.pop("lsf_queue") driver_dict["resource_requirement"] = driver_dict.pop("lsf_resource") - driver_dict.pop("submit_sleep") - driver_dict.pop("max_running") return driver_dict -@pydantic.dataclasses.dataclass class TorqueQueueOptions(QueueOptions): name: Literal[QueueSystem.TORQUE] = QueueSystem.TORQUE qsub_cmd: NonEmptyString | None = None @@ -143,15 +135,19 @@ class TorqueQueueOptions(QueueOptions): @property def driver_options(self) -> dict[str, Any]: - driver_dict = asdict(self) - driver_dict.pop("name") + driver_dict = self.model_dump( + exclude={ + "name", + "max_running", + "submit_sleep", + "qstat_options", + "queue_query_timeout", + } + ) driver_dict["queue_name"] = driver_dict.pop("queue") - driver_dict.pop("max_running") - driver_dict.pop("submit_sleep") return driver_dict -@pydantic.dataclasses.dataclass class SlurmQueueOptions(QueueOptions): name: Literal[QueueSystem.SLURM] = QueueSystem.SLURM sbatch: NonEmptyString = "sbatch" @@ -167,8 +163,7 @@ class SlurmQueueOptions(QueueOptions): @property def driver_options(self) -> dict[str, Any]: - driver_dict = asdict(self) - driver_dict.pop("name") + driver_dict = self.model_dump(exclude={"name", "max_running", "submit_sleep"}) driver_dict["sbatch_cmd"] = driver_dict.pop("sbatch") driver_dict["scancel_cmd"] = driver_dict.pop("scancel") driver_dict["scontrol_cmd"] = driver_dict.pop("scontrol") @@ -177,8 +172,6 @@ def driver_options(self) -> dict[str, Any]: driver_dict["exclude_hosts"] = driver_dict.pop("exclude_host") driver_dict["include_hosts"] = driver_dict.pop("include_host") driver_dict["queue_name"] = driver_dict.pop("partition") - driver_dict.pop("max_running") - driver_dict.pop("submit_sleep") return driver_dict @@ -203,12 +196,12 @@ def validate(self, mem_str_format: str | None) -> bool: ) valid_options: dict[str, list[str]] = { - QueueSystem.LOCAL: [field.name.upper() for field in fields(LocalQueueOptions)], - QueueSystem.LSF: [field.name.upper() for field in fields(LsfQueueOptions)], - QueueSystem.SLURM: [field.name.upper() for field in fields(SlurmQueueOptions)], - QueueSystem.TORQUE: [field.name.upper() for field in fields(TorqueQueueOptions)], + QueueSystem.LOCAL: [field.upper() for field in LocalQueueOptions.model_fields], + QueueSystem.LSF: [field.upper() for field in LsfQueueOptions.model_fields], + QueueSystem.SLURM: [field.upper() for field in SlurmQueueOptions.model_fields], + QueueSystem.TORQUE: [field.upper() for field in TorqueQueueOptions.model_fields], QueueSystemWithGeneric.GENERIC: [ - field.name.upper() for field in fields(QueueOptions) + field.upper() for field in QueueOptions.model_fields ], } diff --git a/src/everest/config/simulator_config.py b/src/everest/config/simulator_config.py index 9d277e7f2a6..126ea7217b5 100644 --- a/src/everest/config/simulator_config.py +++ b/src/everest/config/simulator_config.py @@ -29,7 +29,7 @@ def check_removed_config(queue_system): } if isinstance(queue_system, str) and queue_system in queue_systems: raise ValueError( - f"Queue system configuration has changed, valid options for {queue_system} are: {list(queue_systems[queue_system].__dataclass_fields__.keys())}" + f"Queue system configuration has changed, valid options for {queue_system} are: {list(queue_systems[queue_system].model_fields.keys())}" ) diff --git a/tests/ert/unit_tests/config/config_dict_generator.py b/tests/ert/unit_tests/config/config_dict_generator.py index 82b85e9a4fa..ae2064cb399 100644 --- a/tests/ert/unit_tests/config/config_dict_generator.py +++ b/tests/ert/unit_tests/config/config_dict_generator.py @@ -4,9 +4,9 @@ import os.path import stat from collections import defaultdict -from dataclasses import dataclass, fields +from dataclasses import dataclass from pathlib import Path -from typing import Any, Literal +from typing import Any, Literal, get_args, get_origin from warnings import filterwarnings import hypothesis.strategies as st @@ -127,37 +127,52 @@ def memory_with_unit_lsf(draw): def valid_queue_options(queue_system: str): return [ - field.name.upper() - for field in fields( - queue_systems_and_options[QueueSystemWithGeneric(queue_system)] - ) - if field.name != "name" + name.upper() + for name in queue_systems_and_options[ + QueueSystemWithGeneric(queue_system) + ].model_fields + if name != "name" ] +def has_base_type( + field_type, base_type: type[int] | bool | type[str] | type[float] +) -> bool: + if field_type is base_type: + return True + origin = get_origin(field_type) + if origin: + args = get_args(field_type) + if any(arg is base_type for arg in args): + return True + return any(has_base_type(arg, base_type) for arg in args) + return False + + queue_options_by_type: dict[str, dict[str, list[str]]] = defaultdict(dict) for system, options in queue_systems_and_options.items(): queue_options_by_type["string"][system.name] = [ - field.name.upper() - for field in fields(options) - if ("String" in field.type or "str" in field.type) - and "memory" not in field.name + name.upper() + for name, field in options.model_fields.items() + if has_base_type(field.annotation, str) and "memory" not in name ] queue_options_by_type["bool"][system.name] = [ - field.name.upper() for field in fields(options) if field.type == "bool" + name.upper() + for name, field in options.model_fields.items() + if has_base_type(field.annotation, bool) ] queue_options_by_type["posint"][system.name] = [ - field.name.upper() - for field in fields(options) - if "PositiveInt" in field.type or "NonNegativeInt" in field.type + name.upper() + for name, field in options.model_fields.items() + if has_base_type(field.annotation, int) ] queue_options_by_type["posfloat"][system.name] = [ - field.name.upper() - for field in fields(options) - if "NonNegativeFloat" in field.type or "PositiveFloat" in field.type + name.upper() + for name, field in options.model_fields.items() + if has_base_type(field.annotation, float) ] queue_options_by_type["memory"][system.name] = [ - field.name.upper() for field in fields(options) if "memory" in field.name + name.upper() for name in options.model_fields if "memory" in name ] diff --git a/tests/ert/unit_tests/config/test_queue_config.py b/tests/ert/unit_tests/config/test_queue_config.py index b138c8a15e5..a908bf9d01b 100644 --- a/tests/ert/unit_tests/config/test_queue_config.py +++ b/tests/ert/unit_tests/config/test_queue_config.py @@ -15,7 +15,6 @@ from ert.config.queue_config import ( LocalQueueOptions, LsfQueueOptions, - QueueOptions, SlurmQueueOptions, TorqueQueueOptions, ) @@ -422,7 +421,7 @@ def test_default_activate_script_generation(expected, monkeypatch, venv): def test_conda_activate_script_generation(expected, monkeypatch, env): monkeypatch.setenv("VIRTUAL_ENV", "") monkeypatch.setenv("CONDA_ENV", env) - options = QueueOptions(name="local") + options = LocalQueueOptions(name="local") assert options.activate_script == expected @@ -433,7 +432,7 @@ def test_conda_activate_script_generation(expected, monkeypatch, env): def test_multiple_activate_script_generation(expected, monkeypatch, env): monkeypatch.setenv("VIRTUAL_ENV", env) monkeypatch.setenv("CONDA_ENV", env) - options = QueueOptions(name="local") + options = LocalQueueOptions(name="local") assert options.activate_script == expected From 137398ada8bdbaf2da5212b5cd7a36abb4254357 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=98yvind=20Eide?= Date: Fri, 17 Jan 2025 12:00:07 +0100 Subject: [PATCH 2/6] Add function to propagate context to validation --- src/ert/config/parsing/__init__.py | 2 ++ src/ert/config/parsing/base_model_context.py | 26 ++++++++++++++++ src/ert/config/queue_config.py | 20 +++++++++++-- src/everest/config/everest_config.py | 15 ++++++++-- src/everest/config/server_config.py | 12 +------- src/everest/config/simulator_config.py | 5 ---- tests/everest/test_detached.py | 31 +++++++++++++++++++- 7 files changed, 89 insertions(+), 22 deletions(-) create mode 100644 src/ert/config/parsing/base_model_context.py diff --git a/src/ert/config/parsing/__init__.py b/src/ert/config/parsing/__init__.py index 63d993073cc..ad2d5043a0e 100644 --- a/src/ert/config/parsing/__init__.py +++ b/src/ert/config/parsing/__init__.py @@ -1,3 +1,4 @@ +from .base_model_context import BaseModelWithContextSupport from .config_dict import ConfigDict from .config_errors import ConfigValidationError, ConfigWarning from .config_keywords import ConfigKeys @@ -18,6 +19,7 @@ from .workflow_schema import init_workflow_schema __all__ = [ + "BaseModelWithContextSupport", "ConfigDict", "ConfigKeys", "ConfigValidationError", diff --git a/src/ert/config/parsing/base_model_context.py b/src/ert/config/parsing/base_model_context.py new file mode 100644 index 00000000000..29bdf17d4cb --- /dev/null +++ b/src/ert/config/parsing/base_model_context.py @@ -0,0 +1,26 @@ +from collections.abc import Iterator +from contextlib import contextmanager +from contextvars import ContextVar +from typing import Any + +from pydantic import BaseModel + +init_context_var = ContextVar("_init_context_var", default=None) + + +@contextmanager +def init_context(value: dict[str, Any]) -> Iterator[None]: + token = init_context_var.set(value) # type: ignore + try: + yield + finally: + init_context_var.reset(token) + + +class BaseModelWithContextSupport(BaseModel): + def __init__(__pydantic_self__, **data: Any) -> None: + __pydantic_self__.__pydantic_validator__.validate_python( + data, + self_instance=__pydantic_self__, + context=init_context_var.get(), + ) diff --git a/src/ert/config/queue_config.py b/src/ert/config/queue_config.py index 3df9b97c549..d8e04734968 100644 --- a/src/ert/config/queue_config.py +++ b/src/ert/config/queue_config.py @@ -8,11 +8,13 @@ from typing import Annotated, Any, Literal, no_type_check import pydantic -from pydantic import BaseModel, Field +from pydantic import Field, field_validator from pydantic.dataclasses import dataclass +from pydantic_core.core_schema import ValidationInfo from ._get_num_cpu import get_num_cpu_from_data_file from .parsing import ( + BaseModelWithContextSupport, ConfigDict, ConfigKeys, ConfigValidationError, @@ -38,7 +40,7 @@ def activate_script() -> str: class QueueOptions( - BaseModel, + BaseModelWithContextSupport, validate_assignment=True, extra="forbid", use_enum_values=True, @@ -48,7 +50,19 @@ class QueueOptions( max_running: pydantic.NonNegativeInt = 0 submit_sleep: pydantic.NonNegativeFloat = 0.0 project_code: str | None = None - activate_script: str = Field(default_factory=activate_script) + activate_script: str | None = Field(default=None, validate_default=True) + + @field_validator("activate_script", mode="before") + @classmethod + def inject_site_config_script(cls, v: str, info: ValidationInfo) -> str: + # User value gets highest priority + if isinstance(v, str): + return v + # Use from plugin system if user has not specified + plugin_script = None + if info.context: + plugin_script = info.context.get(info.field_name) + return plugin_script or activate_script() # Return default value @staticmethod def create_queue_options( diff --git a/src/everest/config/everest_config.py b/src/everest/config/everest_config.py index 4f0df35f9af..e2ae2c6a30b 100644 --- a/src/everest/config/everest_config.py +++ b/src/everest/config/everest_config.py @@ -16,7 +16,6 @@ from pydantic import ( AfterValidator, - BaseModel, ConfigDict, Field, ValidationError, @@ -26,6 +25,9 @@ from ruamel.yaml import YAML, YAMLError from ert.config import ErtConfig +from ert.config.parsing import BaseModelWithContextSupport +from ert.config.parsing.base_model_context import init_context +from ert.plugins import ErtPluginManager from everest.config.control_variable_config import ControlVariableGuessListConfig from everest.config.install_template_config import InstallTemplateConfig from everest.config.server_config import ServerConfig @@ -117,7 +119,7 @@ class HasName(Protocol): name: str -class EverestConfig(BaseModel): # type: ignore +class EverestConfig(BaseModelWithContextSupport): # type: ignore controls: Annotated[list[ControlConfig], AfterValidator(unique_items)] = Field( description="""Defines a list of controls. Controls should have unique names each control defines @@ -763,6 +765,15 @@ def load_file(config_file: str) -> "EverestConfig": break raise exp from error + @classmethod + def with_plugins(cls, config_dict): + context = {} + activate_script = ErtPluginManager().activate_script() + if activate_script: + context["activate_script"] = ErtPluginManager().activate_script() + with init_context(context): + return cls(**config_dict) + @staticmethod def load_file_with_argparser( config_path, parser: ArgumentParser diff --git a/src/everest/config/server_config.py b/src/everest/config/server_config.py index f4f7bd9b27a..de4c9691100 100644 --- a/src/everest/config/server_config.py +++ b/src/everest/config/server_config.py @@ -2,7 +2,7 @@ import os from typing import Any -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from pydantic import BaseModel, ConfigDict, Field, model_validator from ert.config.queue_config import ( LocalQueueOptions, @@ -10,7 +10,6 @@ SlurmQueueOptions, TorqueQueueOptions, ) -from ert.plugins import ErtPluginManager from ..strings import ( CERTIFICATE_DIR, @@ -38,15 +37,6 @@ class ServerConfig(BaseModel): # type: ignore extra="forbid", ) - @field_validator("queue_system", mode="before") - @classmethod - def default_local_queue(cls, v): - if v is None: - return v - elif "activate_script" not in v and ErtPluginManager().activate_script(): - v["activate_script"] = ErtPluginManager().activate_script() - return v - @model_validator(mode="before") @classmethod def check_old_config(cls, data: Any) -> Any: diff --git a/src/everest/config/simulator_config.py b/src/everest/config/simulator_config.py index 126ea7217b5..c82454536b0 100644 --- a/src/everest/config/simulator_config.py +++ b/src/everest/config/simulator_config.py @@ -15,7 +15,6 @@ SlurmQueueOptions, TorqueQueueOptions, ) -from ert.plugins import ErtPluginManager simulator_example = {"queue_system": {"name": "local", "max_running": 3}} @@ -93,10 +92,6 @@ class SimulatorConfig(BaseModel, extra="forbid"): # type: ignore def default_local_queue(cls, v): if v is None: return LocalQueueOptions(max_running=8) - if "activate_script" not in v and ( - active_script := ErtPluginManager().activate_script() - ): - v["activate_script"] = active_script return v @model_validator(mode="before") diff --git a/tests/everest/test_detached.py b/tests/everest/test_detached.py index 464f58038ac..8b63ce23c0d 100644 --- a/tests/everest/test_detached.py +++ b/tests/everest/test_detached.py @@ -280,7 +280,7 @@ def test_generate_queue_options_use_simulator_values( queue_options, expected_result, monkeypatch ): monkeypatch.setattr( - everest.config.server_config.ErtPluginManager, + everest.config.everest_config.ErtPluginManager, "activate_script", MagicMock(return_value=activate_script()), ) @@ -288,6 +288,35 @@ def test_generate_queue_options_use_simulator_values( assert config.server.queue_system == expected_result +@pytest.mark.parametrize("use_plugin", (True, False)) +@pytest.mark.parametrize( + "queue_options", + [ + {"name": "slurm", "activate_script": "From user"}, + {"name": "slurm"}, + ], +) +def test_queue_options_site_config(queue_options, use_plugin, monkeypatch, min_config): + plugin_result = "From plugin" + if "activate_script" in queue_options: + expected_result = queue_options["activate_script"] + elif use_plugin: + expected_result = plugin_result + else: + expected_result = activate_script() + + if use_plugin: + monkeypatch.setattr( + everest.config.everest_config.ErtPluginManager, + "activate_script", + MagicMock(return_value=plugin_result), + ) + config = EverestConfig.with_plugins( + {"simulator": {"queue_system": queue_options}} | min_config + ) + assert config.server.queue_system.activate_script == expected_result + + @pytest.mark.timeout(5) # Simulation might not finish @pytest.mark.integration_test @pytest.mark.xdist_group(name="starts_everest") From 4d18d53f4d022cdb0c04606d83b5e8af4af0b558 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=98yvind=20Eide?= Date: Thu, 16 Jan 2025 09:00:53 +0100 Subject: [PATCH 3/6] Fix bug where queue settings were not taken from site config --- src/everest/config/everest_config.py | 8 ++++-- src/everest/config/simulator_config.py | 12 ++++++--- tests/everest/test_detached.py | 31 ++++++++++++++++++++++-- tests/everest/test_res_initialization.py | 26 ++++++++++++++++++++ 4 files changed, 69 insertions(+), 8 deletions(-) diff --git a/src/everest/config/everest_config.py b/src/everest/config/everest_config.py index e2ae2c6a30b..9b9392a369b 100644 --- a/src/everest/config/everest_config.py +++ b/src/everest/config/everest_config.py @@ -22,9 +22,10 @@ field_validator, model_validator, ) +from pydantic_core.core_schema import ValidationInfo from ruamel.yaml import YAML, YAMLError -from ert.config import ErtConfig +from ert.config import ErtConfig, QueueConfig from ert.config.parsing import BaseModelWithContextSupport from ert.config.parsing.base_model_context import init_context from ert.plugins import ErtPluginManager @@ -251,7 +252,7 @@ def validate_queue_system(self) -> Self: # pylint: disable=E0213 return self @model_validator(mode="after") - def validate_forward_model_job_name_installed(self) -> Self: # pylint: disable=E0213 + def validate_forward_model_job_name_installed(self, info: ValidationInfo) -> Self: # pylint: disable=E0213 install_jobs = self.install_jobs forward_model_jobs = self.forward_model if install_jobs is None: @@ -769,6 +770,9 @@ def load_file(config_file: str) -> "EverestConfig": def with_plugins(cls, config_dict): context = {} activate_script = ErtPluginManager().activate_script() + site_config = ErtConfig.read_site_config() + if site_config: + context["queue_system"] = QueueConfig.from_dict(site_config).queue_options if activate_script: context["activate_script"] = ErtPluginManager().activate_script() with init_context(context): diff --git a/src/everest/config/simulator_config.py b/src/everest/config/simulator_config.py index c82454536b0..4a04c77dec8 100644 --- a/src/everest/config/simulator_config.py +++ b/src/everest/config/simulator_config.py @@ -1,14 +1,15 @@ from typing import Any from pydantic import ( - BaseModel, Field, NonNegativeInt, PositiveInt, field_validator, model_validator, ) +from pydantic_core.core_schema import ValidationInfo +from ert.config.parsing import BaseModelWithContextSupport from ert.config.queue_config import ( LocalQueueOptions, LsfQueueOptions, @@ -32,7 +33,7 @@ def check_removed_config(queue_system): ) -class SimulatorConfig(BaseModel, extra="forbid"): # type: ignore +class SimulatorConfig(BaseModelWithContextSupport, extra="forbid"): # type: ignore cores_per_node: PositiveInt | None = Field( default=None, description="""defines the number of CPUs when running @@ -89,9 +90,12 @@ class SimulatorConfig(BaseModel, extra="forbid"): # type: ignore @field_validator("queue_system", mode="before") @classmethod - def default_local_queue(cls, v): + def default_local_queue(cls, v, info: ValidationInfo): if v is None: - return LocalQueueOptions(max_running=8) + options = None + if info.context: + options = info.context.get(info.field_name) + return options or LocalQueueOptions(max_running=8) return v @model_validator(mode="before") diff --git a/tests/everest/test_detached.py b/tests/everest/test_detached.py index 8b63ce23c0d..3a8ece18031 100644 --- a/tests/everest/test_detached.py +++ b/tests/everest/test_detached.py @@ -8,7 +8,7 @@ import requests import everest -from ert.config import ErtConfig +from ert.config import ErtConfig, QueueSystem from ert.config.queue_config import ( LocalQueueOptions, LsfQueueOptions, @@ -314,7 +314,34 @@ def test_queue_options_site_config(queue_options, use_plugin, monkeypatch, min_c config = EverestConfig.with_plugins( {"simulator": {"queue_system": queue_options}} | min_config ) - assert config.server.queue_system.activate_script == expected_result + assert config.simulator.queue_system.activate_script == expected_result + + +@pytest.mark.parametrize("use_plugin", (True, False)) +@pytest.mark.parametrize( + "queue_options", + [ + {"queue_system": {"name": "slurm"}}, + {}, + ], +) +def test_simulator_queue_system_site_config( + queue_options, use_plugin, monkeypatch, min_config +): + if queue_options: + expected_result = SlurmQueueOptions # User specified + elif use_plugin: + expected_result = LsfQueueOptions # Mock site config + else: + expected_result = LocalQueueOptions # Default value + if use_plugin: + monkeypatch.setattr( + everest.config.everest_config.ErtConfig, + "read_site_config", + MagicMock(return_value={"QUEUE_SYSTEM": QueueSystem.LSF}), + ) + config = EverestConfig.with_plugins({"simulator": queue_options} | min_config) + assert isinstance(config.simulator.queue_system, expected_result) @pytest.mark.timeout(5) # Simulation might not finish diff --git a/tests/everest/test_res_initialization.py b/tests/everest/test_res_initialization.py index b984bc87616..0baf676b3c6 100644 --- a/tests/everest/test_res_initialization.py +++ b/tests/everest/test_res_initialization.py @@ -359,3 +359,29 @@ def test_user_config_jobs_precedence(tmp_path, monkeypatch): .executable == "echo" ) + + +def test_that_queue_settings_are_taken_from_site_config( + min_config, monkeypatch, tmp_path +): + monkeypatch.chdir(tmp_path) + assert "simulator" not in min_config # Double check + Path("site-config").write_text( + dedent(""" + QUEUE_SYSTEM LSF + QUEUE_OPTION LSF LSF_RESOURCE my_resource + QUEUE_OPTION LSF LSF_QUEUE my_queue + """), + encoding="utf-8", + ) + with open("config.yml", "w", encoding="utf-8") as f: + yaml.dump(min_config, f) + monkeypatch.setenv("ERT_SITE_CONFIG", "site-config") + config = EverestConfig.load_file("config.yml") + assert config.simulator.queue_system == LsfQueueOptions( + lsf_queue="my_queue", lsf_resource="my_resource" + ) + ert_config = everest_to_ert_config(config) + assert ert_config.queue_config.queue_options == LsfQueueOptions( + lsf_queue="my_queue", lsf_resource="my_resource" + ) From e54edca97b6090e7f31f3f784dd4668449982df8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=98yvind=20Eide?= Date: Wed, 15 Jan 2025 08:48:34 +0100 Subject: [PATCH 4/6] Add installed jobs in new context system --- src/everest/config/everest_config.py | 34 ++++++++++++---------------- tests/everest/test_egg_simulation.py | 1 - tests/everest/test_util.py | 8 ------- 3 files changed, 14 insertions(+), 29 deletions(-) diff --git a/src/everest/config/everest_config.py b/src/everest/config/everest_config.py index 9b9392a369b..359517daf8b 100644 --- a/src/everest/config/everest_config.py +++ b/src/everest/config/everest_config.py @@ -70,18 +70,6 @@ from pydantic_core import ErrorDetails -def _dummy_ert_config(): - site_config = ErtConfig.read_site_config() - dummy_config = {"NUM_REALIZATIONS": 1, "ENSPATH": "."} - dummy_config.update(site_config) - return ErtConfig.with_plugins().from_dict(config_dict=dummy_config) - - -def get_system_installed_jobs(): - """Returns list of all system installed job names""" - return list(_dummy_ert_config().installed_forward_model_steps.keys()) - - class EverestValidationError(ValueError): def __init__(self): super().__init__() @@ -179,7 +167,7 @@ class EverestConfig(BaseModelWithContextSupport): # type: ignore default=None, description="A list of output constraints with unique names." ) install_jobs: list[InstallJobConfig] | None = Field( - default=None, description="A list of jobs to install" + default=None, description="A list of jobs to install", validate_default=True ) install_workflow_jobs: list[InstallJobConfig] | None = Field( default=None, description="A list of workflow jobs to install" @@ -261,7 +249,8 @@ def validate_forward_model_job_name_installed(self, info: ValidationInfo) -> Sel return self installed_jobs_name = [job.name for job in install_jobs] installed_jobs_name += list(script_names) # default jobs - installed_jobs_name += get_system_installed_jobs() # system jobs + if info.context: # Add plugin jobs + installed_jobs_name += info.context.get("install_jobs", {}).keys() errors = [] for fm_job in forward_model_jobs: @@ -721,7 +710,7 @@ def with_defaults(cls, **kwargs): "model": {"realizations": [0]}, } - return cls.model_validate({**defaults, **kwargs}) + return cls.with_plugins({**defaults, **kwargs}) @staticmethod def lint_config_dict(config: dict) -> list["ErrorDetails"]: @@ -738,8 +727,8 @@ def lint_config_dict_with_raise(config: dict): # more understandable EverestConfig.model_validate(config) - @staticmethod - def load_file(config_file: str) -> "EverestConfig": + @classmethod + def load_file(cls, config_file: str): config_path = os.path.realpath(config_file) if not os.path.isfile(config_path): @@ -747,7 +736,7 @@ def load_file(config_file: str) -> "EverestConfig": config_dict = yaml_file_to_substituted_config_dict(config_path) try: - return EverestConfig.model_validate(config_dict) + return cls.with_plugins(config_dict) except ValidationError as error: exp = EverestValidationError() file_content = [] @@ -768,9 +757,14 @@ def load_file(config_file: str) -> "EverestConfig": @classmethod def with_plugins(cls, config_dict): - context = {} - activate_script = ErtPluginManager().activate_script() site_config = ErtConfig.read_site_config() + ert_config: ErtConfig = ErtConfig.with_plugins().from_dict( + config_dict=site_config + ) + context = { + "install_jobs": ert_config.installed_forward_model_steps, + } + activate_script = ErtPluginManager().activate_script() if site_config: context["queue_system"] = QueueConfig.from_dict(site_config).queue_options if activate_script: diff --git a/tests/everest/test_egg_simulation.py b/tests/everest/test_egg_simulation.py index 70408129d94..1a4354ab12d 100644 --- a/tests/everest/test_egg_simulation.py +++ b/tests/everest/test_egg_simulation.py @@ -594,7 +594,6 @@ def test_opm_fail_default_summary_keys(copy_egg_test_data_to_tmp): config = EverestConfig.load_file(CONFIG_FILE) # The Everest config file will fail to load as an Eclipse data file config.model.data_file = os.path.realpath(CONFIG_FILE) - assert len(EverestConfig.lint_config_dict(config.to_dict())) == 0 ert_config = _everest_to_ert_config_dict(config) diff --git a/tests/everest/test_util.py b/tests/everest/test_util.py index 95b26150e35..fa99bdf2003 100644 --- a/tests/everest/test_util.py +++ b/tests/everest/test_util.py @@ -7,7 +7,6 @@ from everest import util from everest.bin.utils import report_on_previous_run from everest.config import EverestConfig, ServerConfig -from everest.config.everest_config import get_system_installed_jobs from everest.detached import ServerStatus from everest.strings import SERVER_STATUS from tests.everest.utils import ( @@ -131,13 +130,6 @@ def test_get_everserver_status_path(copy_math_func_test_data_to_tmp): assert path == expected_path -def test_get_system_installed_job_names(): - job_names = get_system_installed_jobs() - assert job_names is not None - assert isinstance(job_names, list) - assert len(job_names) > 0 - - @patch( "everest.bin.utils.everserver_status", return_value={"status": ServerStatus.failed, "message": "mock error"}, From c08a3173fc44ab67807a1853bbe3e78044dd1313 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=98yvind=20Eide?= Date: Fri, 17 Jan 2025 10:01:21 +0100 Subject: [PATCH 5/6] Simplify test --- tests/everest/test_detached.py | 52 +++++----------------------------- 1 file changed, 7 insertions(+), 45 deletions(-) diff --git a/tests/everest/test_detached.py b/tests/everest/test_detached.py index 3a8ece18031..4756b8e516b 100644 --- a/tests/everest/test_detached.py +++ b/tests/everest/test_detached.py @@ -6,9 +6,10 @@ import numpy as np import pytest import requests +import yaml import everest -from ert.config import ErtConfig, QueueSystem +from ert.config import QueueSystem from ert.config.queue_config import ( LocalQueueOptions, LsfQueueOptions, @@ -21,7 +22,6 @@ from everest.config.server_config import ServerConfig from everest.config.simulator_config import SimulatorConfig from everest.detached import ( - _EVERSERVER_JOB_PATH, PROXY, ServerStatus, everserver_status, @@ -33,12 +33,6 @@ wait_for_server, wait_for_server_to_stop, ) -from everest.strings import ( - DEFAULT_OUTPUT_DIR, - DETACHED_NODE_DIR, - EVEREST_SERVER_CONFIG, - SIMULATION_DIR, -) from everest.util import makedirs_if_needed @@ -153,43 +147,11 @@ def test_wait_for_server(server_is_running_mock, caplog): assert not caplog.messages -def _get_reference_config(): - everest_config = EverestConfig.load_file("config_minimal.yml") - reference_config = ErtConfig.read_site_config() - cwd = os.getcwd() - reference_config.update( - { - "INSTALL_JOB": [(EVEREST_SERVER_CONFIG, _EVERSERVER_JOB_PATH)], - "QUEUE_SYSTEM": "LOCAL", - "JOBNAME": EVEREST_SERVER_CONFIG, - "MAX_SUBMIT": 1, - "NUM_REALIZATIONS": 1, - "RUNPATH": os.path.join( - cwd, - DEFAULT_OUTPUT_DIR, - DETACHED_NODE_DIR, - SIMULATION_DIR, - ), - "FORWARD_MODEL": [ - [ - EVEREST_SERVER_CONFIG, - "--config-file", - os.path.join(cwd, "config_minimal.yml"), - ], - ], - "ENSPATH": os.path.join( - cwd, DEFAULT_OUTPUT_DIR, DETACHED_NODE_DIR, EVEREST_SERVER_CONFIG - ), - "RUNPATH_FILE": os.path.join( - cwd, DEFAULT_OUTPUT_DIR, DETACHED_NODE_DIR, ".res_runpath_list" - ), - } - ) - return everest_config, reference_config - - -def test_detached_mode_config_base(copy_math_func_test_data_to_tmp): - everest_config, _ = _get_reference_config() +def test_detached_mode_config_base(min_config, monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + with open("config.yml", "w", encoding="utf-8") as fout: + yaml.dump(min_config, fout) + everest_config = EverestConfig.load_file("config.yml") assert everest_config.simulator.queue_system == LocalQueueOptions(max_running=8) From 13760e068de9e9b247dff081b83186c51c81c4c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=98yvind=20Eide?= Date: Tue, 4 Feb 2025 15:24:44 +0100 Subject: [PATCH 6/6] Review comments --- src/ert/config/queue_config.py | 4 +--- src/everest/config/everest_config.py | 4 ++-- src/everest/config/simulator_config.py | 2 +- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/ert/config/queue_config.py b/src/ert/config/queue_config.py index d8e04734968..2a8fc28434e 100644 --- a/src/ert/config/queue_config.py +++ b/src/ert/config/queue_config.py @@ -61,7 +61,7 @@ def inject_site_config_script(cls, v: str, info: ValidationInfo) -> str: # Use from plugin system if user has not specified plugin_script = None if info.context: - plugin_script = info.context.get(info.field_name) + plugin_script = info.context.get("activate_script") return plugin_script or activate_script() # Return default value @staticmethod @@ -154,8 +154,6 @@ def driver_options(self) -> dict[str, Any]: "name", "max_running", "submit_sleep", - "qstat_options", - "queue_query_timeout", } ) driver_dict["queue_name"] = driver_dict.pop("queue") diff --git a/src/everest/config/everest_config.py b/src/everest/config/everest_config.py index 359517daf8b..bf5264c98d9 100644 --- a/src/everest/config/everest_config.py +++ b/src/everest/config/everest_config.py @@ -728,7 +728,7 @@ def lint_config_dict_with_raise(config: dict): EverestConfig.model_validate(config) @classmethod - def load_file(cls, config_file: str): + def load_file(cls, config_file: str) -> Self: config_path = os.path.realpath(config_file) if not os.path.isfile(config_path): @@ -768,7 +768,7 @@ def with_plugins(cls, config_dict): if site_config: context["queue_system"] = QueueConfig.from_dict(site_config).queue_options if activate_script: - context["activate_script"] = ErtPluginManager().activate_script() + context["activate_script"] = activate_script with init_context(context): return cls(**config_dict) diff --git a/src/everest/config/simulator_config.py b/src/everest/config/simulator_config.py index 4a04c77dec8..017e12c4537 100644 --- a/src/everest/config/simulator_config.py +++ b/src/everest/config/simulator_config.py @@ -94,7 +94,7 @@ def default_local_queue(cls, v, info: ValidationInfo): if v is None: options = None if info.context: - options = info.context.get(info.field_name) + options = info.context.get("queue_system") return options or LocalQueueOptions(max_running=8) return v