Skip to content

Commit

Permalink
Workflow sketch
Browse files Browse the repository at this point in the history
  • Loading branch information
koparasy committed Aug 13, 2024
1 parent 1c9f28c commit ef92644
Show file tree
Hide file tree
Showing 3 changed files with 283 additions and 5 deletions.
160 changes: 156 additions & 4 deletions src/AMSWorkflow/ams/ams_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from typing import Optional
from dataclasses import dataclass, fields
from ams import util
from typing import Dict, List, Union, Optional, Mapping
from pathlib import Path


def constuct_cli_cmd(executable, *args, **kwargs):
Expand Down Expand Up @@ -36,7 +38,7 @@ class AMSJob:
"""

@classmethod
def generate_formatting(self, store):
def generate_formatting(cls, store):
return {"AMS_STORE_PATH": store.root_path}

def __init__(
Expand Down Expand Up @@ -233,7 +235,7 @@ def __init__(self, domain_names, stage_dir, *args, **kwargs):
super().__init__(*args, **kwargs)

@classmethod
def from_descr(cls, stage_dir, descr):
def from_descr(cls, descr, stage_dir=None):
domain_job_resources = AMSJobResources(**descr["resources"])
return cls(
name=descr["name"],
Expand Down Expand Up @@ -293,7 +295,158 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)


class AMSFSStageJob(AMSJob):
class AMSStageJob(AMSJob):
def __init__(
self,
resources: Union[Dict[str, Union[str, int]], AMSJobResources],
dest: str,
persistent_db_path: str,
store: bool = True,
db_type: str = "dhdf5",
policy: str = "process",
prune_module_path: Optional[str] = None,
prune_class: Optional[str] = None,
environ: Optional[Mapping[str, str]] = None,
stdout: Optional[str] = None,
stderr: Optional[str] = None,
cli_args: List[str] = [],
cli_kwargs: Mapping[str, str] = {},
):
_cli_args = list(cli_args)
if store:
_cli_args.append("--store")
else:
_cli_args.append("--no-store")

_cli_kwargs = dict(cli_kwargs)
_cli_kwargs["--dest"] = dest
_cli_kwargs["--persistent-db-path"] = persistent_db_path
_cli_kwargs["--db-type"] = db_type
_cli_kwargs["--policy"] = policy

if prune_module_path is not None:
assert Path(prune_module_path).exists(), "Module path to user pruner does not exist"
assert prune_class is not None, "When defining a pruning module please define the class"
_cli_kwargs["--load"] = prune_module_path
_cli_kwargs["--class"] = prune_class

super().__init__(
name="AMSStageJob",
executable="AMSDBStage",
environ=environ,
resources=resources,
stdout=stdout,
stderr=stderr,
cli_args=_cli_args,
cli_kwargs=_cli_kwargs,
)


class AMSFSStageJob(AMSStageJob):
def __init__(
self,
resources: Union[Dict[str, Union[str, int]], AMSJobResources],
dest: str,
persistent_db_path: str,
src: str,
store: bool = True,
db_type="dhf5",
pattern="*.h5",
src_type: str = "shdf5",
prune_module_path: Optional[str] = None,
prune_class: Optional[str] = None,
environ: Optional[Mapping[str, str]] = None,
stdout: Optional[str] = None,
stderr: Optional[str] = None,
cli_args: List[str] = [],
cli_kwargs: Mapping[str, str] = {},
):

_cli_args = list(cli_args)
_cli_kwargs = dict(cli_kwargs)
_cli_kwargs["--src"] = src
_cli_kwargs["--src-type"] = src_type
_cli_kwargs["--pattern"] = pattern
_cli_kwargs["--mechanism"] = "fs"

super().__init__(
resources,
dest,
persistent_db_path,
store,
db_type,
environ=environ,
stdout=stdout,
stderr=stderr,
prune_module_path=prune_module_path,
prune_class=prune_class,
cli_args=_cli_args,
cli_kwargs=_cli_kwargs,
)


class AMSNetworkStageJob(AMSStageJob):
def __init__(
self,
resources: Union[Dict[str, Union[str, int]], AMSJobResources],
dest: str,
persistent_db_path: str,
creds: str,
store: bool = True,
db_type: str = "dhdf5",
update_models: bool = False,
prune_module_path: Optional[str] = None,
prune_class: Optional[str] = None,
environ: Optional[Mapping[str, str]] = None,
stdout: Optional[str] = None,
stderr: Optional[str] = None,
cli_args: List[str] = [],
cli_kwargs: Mapping[str, str] = {},
):
_cli_args = list(cli_args)
if update_models:
_cli_args.append("--update-rmq-models")
_cli_kwargs = dict(cli_kwargs)
_cli_kwargs["--creds"] = creds
_cli_kwargs["--mechanism"] = "network"

super().__init__(
resources,
dest,
persistent_db_path,
store,
db_type,
environ=environ,
stdout=stdout,
stderr=stderr,
prune_module_path=prune_module_path,
prune_class=prune_class,
cli_args=_cli_args,
cli_kwargs=_cli_kwargs,
)

@classmethod
def from_descr(cls, descr, dest, persistent_db_path, creds, num_nodes, cores_per_node, gpus_per_node):
cores_per_instance = 5
total_cores = num_nodes * cores_per_node
instances = descr.pop("instances", 1)
requires_gpu = descr.pop("requires_gpu", False)

assert instances == 1, "We are missing support for multi-instance execution"
assert requires_gpu == False, "We are missing support for gpu stager execution"

resources = AMSJobResources(
nodes=1,
tasks_per_node=1,
cores_per_task=5,
exclusive=False,
gpus_per_task=None,
)

return cls(resources, dest, persistent_db_path, creds, **descr)


class AMSFSTempStageJob(AMSJob):
def __init__(
self,
store_dir,
Expand Down Expand Up @@ -377,5 +530,4 @@ def get_echo_job(message):
cores_per_task=1,
gpus_per_task=0,
exclusive=True,
)
return jobspec
2 changes: 1 addition & 1 deletion src/AMSWorkflow/ams/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ class Pipeline(ABC):
supported_policies = {"sequential", "thread", "process"}
supported_writers = {"shdf5", "dhdf5", "csv"}

def __init__(self, db_dir, store, dest_dir=None, stage_dir=None, db_type="hdf5"):
def __init__(self, db_dir, store, dest_dir=None, stage_dir=None, db_type="dhdf5"):
"""
initializes the Pipeline class to write the final data in the 'dest_dir' using a file writer of type 'db_type'
and optionally caching the data in the 'stage_dir' before making them available in the cache store.
Expand Down
126 changes: 126 additions & 0 deletions src/AMSWorkflow/ams/wf_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from ams.ams_jobs import AMSDomainJob, AMSFSStageJob, AMSNetworkStageJob
from ams.store import AMSDataStore
import flux
import json

from typing import Tuple, Dict, List, Optional
from ams_jobs import AMSJob
from dataclasses import dataclass, fields
from pathlib import Path


def get_allocation_resources(uri: str) -> Tuple[int, int, int]:
"""
@brief Returns the resources of a flux allocation
:param uri: A flux uri to querry the resources from
:return: A tuple of (nnodes, cores_per_node, gpus_per_node)
"""
flux_instance = flux.Flux(uri)
resources = flux.resource.resource_list(flux_instance).get()["all"]
cores_per_node = int(resources.ncores / resources.nnodes)
gpus_per_node = int(resources.ngpus / resources.nnodes)
return resources.nnodes, cores_per_node, gpus_per_node


@dataclass
class Partition:
uri: str
nnodes: int
cores_per_node: int
gpus_per_node: int

@classmethod
def from_uri(cls, uri):
res = get_allocation_resources(uri)
return cls(uri=uri, nnodes=res[0], cores_per_node=res[1], gpus_per_node=res[2])


class JobList(list):
"""
@brief A list of 'AMSJobs'
"""

def append(self, job: AMSJob):
if not isinstance(job, AMSJob):
raise TypeError("{self.__classs__.__name__} expects an item of a job")

super().append(job)

def __getitem__(self, index):
return super().__getitem__(index)

def __setitem__(self, index, value):
if not isinstance(value, AMSJob):
raise TypeError("{self.__classs__.__name__} expects an item of a job")

super().__setitem__(index, value)


class WorkflowManager:
"""
@brief Manages all job submissions of the current execution.
"""

def __init__(self, kosh_path: str, store_name: str, db_name: str, jobs: Dict[str, JobList]):
self._kosh_path = kosh_path
self._store_name = store_name
self._db_name = db_name
self._jobs = jobs

@classmethod
def from_json(
cls,
domain_resources: Partition,
stage_resources: Partition,
train_resources: Partition,
json_file: str,
creds: Optional[str] = None,
):

def create_domain_list(domains: List[Dict]) -> List[JobList]:
jobs = JobList()
for job_descr in domains:
jobs.append(AMSDomainJob.from_descr(job_descr))
return jobs

if not Path(json_file).exists():
raise RuntimeError(f"Workflow description file {json_file} does not exist")

with open(json_file, "r") as fd:
data = json.load(fd)

if "db" not in data:
raise KeyError("Workflow decsription file misses 'db' description")

if not all(key in data["db"] for key in {"kosh-path", "name", "store-name"}):
raise KeyError("Workflow description files misses entries in 'db'")

store = AMSDataStore(data["db"]["kosh-path"], data["db"]["store-name"], data["db"]["name"])

if "domain-jobs" not in data:
raise KeyError("Workflow description files misses 'domain-jobs' entry")

if len(data["domain-jobs"]) == 0:
raise RuntimeError("There are no jobs described in workflow description file")

domain_jobs = create_domain_list(data["domain-jobs"])

if "stage-job" not in data:
raise RuntimeError("There is no description for a stage-job")

stage_type = data["stage-job"].pop("type", "rmq")
num_instances = data["stage-job"].pop("instances", 1)

assert num_instances == 1, "We only support 1 instance at the moment"
assert stage_type == "rmq", "We only support 'rmq' stagers"

stage_job = AMSNetworkStageJob.from_descr(
data["stage-job"],
store.get_candidate_path(),
store.root_path,
creds,
stage_resources.nnodes,
stage_resources.cores_per_node,
stage_resources.gpus_per_node,
)

0 comments on commit ef92644

Please sign in to comment.