Skip to content

Commit

Permalink
SlurmGCP. Fix type annotations problems
Browse files Browse the repository at this point in the history
  • Loading branch information
mr0re1 committed Jan 20, 2025
1 parent 8b711f6 commit 5192a6c
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import util
from google.api_core import exceptions, retry
from google.cloud import bigquery as bq
from google.cloud.bigquery import SchemaField
from google.cloud.bigquery import SchemaField # type: ignore
from util import lookup, run

SACCT = "sacct"
Expand Down Expand Up @@ -175,7 +175,8 @@ def schema_field(field_name, data_type, description, required=False):
# creating the job rows
job_schema = {field.name: field for field in schema_fields}
# Order is important here, as that is how they are parsed from sacct output
Job = namedtuple("Job", job_schema.keys())
Job = namedtuple("Job", job_schema.keys()) # type: ignore
# ... see https://github.com/python/mypy/issues/848

client = bq.Client(
project=lookup().cfg.project,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import shutil
from pathlib import Path
from concurrent.futures import as_completed
from addict import Dict as NSDict
from addict import Dict as NSDict # type: ignore

import util
from util import lookup, run, dirs, separate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ def order(paths: List[List[str]]) -> List[str]:
if not paths: return []
class Vert:
"Represents a vertex in a *network* tree."
def __init__(self, name: str, parent: "Vert"):
def __init__(self, name: str, parent: Optional["Vert"]):
self.name = name
self.parent = parent
# Use `OrderedDict` to preserve insertion order
# TODO: once we move to Python 3.7+ use regular `dict` since it has the same guarantee
self.children = OrderedDict()
self.children: OrderedDict = OrderedDict()

# build a tree, children are ordered by insertion order
root = Vert("", None)
Expand Down Expand Up @@ -107,7 +107,7 @@ def to_hostnames(nodelist: str) -> List[str]:
return [n.decode("utf-8") for n in out.splitlines()]


def get_instances(node_names: List[str]) -> Dict[str, object]:
def get_instances(node_names: List[str]) -> Dict[str, Optional[Instance]]:
fmt = (
"--format=csv[no-heading,separator=','](zone,resourceStatus.physicalHost,name)"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from mock import Mock
from common import TstNodeset, TstCfg, TstMachineConf, TstTemplateInfo

import addict
import addict # type: ignore
import conf
import util

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,6 @@ def test_gen_topology_conf_update():
(["z/n-0", "z/n-1", "z/n-2", "z/n-3", "z/n-4", "z/n-10"], ['n-0', 'n-1', 'n-2', 'n-3', 'n-4', 'n-10']),
(["y/n-0", "z/n-1", "x/n-2", "x/n-3", "y/n-4", "g/n-10"], ['n-0', 'n-4', 'n-1', 'n-2', 'n-3', 'n-10']),
])
def test_sort_nodes_order(paths: list[list[str]], expected: list[str]) -> None:
paths = [l.split("/") for l in paths]
assert sort_nodes.order(paths) == expected
def test_sort_nodes_order(paths: list[str], expected: list[str]) -> None:
paths_expanded = [l.split("/") for l in paths]
assert sort_nodes.order(paths_expanded) == expected
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def test_parse_job_info(job_info, expected_job):
@pytest.mark.parametrize(
"node,state,want",
[
("c-n-2", NodeState("DOWN", {}), NodeState("DOWN", {})), # happy scenario
("c-n-2", NodeState("DOWN", frozenset([])), NodeState("DOWN", frozenset([]))), # happy scenario
("c-d-vodoo", None, None), # dynamic nodeset
("c-x-44", None, None), # unknown(removed) nodeset
("c-n-7", None, None), # Out of bounds: c-n-[0-4] - downsized nodeset
Expand All @@ -340,7 +340,8 @@ def test_node_state(node: str, state: Optional[NodeState], want: NodeState | Non
"d": TstNodeset()},
)
lkp = util.Lookup(cfg)
lkp.slurm_nodes = lambda: {node: state} if state else {}
lkp.slurm_nodes = lambda: {node: state} if state else {} # type: ignore[assignment]
# ... see https://github.com/python/typeshed/issues/6347

if type(want) is type and issubclass(want, Exception):
with pytest.raises(want):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Iterable, List, Tuple, Optional, Any, Dict, Sequence
from typing import Iterable, List, Tuple, Optional, Any, Dict, Sequence, Type
import argparse
import base64
from dataclasses import dataclass
from dataclasses import dataclass, field
from datetime import timedelta, datetime
import hashlib
import inspect
Expand All @@ -43,24 +43,26 @@
from pathlib import Path
from time import sleep, time


# TODO: remove "type: ignore" once moved to newer version of libraries
from google.cloud import secretmanager
from google.cloud import storage
from google.cloud import storage # type: ignore

import google.auth # noqa: E402
from google.oauth2 import service_account # noqa: E402
import googleapiclient.discovery # noqa: E402
import google_auth_httplib2 # noqa: E402
from googleapiclient.http import set_user_agent # noqa: E402
from google.api_core.client_options import ClientOptions # noqa: E402
import httplib2 # noqa: E402
import google.auth
from google.oauth2 import service_account
import googleapiclient.discovery # type: ignore
import google_auth_httplib2 # type: ignore
from googleapiclient.http import set_user_agent # type: ignore
from google.api_core.client_options import ClientOptions
import httplib2

import google.api_core.exceptions as gExceptions # noqa: E402
import google.api_core.exceptions as gExceptions

from requests import get as get_url # noqa: E402
from requests.exceptions import RequestException # noqa: E402
from requests import get as get_url
from requests.exceptions import RequestException

import yaml # noqa: E402
from addict import Dict as NSDict # noqa: E402
import yaml
from addict import Dict as NSDict # type: ignore


USER_AGENT = "Slurm_GCP_Scripts/1.5 (GPN:SchedMD)"
Expand All @@ -83,40 +85,28 @@ def mkdirp(path: Path) -> None:

# load all directories as Paths into a dict-like namespace
dirs = NSDict(
{
n: Path(p)
for n, p in dict.items(
{
"home": "/home",
"apps": "/opt/apps",
"slurm": "/slurm",
"scripts": scripts_dir,
"custom_scripts": "/slurm/custom_scripts",
"munge": "/etc/munge",
"secdisk": "/mnt/disks/sec",
"log": "/var/log/slurm",
}
)
}
home = Path("/home"),
apps = Path("/opt/apps"),
slurm = Path("/slurm"),
scripts = scripts_dir,
custom_scripts = Path("/slurm/custom_scripts"),
munge = Path("/etc/munge"),
secdisk = Path("/mnt/disks/sec"),
log = Path("/var/log/slurm"),
)

slurmdirs = NSDict(
{
n: Path(p)
for n, p in dict.items(
{
"prefix": "/usr/local",
"etc": "/usr/local/etc/slurm",
"state": "/var/spool/slurm",
}
)
}
prefix = Path("/usr/local"),
etc = Path("/usr/local/etc/slurm"),
state = Path("/var/spool/slurm"),
)


# TODO: Remove this hack (relies on undocumented behavior of PyYAML)
# No need to represent NSDict and Path once we move to properly typed & serializable config.
yaml.SafeDumper.yaml_representers[
None
] = lambda self, data: yaml.representer.SafeRepresenter.represent_str(self, str(data))
None # type: ignore
] = lambda self, data: yaml.representer.SafeRepresenter.represent_str(self, str(data)) # type: ignore


class ApiEndpoint(Enum):
Expand Down Expand Up @@ -481,48 +471,59 @@ def _fill_cfg_defaults(cfg: NSDict) -> NSDict:
netstore.server_ip = cfg.slurm_control_host
return cfg

def _list_config_blobs() -> Tuple[Any, str]:
@dataclass
class _ConfigBlobs:
"""
"Private" class that represent a collection of GCS blobs for configuration
"""
core: storage.Blob
partition: List[storage.Blob] = field(default_factory=list)
nodeset: List[storage.Blob] = field(default_factory=list)
nodeset_dyn: List[storage.Blob] = field(default_factory=list)
nodeset_tpu: List[storage.Blob] = field(default_factory=list)

@property
def hash(self) -> str:
h = hashlib.md5()
all = [self.core] + self.partition + self.nodeset + self.nodeset_dyn + self.nodeset_tpu
# sort blobs so hash is consistent
for blob in sorted(all, key=lambda b: b.name):
h.update(blob.md5_hash.encode("utf-8"))
return h.hexdigest()

def _list_config_blobs() -> _ConfigBlobs:
_, common_prefix = _get_bucket_and_common_prefix()
res = { # TODO: use a dataclass once we move to python 3.7
"core": None,
"partition": [],
"nodeset": [],
"nodeset_dyn": [],
"nodeset_tpu": [],
}
hash = hashlib.md5()
blobs = list(blob_list(prefix=""))
# sort blobs so hash is consistent
for blob in sorted(blobs, key=lambda b: b.name):

core: Optional[storage.Blob] = None
rest: Dict[str, List[storage.Blob]] = {"partition": [], "nodeset": [], "nodeset_dyn": [], "nodeset_tpu": []}

for blob in blob_list(prefix=""):
if blob.name == f"{common_prefix}/config.yaml":
res["core"] = blob
hash.update(blob.md5_hash.encode("utf-8"))
for key in ("partition", "nodeset", "nodeset_dyn", "nodeset_tpu"):
core = blob
for key in rest.keys():
if blob.name.startswith(f"{common_prefix}/{key}_configs/"):
res[key].append(blob)
hash.update(blob.md5_hash.encode("utf-8"))

if res["core"] is None:
raise DeffetiveStoredConfigError("config.yaml not found in bucket")
return res, hash.hexdigest()
rest[key].append(blob)

if core is None:
raise DeffetiveStoredConfigError(f"{common_prefix}/config.yaml not found in bucket")
return _ConfigBlobs(core=core, **rest)

def _fetch_config(old_hash: Optional[str]) -> Optional[Tuple[NSDict, str]]:
"""Fetch config from bucket, returns None if no changes are detected."""
blobs, hash = _list_config_blobs()
if old_hash == hash:
blobs = _list_config_blobs()
if old_hash == blobs.hash:
return None

def _download(bs) -> List[Any]:
return [yaml.safe_load(b.download_as_text()) for b in bs]

return _assemble_config(
core=_download([blobs["core"]])[0],
partitions=_download(blobs["partition"]),
nodesets=_download(blobs["nodeset"]),
nodesets_dyn=_download(blobs["nodeset_dyn"]),
nodesets_tpu=_download(blobs["nodeset_tpu"]),
), hash
core=_download([blobs.core])[0],
partitions=_download(blobs.partition),
nodesets=_download(blobs.nodeset),
nodesets_dyn=_download(blobs.nodeset_dyn),
nodesets_tpu=_download(blobs.nodeset_tpu),
), blobs.hash

def _assemble_config(
core: Any,
Expand Down Expand Up @@ -742,7 +743,7 @@ def cached_property(f):
return property(lru_cache()(f))


def retry(max_retries: int, init_wait_time: float, warn_msg: str, exc_type: Exception):
def retry(max_retries: int, init_wait_time: float, warn_msg: str, exc_type: Type[Exception]):
"""Retries functions that raises the exception exc_type.
Retry time is increased by a factor of two for every iteration.
Expand Down Expand Up @@ -938,7 +939,9 @@ def to_hostlist(names: Iterable[str]) -> str:
pref = defaultdict(list)
tokenizer = re.compile(r"^(.*?)(\d*)$")
for name in filter(None, names):
p, s = tokenizer.match(name).groups()
matches = tokenizer.match(name)
assert matches, name
p, s = matches.groups()
pref[p].append(s)

def _compress_suffixes(ss: List[str]) -> List[str]:
Expand Down Expand Up @@ -1430,10 +1433,9 @@ def is_static_node(self, node_name: str) -> bool:
return idx < self.node_nodeset(node_name).node_count_static

@lru_cache(maxsize=None)
def slurm_nodes(self):

def make_node_tuple(node_line):
"""turn node,state line to (node, NodeState(state))"""
def slurm_nodes(self) -> Dict[str, NodeState]:
def parse_line(node_line) -> Tuple[str, NodeState]:
"""turn node,state line to (node, NodeState)"""
# state flags include: CLOUD, COMPLETING, DRAIN, FAIL, POWERED_DOWN,
# POWERING_DOWN
node, fullstate = node_line.split(",")
Expand All @@ -1449,7 +1451,7 @@ def make_node_tuple(node_line):
node_lines = run(cmd, shell=True).stdout.rstrip().splitlines()
nodes = {
node: state
for node, state in map(make_node_tuple, node_lines)
for node, state in map(parse_line, node_lines)
if "CLOUD" in state.flags or "DYNAMIC_NORM" in state.flags
}
return nodes
Expand Down

0 comments on commit 5192a6c

Please sign in to comment.