Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug where site config was not propagated to Everest config #9719

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/ert/config/parsing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .base_model_context import BaseModelWithContextSupport
from .config_dict import ConfigDict
from .config_errors import ConfigValidationError, ConfigWarning
from .config_keywords import ConfigKeys
Expand All @@ -18,6 +19,7 @@
from .workflow_schema import init_workflow_schema

__all__ = [
"BaseModelWithContextSupport",
"ConfigDict",
"ConfigKeys",
"ConfigValidationError",
Expand Down
26 changes: 26 additions & 0 deletions src/ert/config/parsing/base_model_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from collections.abc import Iterator
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any

from pydantic import BaseModel

init_context_var = ContextVar("_init_context_var", default=None)


@contextmanager
def init_context(value: dict[str, Any]) -> Iterator[None]:
token = init_context_var.set(value) # type: ignore
try:
yield
finally:
init_context_var.reset(token)


class BaseModelWithContextSupport(BaseModel):
def __init__(__pydantic_self__, **data: Any) -> None:
__pydantic_self__.__pydantic_validator__.validate_python(
data,
self_instance=__pydantic_self__,
context=init_context_var.get(),
)
77 changes: 41 additions & 36 deletions src/ert/config/queue_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@
import re
import shutil
from abc import abstractmethod
from dataclasses import asdict, field, fields
from typing import Annotated, Any, Literal, no_type_check

import pydantic
from pydantic import Field, field_validator
from pydantic.dataclasses import dataclass
from pydantic_core.core_schema import ValidationInfo

from ._get_num_cpu import get_num_cpu_from_data_file
from .parsing import (
BaseModelWithContextSupport,
ConfigDict,
ConfigKeys,
ConfigValidationError,
Expand All @@ -37,20 +39,30 @@ def activate_script() -> str:
return ""


@pydantic.dataclasses.dataclass(
config={
"extra": "forbid",
"validate_assignment": True,
"use_enum_values": True,
"validate_default": True,
}
)
class QueueOptions:
class QueueOptions(
BaseModelWithContextSupport,
validate_assignment=True,
extra="forbid",
use_enum_values=True,
validate_default=True,
):
name: QueueSystem
max_running: pydantic.NonNegativeInt = 0
submit_sleep: pydantic.NonNegativeFloat = 0.0
project_code: str | None = None
activate_script: str = field(default_factory=activate_script)
activate_script: str | None = Field(default=None, validate_default=True)

@field_validator("activate_script", mode="before")
@classmethod
def inject_site_config_script(cls, v: str, info: ValidationInfo) -> str:
# User value gets highest priority
if isinstance(v, str):
return v
# Use from plugin system if user has not specified
plugin_script = None
if info.context:
plugin_script = info.context.get("activate_script")
return plugin_script or activate_script() # Return default value

@staticmethod
def create_queue_options(
Expand Down Expand Up @@ -78,12 +90,12 @@ def create_queue_options(
return None

def add_global_queue_options(self, config_dict: ConfigDict) -> None:
for generic_option in fields(QueueOptions):
for name, generic_option in QueueOptions.model_fields.items():
if (
generic_value := config_dict.get(generic_option.name.upper(), None) # type: ignore
) and self.__dict__[generic_option.name] == generic_option.default:
generic_value := config_dict.get(name.upper(), None) # type: ignore
) and self.__dict__[name] == generic_option.default:
try:
setattr(self, generic_option.name, generic_value)
setattr(self, name, generic_value)
except pydantic.ValidationError as exception:
for error in exception.errors():
_throw_error_or_warning(
Expand All @@ -98,7 +110,6 @@ def driver_options(self) -> dict[str, Any]:
"""Translate the queue options to the key-value API provided by each driver"""


@pydantic.dataclasses.dataclass
class LocalQueueOptions(QueueOptions):
name: Literal[QueueSystem.LOCAL] = QueueSystem.LOCAL

Expand All @@ -107,7 +118,6 @@ def driver_options(self) -> dict[str, Any]:
return {}


@pydantic.dataclasses.dataclass
class LsfQueueOptions(QueueOptions):
name: Literal[QueueSystem.LSF] = QueueSystem.LSF
bhist_cmd: NonEmptyString | None = None
Expand All @@ -120,17 +130,13 @@ class LsfQueueOptions(QueueOptions):

@property
def driver_options(self) -> dict[str, Any]:
driver_dict = asdict(self)
driver_dict.pop("name")
driver_dict = self.model_dump(exclude={"name", "submit_sleep", "max_running"})
driver_dict["exclude_hosts"] = driver_dict.pop("exclude_host")
driver_dict["queue_name"] = driver_dict.pop("lsf_queue")
driver_dict["resource_requirement"] = driver_dict.pop("lsf_resource")
driver_dict.pop("submit_sleep")
driver_dict.pop("max_running")
return driver_dict


@pydantic.dataclasses.dataclass
class TorqueQueueOptions(QueueOptions):
name: Literal[QueueSystem.TORQUE] = QueueSystem.TORQUE
qsub_cmd: NonEmptyString | None = None
Expand All @@ -143,15 +149,17 @@ class TorqueQueueOptions(QueueOptions):

@property
def driver_options(self) -> dict[str, Any]:
driver_dict = asdict(self)
driver_dict.pop("name")
driver_dict = self.model_dump(
exclude={
"name",
"max_running",
"submit_sleep",
}
)
driver_dict["queue_name"] = driver_dict.pop("queue")
driver_dict.pop("max_running")
driver_dict.pop("submit_sleep")
return driver_dict


@pydantic.dataclasses.dataclass
class SlurmQueueOptions(QueueOptions):
name: Literal[QueueSystem.SLURM] = QueueSystem.SLURM
sbatch: NonEmptyString = "sbatch"
Expand All @@ -167,8 +175,7 @@ class SlurmQueueOptions(QueueOptions):

@property
def driver_options(self) -> dict[str, Any]:
driver_dict = asdict(self)
driver_dict.pop("name")
driver_dict = self.model_dump(exclude={"name", "max_running", "submit_sleep"})
driver_dict["sbatch_cmd"] = driver_dict.pop("sbatch")
driver_dict["scancel_cmd"] = driver_dict.pop("scancel")
driver_dict["scontrol_cmd"] = driver_dict.pop("scontrol")
Expand All @@ -177,8 +184,6 @@ def driver_options(self) -> dict[str, Any]:
driver_dict["exclude_hosts"] = driver_dict.pop("exclude_host")
driver_dict["include_hosts"] = driver_dict.pop("include_host")
driver_dict["queue_name"] = driver_dict.pop("partition")
driver_dict.pop("max_running")
driver_dict.pop("submit_sleep")
return driver_dict


Expand All @@ -203,12 +208,12 @@ def validate(self, mem_str_format: str | None) -> bool:
)

valid_options: dict[str, list[str]] = {
QueueSystem.LOCAL: [field.name.upper() for field in fields(LocalQueueOptions)],
QueueSystem.LSF: [field.name.upper() for field in fields(LsfQueueOptions)],
QueueSystem.SLURM: [field.name.upper() for field in fields(SlurmQueueOptions)],
QueueSystem.TORQUE: [field.name.upper() for field in fields(TorqueQueueOptions)],
QueueSystem.LOCAL: [field.upper() for field in LocalQueueOptions.model_fields],
QueueSystem.LSF: [field.upper() for field in LsfQueueOptions.model_fields],
QueueSystem.SLURM: [field.upper() for field in SlurmQueueOptions.model_fields],
QueueSystem.TORQUE: [field.upper() for field in TorqueQueueOptions.model_fields],
QueueSystemWithGeneric.GENERIC: [
field.name.upper() for field in fields(QueueOptions)
field.upper() for field in QueueOptions.model_fields
],
}

Expand Down
53 changes: 31 additions & 22 deletions src/everest/config/everest_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,19 @@

from pydantic import (
AfterValidator,
BaseModel,
ConfigDict,
Field,
ValidationError,
field_validator,
model_validator,
)
from pydantic_core.core_schema import ValidationInfo
from ruamel.yaml import YAML, YAMLError

from ert.config import ErtConfig
from ert.config import ErtConfig, QueueConfig
from ert.config.parsing import BaseModelWithContextSupport
from ert.config.parsing.base_model_context import init_context
from ert.plugins import ErtPluginManager
from everest.config.control_variable_config import ControlVariableGuessListConfig
from everest.config.install_template_config import InstallTemplateConfig
from everest.config.server_config import ServerConfig
Expand Down Expand Up @@ -67,18 +70,6 @@
from pydantic_core import ErrorDetails


def _dummy_ert_config():
site_config = ErtConfig.read_site_config()
dummy_config = {"NUM_REALIZATIONS": 1, "ENSPATH": "."}
dummy_config.update(site_config)
return ErtConfig.with_plugins().from_dict(config_dict=dummy_config)


def get_system_installed_jobs():
"""Returns list of all system installed job names"""
return list(_dummy_ert_config().installed_forward_model_steps.keys())


class EverestValidationError(ValueError):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -117,7 +108,7 @@ class HasName(Protocol):
name: str


class EverestConfig(BaseModel): # type: ignore
class EverestConfig(BaseModelWithContextSupport): # type: ignore
controls: Annotated[list[ControlConfig], AfterValidator(unique_items)] = Field(
description="""Defines a list of controls.
Controls should have unique names each control defines
Expand Down Expand Up @@ -176,7 +167,7 @@ class EverestConfig(BaseModel): # type: ignore
default=None, description="A list of output constraints with unique names."
)
install_jobs: list[InstallJobConfig] | None = Field(
default=None, description="A list of jobs to install"
default=None, description="A list of jobs to install", validate_default=True
)
install_workflow_jobs: list[InstallJobConfig] | None = Field(
default=None, description="A list of workflow jobs to install"
Expand Down Expand Up @@ -249,7 +240,7 @@ def validate_queue_system(self) -> Self: # pylint: disable=E0213
return self

@model_validator(mode="after")
def validate_forward_model_job_name_installed(self) -> Self: # pylint: disable=E0213
def validate_forward_model_job_name_installed(self, info: ValidationInfo) -> Self: # pylint: disable=E0213
install_jobs = self.install_jobs
forward_model_jobs = self.forward_model
if install_jobs is None:
Expand All @@ -258,7 +249,8 @@ def validate_forward_model_job_name_installed(self) -> Self: # pylint: disable=
return self
installed_jobs_name = [job.name for job in install_jobs]
installed_jobs_name += list(script_names) # default jobs
installed_jobs_name += get_system_installed_jobs() # system jobs
if info.context: # Add plugin jobs
installed_jobs_name += info.context.get("install_jobs", {}).keys()

errors = []
for fm_job in forward_model_jobs:
Expand Down Expand Up @@ -718,7 +710,7 @@ def with_defaults(cls, **kwargs):
"model": {"realizations": [0]},
}

return cls.model_validate({**defaults, **kwargs})
return cls.with_plugins({**defaults, **kwargs})

@staticmethod
def lint_config_dict(config: dict) -> list["ErrorDetails"]:
Expand All @@ -735,16 +727,16 @@ def lint_config_dict_with_raise(config: dict):
# more understandable
EverestConfig.model_validate(config)

@staticmethod
def load_file(config_file: str) -> "EverestConfig":
@classmethod
def load_file(cls, config_file: str) -> Self:
config_path = os.path.realpath(config_file)

if not os.path.isfile(config_path):
raise FileNotFoundError(f"File not found: {config_path}")

config_dict = yaml_file_to_substituted_config_dict(config_path)
try:
return EverestConfig.model_validate(config_dict)
return cls.with_plugins(config_dict)
except ValidationError as error:
exp = EverestValidationError()
file_content = []
Expand All @@ -763,6 +755,23 @@ def load_file(config_file: str) -> "EverestConfig":
break
raise exp from error

@classmethod
def with_plugins(cls, config_dict):
site_config = ErtConfig.read_site_config()
ert_config: ErtConfig = ErtConfig.with_plugins().from_dict(
config_dict=site_config
)
context = {
"install_jobs": ert_config.installed_forward_model_steps,
}
activate_script = ErtPluginManager().activate_script()
if site_config:
context["queue_system"] = QueueConfig.from_dict(site_config).queue_options
if activate_script:
context["activate_script"] = activate_script
with init_context(context):
return cls(**config_dict)

@staticmethod
def load_file_with_argparser(
config_path, parser: ArgumentParser
Expand Down
12 changes: 1 addition & 11 deletions src/everest/config/server_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
import os
from typing import Any

from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from pydantic import BaseModel, ConfigDict, Field, model_validator

from ert.config.queue_config import (
LocalQueueOptions,
LsfQueueOptions,
SlurmQueueOptions,
TorqueQueueOptions,
)
from ert.plugins import ErtPluginManager

from ..strings import (
CERTIFICATE_DIR,
Expand Down Expand Up @@ -38,15 +37,6 @@ class ServerConfig(BaseModel): # type: ignore
extra="forbid",
)

@field_validator("queue_system", mode="before")
@classmethod
def default_local_queue(cls, v):
if v is None:
return v
elif "activate_script" not in v and ErtPluginManager().activate_script():
v["activate_script"] = ErtPluginManager().activate_script()
return v

@model_validator(mode="before")
@classmethod
def check_old_config(cls, data: Any) -> Any:
Expand Down
Loading
Loading