diff --git a/src/AMSWorkflow/ams/ams_jobs.py b/src/AMSWorkflow/ams/ams_jobs.py index 88c1a61d..1b56f3cd 100644 --- a/src/AMSWorkflow/ams/ams_jobs.py +++ b/src/AMSWorkflow/ams/ams_jobs.py @@ -4,6 +4,8 @@ from typing import Optional from dataclasses import dataclass, fields from ams import util +from typing import Dict, List, Union, Optional, Mapping +from pathlib import Path def constuct_cli_cmd(executable, *args, **kwargs): @@ -36,7 +38,7 @@ class AMSJob: """ @classmethod - def generate_formatting(self, store): + def generate_formatting(cls, store): return {"AMS_STORE_PATH": store.root_path} def __init__( @@ -233,7 +235,7 @@ def __init__(self, domain_names, stage_dir, *args, **kwargs): super().__init__(*args, **kwargs) @classmethod - def from_descr(cls, stage_dir, descr): + def from_descr(cls, descr, stage_dir=None): domain_job_resources = AMSJobResources(**descr["resources"]) return cls( name=descr["name"], @@ -293,7 +295,158 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) -class AMSFSStageJob(AMSJob): +class AMSStageJob(AMSJob): + def __init__( + self, + resources: Union[Dict[str, Union[str, int]], AMSJobResources], + dest: str, + persistent_db_path: str, + store: bool = True, + db_type: str = "dhdf5", + policy: str = "process", + prune_module_path: Optional[str] = None, + prune_class: Optional[str] = None, + environ: Optional[Mapping[str, str]] = None, + stdout: Optional[str] = None, + stderr: Optional[str] = None, + cli_args: List[str] = [], + cli_kwargs: Mapping[str, str] = {}, + ): + _cli_args = list(cli_args) + if store: + _cli_args.append("--store") + else: + _cli_args.append("--no-store") + + _cli_kwargs = dict(cli_kwargs) + _cli_kwargs["--dest"] = dest + _cli_kwargs["--persistent-db-path"] = persistent_db_path + _cli_kwargs["--db-type"] = db_type + _cli_kwargs["--policy"] = policy + + if prune_module_path is not None: + assert Path(prune_module_path).exists(), "Module path to user pruner does not exist" + assert prune_class is not None, "When defining a pruning module please define the class" + _cli_kwargs["--load"] = prune_module_path + _cli_kwargs["--class"] = prune_class + + super().__init__( + name="AMSStageJob", + executable="AMSDBStage", + environ=environ, + resources=resources, + stdout=stdout, + stderr=stderr, + cli_args=_cli_args, + cli_kwargs=_cli_kwargs, + ) + + +class AMSFSStageJob(AMSStageJob): + def __init__( + self, + resources: Union[Dict[str, Union[str, int]], AMSJobResources], + dest: str, + persistent_db_path: str, + src: str, + store: bool = True, + db_type="dhf5", + pattern="*.h5", + src_type: str = "shdf5", + prune_module_path: Optional[str] = None, + prune_class: Optional[str] = None, + environ: Optional[Mapping[str, str]] = None, + stdout: Optional[str] = None, + stderr: Optional[str] = None, + cli_args: List[str] = [], + cli_kwargs: Mapping[str, str] = {}, + ): + + _cli_args = list(cli_args) + _cli_kwargs = dict(cli_kwargs) + _cli_kwargs["--src"] = src + _cli_kwargs["--src-type"] = src_type + _cli_kwargs["--pattern"] = pattern + _cli_kwargs["--mechanism"] = "fs" + + super().__init__( + resources, + dest, + persistent_db_path, + store, + db_type, + environ=environ, + stdout=stdout, + stderr=stderr, + prune_module_path=prune_module_path, + prune_class=prune_class, + cli_args=_cli_args, + cli_kwargs=_cli_kwargs, + ) + + +class AMSNetworkStageJob(AMSStageJob): + def __init__( + self, + resources: Union[Dict[str, Union[str, int]], AMSJobResources], + dest: str, + persistent_db_path: str, + creds: str, + store: bool = True, + db_type: str = "dhdf5", + update_models: bool = False, + prune_module_path: Optional[str] = None, + prune_class: Optional[str] = None, + environ: Optional[Mapping[str, str]] = None, + stdout: Optional[str] = None, + stderr: Optional[str] = None, + cli_args: List[str] = [], + cli_kwargs: Mapping[str, str] = {}, + ): + _cli_args = list(cli_args) + if update_models: + _cli_args.append("--update-rmq-models") + _cli_kwargs = dict(cli_kwargs) + _cli_kwargs["--creds"] = creds + _cli_kwargs["--mechanism"] = "network" + + super().__init__( + resources, + dest, + persistent_db_path, + store, + db_type, + environ=environ, + stdout=stdout, + stderr=stderr, + prune_module_path=prune_module_path, + prune_class=prune_class, + cli_args=_cli_args, + cli_kwargs=_cli_kwargs, + ) + + @classmethod + def from_descr(cls, descr, dest, persistent_db_path, creds, num_nodes, cores_per_node, gpus_per_node): + cores_per_instance = 5 + total_cores = num_nodes * cores_per_node + instances = descr.pop("instances", 1) + requires_gpu = descr.pop("requires_gpu", False) + + assert instances == 1, "We are missing support for multi-instance execution" + assert requires_gpu == False, "We are missing support for gpu stager execution" + + resources = AMSJobResources( + nodes=1, + tasks_per_node=1, + cores_per_task=5, + exclusive=False, + gpus_per_task=None, + ) + + return cls(resources, dest, persistent_db_path, creds, **descr) + + +class AMSFSTempStageJob(AMSJob): def __init__( self, store_dir, @@ -377,5 +530,4 @@ def get_echo_job(message): cores_per_task=1, gpus_per_task=0, exclusive=True, - ) return jobspec diff --git a/src/AMSWorkflow/ams/stage.py b/src/AMSWorkflow/ams/stage.py index 88e07e38..7fcabc82 100644 --- a/src/AMSWorkflow/ams/stage.py +++ b/src/AMSWorkflow/ams/stage.py @@ -537,7 +537,7 @@ class Pipeline(ABC): supported_policies = {"sequential", "thread", "process"} supported_writers = {"shdf5", "dhdf5", "csv"} - def __init__(self, db_dir, store, dest_dir=None, stage_dir=None, db_type="hdf5"): + def __init__(self, db_dir, store, dest_dir=None, stage_dir=None, db_type="dhdf5"): """ initializes the Pipeline class to write the final data in the 'dest_dir' using a file writer of type 'db_type' and optionally caching the data in the 'stage_dir' before making them available in the cache store. diff --git a/src/AMSWorkflow/ams/wf_manager.py b/src/AMSWorkflow/ams/wf_manager.py new file mode 100644 index 00000000..f1c6c37a --- /dev/null +++ b/src/AMSWorkflow/ams/wf_manager.py @@ -0,0 +1,126 @@ +from ams.ams_jobs import AMSDomainJob, AMSFSStageJob, AMSNetworkStageJob +from ams.store import AMSDataStore +import flux +import json + +from typing import Tuple, Dict, List, Optional +from ams_jobs import AMSJob +from dataclasses import dataclass, fields +from pathlib import Path + + +def get_allocation_resources(uri: str) -> Tuple[int, int, int]: + """ + @brief Returns the resources of a flux allocation + + :param uri: A flux uri to querry the resources from + :return: A tuple of (nnodes, cores_per_node, gpus_per_node) + """ + flux_instance = flux.Flux(uri) + resources = flux.resource.resource_list(flux_instance).get()["all"] + cores_per_node = int(resources.ncores / resources.nnodes) + gpus_per_node = int(resources.ngpus / resources.nnodes) + return resources.nnodes, cores_per_node, gpus_per_node + + +@dataclass +class Partition: + uri: str + nnodes: int + cores_per_node: int + gpus_per_node: int + + @classmethod + def from_uri(cls, uri): + res = get_allocation_resources(uri) + return cls(uri=uri, nnodes=res[0], cores_per_node=res[1], gpus_per_node=res[2]) + + +class JobList(list): + """ + @brief A list of 'AMSJobs' + """ + + def append(self, job: AMSJob): + if not isinstance(job, AMSJob): + raise TypeError("{self.__classs__.__name__} expects an item of a job") + + super().append(job) + + def __getitem__(self, index): + return super().__getitem__(index) + + def __setitem__(self, index, value): + if not isinstance(value, AMSJob): + raise TypeError("{self.__classs__.__name__} expects an item of a job") + + super().__setitem__(index, value) + + +class WorkflowManager: + """ + @brief Manages all job submissions of the current execution. + """ + + def __init__(self, kosh_path: str, store_name: str, db_name: str, jobs: Dict[str, JobList]): + self._kosh_path = kosh_path + self._store_name = store_name + self._db_name = db_name + self._jobs = jobs + + @classmethod + def from_json( + cls, + domain_resources: Partition, + stage_resources: Partition, + train_resources: Partition, + json_file: str, + creds: Optional[str] = None, + ): + + def create_domain_list(domains: List[Dict]) -> List[JobList]: + jobs = JobList() + for job_descr in domains: + jobs.append(AMSDomainJob.from_descr(job_descr)) + return jobs + + if not Path(json_file).exists(): + raise RuntimeError(f"Workflow description file {json_file} does not exist") + + with open(json_file, "r") as fd: + data = json.load(fd) + + if "db" not in data: + raise KeyError("Workflow decsription file misses 'db' description") + + if not all(key in data["db"] for key in {"kosh-path", "name", "store-name"}): + raise KeyError("Workflow description files misses entries in 'db'") + + store = AMSDataStore(data["db"]["kosh-path"], data["db"]["store-name"], data["db"]["name"]) + + if "domain-jobs" not in data: + raise KeyError("Workflow description files misses 'domain-jobs' entry") + + if len(data["domain-jobs"]) == 0: + raise RuntimeError("There are no jobs described in workflow description file") + + domain_jobs = create_domain_list(data["domain-jobs"]) + + if "stage-job" not in data: + raise RuntimeError("There is no description for a stage-job") + + stage_type = data["stage-job"].pop("type", "rmq") + num_instances = data["stage-job"].pop("instances", 1) + + assert num_instances == 1, "We only support 1 instance at the moment" + assert stage_type == "rmq", "We only support 'rmq' stagers" + + stage_job = AMSNetworkStageJob.from_descr( + data["stage-job"], + store.get_candidate_path(), + store.root_path, + creds, + stage_resources.nnodes, + stage_resources.cores_per_node, + stage_resources.gpus_per_node, + )