Skip to content

Commit

Permalink
Merge pull request #3537 from mr0re1/mypy_slurmsync1
Browse files Browse the repository at this point in the history
SlurmGCP. Fix mypy errors, and found bug
  • Loading branch information
mr0re1 authored Jan 16, 2025
2 parents 959dde1 + 7d89351 commit 035d63a
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional
from typing import List, Optional, Dict
import argparse
from datetime import timedelta
import shlex
Expand Down Expand Up @@ -245,9 +245,9 @@ def group_nodes_bulk(nodes: List[str], resume_data: Optional[ResumeData], lkp: u
if resume_data is None: # all nodes will be considered jobless
resume_data = ResumeData(jobs=[])

nodes = set(nodes) # turn into set to simplify intersection
non_excl = nodes.copy()
groups = {} # excl_job_id|none -> PlacementAndNodes
nodes_set = set(nodes) # turn into set to simplify intersection
non_excl = nodes_set.copy()
groups : Dict[Optional[int], List[PlacementAndNodes]] = {} # excl_job_id|none -> PlacementAndNodes

# expand all exclusive job nodelists
for job in resume_data.jobs:
Expand All @@ -261,7 +261,7 @@ def group_nodes_bulk(nodes: List[str], resume_data: Optional[ResumeData], lkp: u
PlacementAndNodes(
placement=pn.placement,
#... but we only want to handle nodes in nodes_resume in this run.
nodes = sorted(set(pn.nodes) & nodes)
nodes = sorted(set(pn.nodes) & nodes_set)
))
non_excl.difference_update(job.nodes_alloc)

Expand Down Expand Up @@ -396,21 +396,21 @@ def _handle_bulk_insert_op(op: object, nodes: List[str], resume_data: Optional[R
lambda op: "+".join(err["code"] for err in op["error"]["errors"]),
)
for code, failed_ops in by_error_inserts:
failed_nodes = {trim_self_link(op["targetLink"]): op for op in failed_ops}
failed_ops = list(failed_ops)
failed_nodes = [trim_self_link(op["targetLink"]) for op in failed_ops]
hostlist = util.to_hostlist(failed_nodes)
count = len(failed_nodes)
log.error(
f"{count} instances failed to start: {code} ({hostlist}) operationGroupId={group_id}"
f"{len(failed_nodes)} instances failed to start: {code} ({hostlist}) operationGroupId={group_id}"
)
failed_node, failed_op = next(iter(failed_nodes.items()))

msg = "; ".join(
f"{err['code']}: {err['message'] if 'message' in err else 'no message'}"
for err in failed_op["error"]["errors"]
for err in failed_ops[0]["error"]["errors"]
)
if code != "RESOURCE_ALREADY_EXISTS":
down_nodes_notify_jobs(failed_nodes, f"GCP Error: {msg}", resume_data)
log.error(
f"errors from insert for node '{failed_node}' ({failed_op['name']}): {msg}"
f"errors from insert for node '{failed_nodes[0]}' ({failed_ops[0]['name']}): {msg}"
)

ready_nodes = {trim_self_link(op["targetLink"]) for op in successful_inserts}
Expand All @@ -430,9 +430,9 @@ def down_nodes_notify_jobs(nodes: List[str], reason: str, resume_data: Optional[
log.warning("Cannot update and notify jobs with API failures as no valid resume file is present.")
return

nodes = set(nodes) # turn into set to speed up intersection
nodes_set = set(nodes) # turn into set to speed up intersection
for job in resume_data.jobs:
if not (set(job.nodes_alloc) & nodes):
if not (set(job.nodes_alloc) & nodes_set):
continue
run(f"{lookup().scontrol} update jobid={job.job_id} admincomment='{reason_quoted}'")
run(f"{lookup().scontrol} notify {job.job_id} '{reason_quoted}'")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -531,14 +531,17 @@ def main():
try:
main()
except subprocess.TimeoutExpired as e:
stdout = (e.stdout or b"").decode().strip()
stderr = (e.stderr or b"").decode().strip()

log.error(
f"""TimeoutExpired:
command={e.cmd}
timeout={e.timeout}
stdout:
{e.stdout.strip()}
{stdout}
stderr:
{e.stderr.strip()}
{stderr}
"""
)
log.error("Aborting setup...")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,13 @@ def _find_dynamic_node_status() -> NodeAction:
# * delete orhpaned instances
return NodeActionUnchanged() # don't touch dynamic nodes

def get_fr_action(fr: FutureReservation, nodename:str, state:NodeState) -> Optional[NodeAction]:
def get_fr_action(fr: FutureReservation, state:Optional[NodeState]) -> Optional[NodeAction]:
now = datetime.utcnow()
if state is None:
return None # handle like any other node
if fr.start_time < now < fr.end_time:
return None # handle like any other node

if state.base == "DOWN":
return NodeActionUnchanged()
if fr.start_time >= now:
Expand Down Expand Up @@ -227,7 +230,8 @@ def get_node_action(nodename: str) -> NodeAction:

if lookup().node_is_fr(nodename):
fr = lookup().future_reservation(lookup().node_nodeset(nodename))
if action := get_fr_action(fr, nodename, state):
assert fr
if action := get_fr_action(fr, state):
return action

if lookup().node_is_dyn(nodename):
Expand All @@ -242,7 +246,7 @@ def get_node_action(nodename: str) -> NodeAction:
("POWER_DOWN", "POWERING_UP", "POWERING_DOWN", "POWERED_DOWN")
) & (state.flags if state is not None else set())

if inst is None:
if (inst is None) and (state is not None):
if "POWERING_UP" in state.flags:
return NodeActionUnchanged()
if state.base == "DOWN" and "POWERED_DOWN" in state.flags:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ def parse_bucket_uri(uri: str):
"""
pattern = re.compile(r"gs://(?P<bucket>[^/\s]+)/(?P<path>([^/\s]+)(/[^/\s]+)*)")
matches = pattern.match(uri)
assert matches, f"Unexpected bucker URI: '{uri}'"
return matches.group("bucket"), matches.group("path")


Expand Down Expand Up @@ -1394,7 +1395,7 @@ def nodelist_range(self, nodeset_name: str, start: int, count: int) -> str:
return f"{pref}-{start}"
return f"{pref}-[{start}-{start + count - 1}]"

def static_dynamic_sizes(self, nodeset: object) -> int:
def static_dynamic_sizes(self, nodeset: object) -> Tuple[int, int]:
return (nodeset.node_count_static or 0, nodeset.node_count_dynamic_max or 0)

def nodelist(self, nodeset) -> str:
Expand Down Expand Up @@ -1626,16 +1627,18 @@ def future_reservation(self, nodeset:object) -> Optional[FutureReservation]:

active_reservation = None
match = re.search(r'^projects/(?P<project>[^/]+)/zones/(?P<zone>[^/]+)/futureReservations/(?P<name>[^/]+)(/.*)?$', nodeset.future_reservation)
assert match, f"Invalid future reservation name '{nodeset.future_reservation}'"
project, zone, name = match.group("project","zone","name")
fr = self._get_future_reservation(project,zone,name)

# TODO: Remove this "hack" of trimming the Z from timestamps once we move to Python 3.11 (context: https://discuss.python.org/t/parse-z-timezone-suffix-in-datetime/2220/30)
start_time = datetime.fromisoformat(fr["timeWindow"]["startTime"][:-1])
end_time = datetime.fromisoformat(fr["timeWindow"]["endTime"][:-1])

if "autoCreatedReservations" in fr["status"] and (fr_res:=fr["status"]["autoCreatedReservations"][0]):
if "autoCreatedReservations" in fr["status"] and (res:=fr["status"]["autoCreatedReservations"][0]):
if (start_time<=datetime.utcnow()<=end_time):
match = re.search(r'projects/(?P<project>[^/]+)/zones/(?P<zone>[^/]+)/reservations/(?P<name>[^/]+)(/.*)?$',fr_res)
match = re.search(r'projects/(?P<project>[^/]+)/zones/(?P<zone>[^/]+)/reservations/(?P<name>[^/]+)(/.*)?$',res)
assert match, f"Unexpected reservation name '{res}'"
res_name = match.group("name")
bulk_insert_name = f"projects/{project}/reservations/{res_name}"
active_reservation = self.get_reservation_details(project, zone, res_name, bulk_insert_name)
Expand Down

0 comments on commit 035d63a

Please sign in to comment.