From f8470e8034fb18abd0aff64272c93ad845d37558 Mon Sep 17 00:00:00 2001 From: ck Date: Thu, 24 Aug 2023 02:38:31 +0800 Subject: [PATCH] feat:add retry config & ignore exit code setting (#361) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- dpdispatcher/dp_cloud_server.py | 27 +++++++++++++++++--- dpdispatcher/dp_cloud_server_context.py | 24 +++++++++++++++--- dpdispatcher/dpcloudserver/config.py | 1 - dpdispatcher/machine.py | 12 +++++++++ dpdispatcher/openapi.py | 33 +++++++++++++++++++------ dpdispatcher/ssh_context.py | 1 - dpdispatcher/submission.py | 7 ++++-- 7 files changed, 87 insertions(+), 18 deletions(-) diff --git a/dpdispatcher/dp_cloud_server.py b/dpdispatcher/dp_cloud_server.py index 09ff9eca..b0ccdae2 100644 --- a/dpdispatcher/dp_cloud_server.py +++ b/dpdispatcher/dp_cloud_server.py @@ -31,6 +31,8 @@ def __init__(self, context): phone = context.remote_profile.get("phone", None) username = context.remote_profile.get("username", None) password = context.remote_profile.get("password", None) + self.retry_count = context.remote_profile.get("retry_count", 3) + self.ignore_exit_code = context.remote_profile.get("ignore_exit_code", True) ticket = os.environ.get("BOHR_TICKET", None) if ticket: @@ -110,7 +112,6 @@ def do_submit(self, job): # oss_task_zip = 'indicate/' + job.job_hash + '/' + zip_filename oss_task_zip = self._gen_oss_path(job, zip_filename) job_resources = ALI_OSS_BUCKET_URL + oss_task_zip - input_data = self.input_data.copy() if not input_data.get("job_resources"): @@ -187,7 +188,9 @@ def check_status(self, job): f"cannot find job information in bohrium for job {job.job_id} {check_return} {retry_return}" ) - job_state = self.map_dp_job_state(dp_job_status) + job_state = self.map_dp_job_state( + dp_job_status, check_return.get("exitCode", 0), self.ignore_exit_code + ) if job_state == JobStatus.finished: job_log = self.api.get_log(job_id) if self.input_data.get("output_log"): @@ -232,7 +235,7 @@ def check_if_recover(self, submission): # pass @staticmethod - def map_dp_job_state(status): + def map_dp_job_state(status, exit_code, ignore_exit_code=True): if isinstance(status, JobStatus): return status map_dict = { @@ -244,10 +247,13 @@ def map_dp_job_state(status): 4: JobStatus.running, 5: JobStatus.terminated, 6: JobStatus.running, + 9: JobStatus.waiting, } if status not in map_dict: dlog.error(f"unknown job status {status}") return JobStatus.unknown + if status == -1 and exit_code != 0 and ignore_exit_code: + return JobStatus.finished return map_dict[status] def kill(self, job): @@ -261,6 +267,21 @@ def kill(self, job): job_id = job.job_id self.api.kill(job_id) + def get_exit_code(self, job) -> int: + job_id = self._parse_job_id(job.job_id) + if job_id <= 0: + raise RuntimeError(f"cannot parse job id {job.job_id}") + + check_return = self._get_job_detail(job_id, self.group_id) + return check_return.get("exitCode", -999) # type: ignore + + def _parse_job_id(self, str_job_id: str) -> int: + job_id = 0 + if "job_group_id" in str_job_id: + ids = str_job_id.split(":job_group_id:") + job_id, _ = int(ids[0]), int(ids[1]) + return job_id + DpCloudServer = Bohrium Lebesgue = Bohrium diff --git a/dpdispatcher/dp_cloud_server_context.py b/dpdispatcher/dp_cloud_server_context.py index b69c000b..db23f44d 100644 --- a/dpdispatcher/dp_cloud_server_context.py +++ b/dpdispatcher/dp_cloud_server_context.py @@ -10,6 +10,7 @@ from dpdispatcher import dlog from dpdispatcher.base_context import BaseContext +from dpdispatcher.dpcloudserver.config import ALI_STS_BUCKET_NAME, ALI_STS_ENDPOINT # from dpdispatcher.submission import Machine # from . import dlog @@ -20,8 +21,6 @@ DP_CLOUD_SERVER_HOME_DIR = os.path.join( os.path.expanduser("~"), ".dpdispatcher/", "dp_cloud_server/" ) -ENDPOINT = "http://oss-cn-shenzhen.aliyuncs.com" -BUCKET_NAME = os.environ.get("BUCKET_NAME", "dpcloudserver") class BohriumContext(BaseContext): @@ -124,7 +123,9 @@ def upload_job(self, job, common_files=None): upload_zip = zip_file.zip_file_list( self.local_root, zip_task_file, file_list=upload_file_list ) - result = self.api.upload(oss_task_zip, upload_zip, ENDPOINT, BUCKET_NAME) + result = self.api.upload( + oss_task_zip, upload_zip, ALI_STS_ENDPOINT, ALI_STS_BUCKET_NAME + ) retry_count = 0 self._backup(self.local_root, upload_zip) @@ -285,6 +286,9 @@ def machine_subfields(cls) -> List[Argument]: doc_remote_profile = ( "The information used to maintain the connection with remote machine." ) + doc_retry_count = "The retry count when a job is terminated" + doc_ignore_exit_code = """The job state will be marked as finished if the exit code is non-zero when set to True. Otherwise, + the job state will be designated as terminated.""" return [ Argument( "remote_profile", @@ -299,6 +303,20 @@ def machine_subfields(cls) -> List[Argument]: alias=["project_id"], doc="Program ID", ), + Argument( + "retry_count", + [int, type(None)], + optional=True, + default=3, + doc=doc_retry_count, + ), + Argument( + "ignore_exit_code", + bool, + optional=True, + default=True, + doc=doc_ignore_exit_code, + ), Argument( "keep_backup", bool, diff --git a/dpdispatcher/dpcloudserver/config.py b/dpdispatcher/dpcloudserver/config.py index 39a8d994..b0391671 100644 --- a/dpdispatcher/dpcloudserver/config.py +++ b/dpdispatcher/dpcloudserver/config.py @@ -14,4 +14,3 @@ "DPDISPATCHER_LEBESGUE_ALI_OSS_BUCKET_URL", "https://dpcloudserver.oss-cn-shenzhen.aliyuncs.com/", ) -# ALI_OSS_BUCKET_URL = 'https://dpcloudserver.oss-cn-shenzhen.aliyuncs.com/ diff --git a/dpdispatcher/machine.py b/dpdispatcher/machine.py index 727d4cef..f09bd647 100644 --- a/dpdispatcher/machine.py +++ b/dpdispatcher/machine.py @@ -453,3 +453,15 @@ def kill(self, job): job """ dlog.warning("Job %s should be manually killed" % job.job_id) + + def get_exit_code(self, job): + """Get exit code of the job. + + Parameters + ---------- + job : Job + job + """ + raise NotImplementedError( + "abstract method get_exit_code should be implemented by derived class" + ) diff --git a/dpdispatcher/openapi.py b/dpdispatcher/openapi.py index 8d34747c..89b4f417 100644 --- a/dpdispatcher/openapi.py +++ b/dpdispatcher/openapi.py @@ -31,6 +31,8 @@ def __init__(self, context): self.remote_profile = context.remote_profile.copy() self.grouped = self.remote_profile.get("grouped", True) + self.retry_count = self.remote_profile.get("retry_count", 3) + self.ignore_exit_code = context.remote_profile.get("ignore_exit_code", True) self.client = Client() self.job = Job(client=self.client) self.storage = Storage(client=self.client) @@ -80,8 +82,9 @@ def do_submit(self, job): "out_files": self._gen_backward_files_list(job), "platform": self.remote_profile.get("platform", "ali"), "image_address": self.remote_profile.get("image_address", ""), - "job_id": job.job_id, } + if job.job_state == JobStatus.unsubmitted: + openapi_params["job_id"] = job.job_id data = self.job.insert(**openapi_params) @@ -126,12 +129,13 @@ def check_status(self, job): f"cannot find job information in bohrium for job {job.job_id} {check_return} {retry_return}" ) - job_state = self.map_dp_job_state(dp_job_status) + job_state = self.map_dp_job_state( + dp_job_status, check_return.get("exitCode", 0), self.ignore_exit_code # type: ignore + ) if job_state == JobStatus.finished: job_log = self.job.log(job_id) if self.remote_profile.get("output_log"): print(job_log, end="") - # print(job.job_id) self._download_job(job) elif self.remote_profile.get("output_log") and job_state == JobStatus.running: job_log = self.job.log(job_id) @@ -140,7 +144,6 @@ def check_status(self, job): def _download_job(self, job): data = self.job.detail(job.job_id) - # print(data) job_url = data["jobFiles"]["outFiles"][0]["url"] # type: ignore if not job_url: return @@ -174,7 +177,7 @@ def check_if_recover(self, submission): # pass @staticmethod - def map_dp_job_state(status): + def map_dp_job_state(status, exit_code, ignore_exit_code=True): if isinstance(status, JobStatus): return status map_dict = { @@ -191,6 +194,8 @@ def map_dp_job_state(status): if status not in map_dict: dlog.error(f"unknown job status {status}") return JobStatus.unknown + if status == -1 and exit_code != 0 and ignore_exit_code: + return JobStatus.finished return map_dict[status] def kill(self, job): @@ -204,6 +209,18 @@ def kill(self, job): job_id = job.job_id self.job.kill(job_id) - # def check_finish_tag(self, job): - # job_tag_finished = job.job_hash + '_job_tag_finished' - # return self.context.check_file_exists(job_tag_finished) + def get_exit_code(self, job): + """Get exit code of the job. + + Parameters + ---------- + job : Job + job + + Returns + ------- + int + exit code + """ + check_return = self.job.detail(job.job_id) + return check_return.get("exitCode", -999) # type: ignore diff --git a/dpdispatcher/ssh_context.py b/dpdispatcher/ssh_context.py index be08648c..ce546caa 100644 --- a/dpdispatcher/ssh_context.py +++ b/dpdispatcher/ssh_context.py @@ -320,7 +320,6 @@ def arginfo(): doc_look_for_keys = ( "enable searching for discoverable private key files in ~/.ssh/" ) - ssh_remote_profile_args = [ Argument("hostname", str, optional=False, doc=doc_hostname), Argument("username", str, optional=False, doc=doc_username), diff --git a/dpdispatcher/submission.py b/dpdispatcher/submission.py index 2a85df7d..9b8301aa 100644 --- a/dpdispatcher/submission.py +++ b/dpdispatcher/submission.py @@ -744,7 +744,6 @@ def __init__( # self.job_work_base = job_work_base self.resources = resources self.machine = machine - self.job_state = None # JobStatus.unsubmitted self.job_id = "" self.fail_count = 0 @@ -839,7 +838,11 @@ def handle_unexpected_job_state(self): f"job: {self.job_hash} {self.job_id} terminated;" f"fail_cout is {self.fail_count}; resubmitting job" ) - if (self.fail_count) > 0 and (self.fail_count % 3 == 0): + retry_count = 3 + assert self.machine is not None + if hasattr(self.machine, "retry_count") and self.machine.retry_count > 0: + retry_count = self.machine.retry_count + if (self.fail_count) > 0 and (self.fail_count % retry_count == 0): raise RuntimeError( f"job:{self.job_hash} {self.job_id} failed {self.fail_count} times.job_detail:{self}" )