From 1c60e4713c05bfe7237dccbd0ea366493fa9c861 Mon Sep 17 00:00:00 2001 From: aldbr Date: Thu, 27 Jun 2024 09:21:51 +0200 Subject: [PATCH] feat(Resources): introduce fabric in SSHCE --- environment.yml | 3 + setup.cfg | 2 + .../Computing/SSHBatchComputingElement.py | 118 ++- .../Computing/SSHComputingElement.py | 671 +++++++----------- 4 files changed, 324 insertions(+), 470 deletions(-) diff --git a/environment.yml b/environment.yml index 070192a5364..c3176881822 100644 --- a/environment.yml +++ b/environment.yml @@ -20,11 +20,14 @@ dependencies: - elasticsearch <7.14 - elasticsearch-dsl - opensearch-py + - fabric - fts3 - gitpython >=2.1.0 + - invoke - m2crypto >=0.38.0 - matplotlib - numpy + - paramiko - pexpect >=4.0.1 - pillow - prompt-toolkit >=3,<4 diff --git a/setup.cfg b/setup.cfg index 59fc33c02ed..718c516d39c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,8 +37,10 @@ install_requires = gfal2-python importlib_metadata >=4.4 importlib_resources + invoke M2Crypto >=0.36 packaging + paramiko pexpect prompt-toolkit >=3 psutil diff --git a/src/DIRAC/Resources/Computing/SSHBatchComputingElement.py b/src/DIRAC/Resources/Computing/SSHBatchComputingElement.py index 174642b475e..cadad16bfd5 100644 --- a/src/DIRAC/Resources/Computing/SSHBatchComputingElement.py +++ b/src/DIRAC/Resources/Computing/SSHBatchComputingElement.py @@ -1,4 +1,4 @@ -""" SSH (Virtual) Computing Element: For a given list of ip/cores pair it will send jobs +""" SSH (Virtual) Batch Computing Element: For a given list of ip/cores pair it will send jobs directly through ssh """ @@ -13,64 +13,78 @@ class SSHBatchComputingElement(SSHComputingElement): - ############################################################################# def __init__(self, ceUniqueID): """Standard constructor.""" super().__init__(ceUniqueID) - self.ceType = "SSHBatch" - self.sshHost = [] + self.connections = {} self.execution = "SSHBATCH" def _reset(self): """Process CE parameters and make necessary adjustments""" + # Get the Batch System instance result = self._getBatchSystem() if not result["OK"]: return result + + # Get the location of the remote directories self._getBatchSystemDirectoryLocations() - self.user = self.ceParameters["SSHUser"] + # Get the SSH parameters + self.timeout = self.ceParameters.get("Timeout", self.timeout) + self.user = self.ceParameters.get("SSHUser", self.user) + port = self.ceParameters.get("SSHPort", None) + password = self.ceParameters.get("SSHPassword", None) + key = self.ceParameters.get("SSHKey", None) + tunnel = self.ceParameters.get("SSHTunnel", None) + + # Get submission parameters + self.submitOptions = self.ceParameters.get("SubmitOptions", self.submitOptions) + self.preamble = self.ceParameters.get("Preamble", self.preamble) + self.account = self.ceParameters.get("Account", self.account) self.queue = self.ceParameters["Queue"] self.log.info("Using queue: ", self.queue) - self.submitOptions = self.ceParameters.get("SubmitOptions", "") - self.preamble = self.ceParameters.get("Preamble", "") - self.account = self.ceParameters.get("Account", "") - - # Prepare all the hosts - for hPar in self.ceParameters["SSHHost"].strip().split(","): - host = hPar.strip().split("/")[0] - result = self._prepareRemoteHost(host=host) - if result["OK"]: - self.log.info(f"Host {host} registered for usage") - self.sshHost.append(hPar.strip()) + # Get output and error templates + self.outputTemplate = self.ceParameters.get("OutputTemplate", self.outputTemplate) + self.errorTemplate = self.ceParameters.get("ErrorTemplate", self.errorTemplate) + + # Prepare the remote hosts + for host in self.ceParameters.get("SSHHost", "").strip().split(","): + hostDetails = host.strip().split("/") + if len(hostDetails) > 1: + hostname = hostDetails[0] + maxJobs = int(hostDetails[1]) else: - self.log.error("Failed to initialize host", host) + hostname = hostDetails[0] + maxJobs = self.ceParameters.get("MaxTotalJobs", 0) + + connection = self._getConnection(hostname, self.user, port, password, key, tunnel) + + result = self._prepareRemoteHost(connection) + if not result["OK"]: return result + self.connections[hostname] = {"connection": connection, "maxJobs": maxJobs} + self.log.info(f"Host {hostname} registered for usage") + return S_OK() ############################################################################# + def submitJob(self, executableFile, proxy, numberOfJobs=1): """Method to submit job""" - # Choose eligible hosts, rank them by the number of available slots rankHosts = {} maxSlots = 0 - for host in self.sshHost: - thost = host.split("/") - hostName = thost[0] - maxHostJobs = 1 - if len(thost) > 1: - maxHostJobs = int(thost[1]) - - result = self._getHostStatus(hostName) + for _, details in self.connections.items(): + result = self._getHostStatus(details["connection"]) if not result["OK"]: continue - slots = maxHostJobs - result["Value"]["Running"] + slots = details["maxJobs"] - result["Value"]["Running"] if slots > 0: rankHosts.setdefault(slots, []) - rankHosts[slots].append(hostName) + rankHosts[slots].append(details["connection"]) if slots > maxSlots: maxSlots = slots @@ -96,18 +110,28 @@ def submitJob(self, executableFile, proxy, numberOfJobs=1): restJobs = numberOfJobs submittedJobs = [] stampDict = {} + batchSystemName = self.batchSystem.__class__.__name__.lower() + for slots in range(maxSlots, 0, -1): if slots not in rankHosts: continue - for host in rankHosts[slots]: - result = self._submitJobToHost(submitFile, min(slots, restJobs), host) + for connection in rankHosts[slots]: + result = self._submitJobToHost(connection, submitFile, min(slots, restJobs)) if not result["OK"]: continue - nJobs = len(result["Value"]) + batchIDs, jobStamps = result["Value"] + + nJobs = len(batchIDs) if nJobs > 0: - submittedJobs.extend(result["Value"]) - stampDict.update(result.get("PilotStampDict", {})) + jobIDs = [ + f"{self.ceType.lower()}{batchSystemName}://{self.ceName}/{connection.host}/{_id}" + for _id in batchIDs + ] + submittedJobs.extend(jobIDs) + for iJob, jobID in enumerate(jobIDs): + stampDict[jobID] = jobStamps[iJob] + restJobs = restJobs - nJobs if restJobs <= 0: break @@ -121,6 +145,8 @@ def submitJob(self, executableFile, proxy, numberOfJobs=1): result["PilotStampDict"] = stampDict return result + ############################################################################# + def killJob(self, jobIDs): """Kill specified jobs""" jobIDList = list(jobIDs) @@ -136,7 +162,7 @@ def killJob(self, jobIDs): failed = [] for host, jobIDList in hostDict.items(): - result = self._killJobOnHost(jobIDList, host) + result = self._killJobOnHost(self.connections[host]["connection"], jobIDList) if not result["OK"]: failed.extend(jobIDList) message = result["Message"] @@ -149,6 +175,8 @@ def killJob(self, jobIDs): return result + ############################################################################# + def getCEStatus(self): """Method to return information on running and pending jobs.""" result = S_OK() @@ -156,9 +184,8 @@ def getCEStatus(self): result["RunningJobs"] = 0 result["WaitingJobs"] = 0 - for host in self.sshHost: - thost = host.split("/") - resultHost = self._getHostStatus(thost[0]) + for _, details in self.connections: + resultHost = self._getHostStatus(details["connection"]) if resultHost["OK"]: result["RunningJobs"] += resultHost["Value"]["Running"] @@ -167,6 +194,8 @@ def getCEStatus(self): return result + ############################################################################# + def getJobStatus(self, jobIDList): """Get status of the jobs in the given list""" hostDict = {} @@ -178,7 +207,7 @@ def getJobStatus(self, jobIDList): resultDict = {} failed = [] for host, jobIDList in hostDict.items(): - result = self._getJobStatusOnHost(jobIDList, host) + result = self._getJobStatusOnHost(self.connections[host]["connection"], jobIDList) if not result["OK"]: failed.extend(jobIDList) continue @@ -189,3 +218,16 @@ def getJobStatus(self, jobIDList): resultDict[job] = PilotStatus.UNKNOWN return S_OK(resultDict) + + ############################################################################# + + def getJobOutput(self, jobID, localDir=None): + """Get the specified job standard output and error files. If the localDir is provided, + the output is returned as file in this directory. Otherwise, the output is returned + as strings. + """ + self.log.verbose("Getting output for jobID", jobID) + + # host can be retrieved from the path of the jobID + host = os.path.dirname(urlparse(jobID).path).lstrip("/") + return self._getJobOutputFilesOnHost(self.connections[host]["connection"], jobID, localDir) diff --git a/src/DIRAC/Resources/Computing/SSHComputingElement.py b/src/DIRAC/Resources/Computing/SSHComputingElement.py index 81b8b7df6bd..a39ec57c8c8 100644 --- a/src/DIRAC/Resources/Computing/SSHComputingElement.py +++ b/src/DIRAC/Resources/Computing/SSHComputingElement.py @@ -40,25 +40,19 @@ SSH password SSHPort: - Port number if not standard, e.g. for the gsissh access + Port number if not standard SSHKey: Location of the ssh private key for no-password connection -SSHOptions: - Any other SSH options to be used. Example:: - - SSHOptions = -o UserKnownHostsFile=/local/path/to/known_hosts - - Allows to have a local copy of the ``known_hosts`` file, independent of the HOME directory. - SSHTunnel: String defining the use of intermediate SSH host. Example:: ssh -i /private/key/location -l final_user final_host -SSHType: - SSH (default) or gsissh +Timeout: + Timeout for the SSH commands. Default is 120 seconds. + **Code Documentation** """ @@ -68,279 +62,90 @@ import shutil import stat import uuid -from shlex import quote as shlex_quote from urllib.parse import quote, unquote, urlparse -import pexpect +from fabric import Connection +from invoke.exceptions import CommandTimedOut +from paramiko.ssh_exception import SSHException import DIRAC -from DIRAC import S_ERROR, S_OK, gLogger +from DIRAC import S_ERROR, S_OK from DIRAC.Core.Utilities.List import breakListIntoChunks, uniqueElements from DIRAC.Resources.Computing.BatchSystems.executeBatch import executeBatchContent from DIRAC.Resources.Computing.ComputingElement import ComputingElement from DIRAC.Resources.Computing.PilotBundle import bundleProxy, writeScript -class SSH: - """SSH class encapsulates passing commands and files through an SSH tunnel - to a remote host. It can use either ssh or gsissh access. The final host - where the commands will be executed and where the files will copied/retrieved - can be reached through an intermediate host if SSHTunnel parameters is defined. - - SSH constructor parameters are defined in a SSH accessible Computing Element - in the Configuration System: - - - SSHHost: SSH host name - - SSHUser: SSH user login - - SSHPassword: SSH password - - SSHPort: port number if not standard, e.g. for the gsissh access - - SSHKey: location of the ssh private key for no-password connection - - SSHOptions: any other SSH options to be used - - SSHTunnel: string defining the use of intermediate SSH host. Example: - 'ssh -i /private/key/location -l final_user final_host' - - SSHType: ssh ( default ) or gsissh - - The class public interface includes two methods: - - sshCall( timeout, command_sequence ) - scpCall( timeout, local_file, remote_file, upload = False/True ) - """ - - def __init__(self, host=None, parameters=None): - self.host = host - if parameters is None: - parameters = {} - if not host: - self.host = parameters.get("SSHHost", "") - - self.user = parameters.get("SSHUser", "") - self.password = parameters.get("SSHPassword", "") - self.port = parameters.get("SSHPort", "") - self.key = parameters.get("SSHKey", "") - self.options = parameters.get("SSHOptions", "") - self.sshTunnel = parameters.get("SSHTunnel", "") - self.sshType = parameters.get("SSHType", "ssh") - - if self.port: - self.options += f" -p {self.port}" - if self.key: - self.options += f" -i {self.key}" - self.options = self.options.strip() - - self.log = gLogger.getSubLogger("SSH") - - def __ssh_call(self, command, timeout): - if not timeout: - timeout = 999 - - ssh_newkey = "Are you sure you want to continue connecting" - try: - child = pexpect.spawn(command, timeout=timeout, encoding="utf-8") - i = child.expect([pexpect.TIMEOUT, ssh_newkey, pexpect.EOF, "assword: "]) - if i == 0: # Timeout - return S_OK((-1, child.before, "SSH login failed")) - - if i == 1: # SSH does not have the public key. Just accept it. - child.sendline("yes") - child.expect("assword: ") - i = child.expect([pexpect.TIMEOUT, "assword: "]) - if i == 0: # Timeout - return S_OK((-1, str(child.before) + str(child.after), "SSH login failed")) - if i == 1: - child.sendline(self.password) - child.expect(pexpect.EOF) - return S_OK((0, child.before, "")) - - if i == 2: - # Passwordless login, get the output - return S_OK((0, child.before, "")) - - if self.password: - child.sendline(self.password) - child.expect(pexpect.EOF) - return S_OK((0, child.before, "")) - - return S_ERROR(f"Unknown error: {child.before}") - except Exception as x: - return S_ERROR(f"Encountered exception: {str(x)}") - - def sshCall(self, timeout, cmdSeq): - """Execute remote command via a ssh remote call - - :param int timeout: timeout of the command - :param cmdSeq: list of command components - :type cmdSeq: python:list - """ - - command = cmdSeq - if isinstance(cmdSeq, list): - command = " ".join(cmdSeq) - - pattern = "__DIRAC__" - - if self.sshTunnel: - command = command.replace("'", '\\\\\\"') - command = command.replace("$", "\\\\\\$") - command = '/bin/sh -c \' {} -q {} -l {} {} "{} \\"echo {}; {}\\" " \' '.format( - self.sshType, - self.options, - self.user, - self.host, - self.sshTunnel, - pattern, - command, - ) - else: - # command = command.replace( '$', '\$' ) - command = '{} -q {} -l {} {} "echo {}; {}"'.format( - self.sshType, - self.options, - self.user, - self.host, - pattern, - command, - ) - self.log.debug(f"SSH command: {command}") - result = self.__ssh_call(command, timeout) - self.log.debug(f"SSH command result {str(result)}") - if not result["OK"]: - return result - - # Take the output only after the predefined pattern - ind = result["Value"][1].find("__DIRAC__") - if ind == -1: - return result - - status, output, error = result["Value"] - output = output[ind + 9 :] - if output.startswith("\r"): - output = output[1:] - if output.startswith("\n"): - output = output[1:] - - result["Value"] = (status, output, error) - return result - - def scpCall(self, timeout, localFile, remoteFile, postUploadCommand="", upload=True): - """Perform file copy through an SSH magic. - - :param int timeout: timeout of the command - :param str localFile: local file path, serves as source for uploading and destination for downloading. - Can take 'Memory' as value, in this case the downloaded contents is returned - as result['Value'] - :param str remoteFile: remote file full path - :param str postUploadCommand: command executed on the remote side after file upload - :param bool upload: upload if True, download otherwise - """ - # shlex_quote aims to prevent any security issue or problems with filepath containing spaces - # it returns a shell-escaped version of the filename - localFile = shlex_quote(localFile) - remoteFile = shlex_quote(remoteFile) - if upload: - if self.sshTunnel: - remoteFile = remoteFile.replace("$", r"\\\\\$") - postUploadCommand = postUploadCommand.replace("$", r"\\\\\$") - command = '/bin/sh -c \'cat {} | {} -q {} {}@{} "{} \\"cat > {}; {}\\""\' '.format( - localFile, - self.sshType, - self.options, - self.user, - self.host, - self.sshTunnel, - remoteFile, - postUploadCommand, - ) - else: - command = "/bin/sh -c \"cat {} | {} -q {} {}@{} 'cat > {}; {}'\" ".format( - localFile, - self.sshType, - self.options, - self.user, - self.host, - remoteFile, - postUploadCommand, - ) - else: - finalCat = f"| cat > {localFile}" - if localFile.lower() == "memory": - finalCat = "" - if self.sshTunnel: - remoteFile = remoteFile.replace("$", "\\\\\\$") - command = '/bin/sh -c \'{} -q {} -l {} {} "{} \\"cat {}\\"" {}\''.format( - self.sshType, - self.options, - self.user, - self.host, - self.sshTunnel, - remoteFile, - finalCat, - ) - else: - remoteFile = remoteFile.replace("$", r"\$") - command = "/bin/sh -c '{} -q {} -l {} {} \"cat {}\" {}'".format( - self.sshType, - self.options, - self.user, - self.host, - remoteFile, - finalCat, - ) - - self.log.debug(f"SSH copy command: {command}") - return self.__ssh_call(command, timeout) - - class SSHComputingElement(ComputingElement): ############################################################################# def __init__(self, ceUniqueID): """Standard constructor.""" super().__init__(ceUniqueID) - self.execution = "SSHCE" self.submittedJobs = 0 - self.outputTemplate = "" - self.errorTemplate = "" - - ############################################################################ - def setProxy(self, proxy): - """ - Set and prepare proxy to use - :param str proxy: proxy to use - :return: S_OK/S_ERROR - """ - ComputingElement.setProxy(self, proxy) - if self.ceParameters.get("SSHType", "ssh") == "gsissh": - result = self._prepareProxy() - if not result["OK"]: - gLogger.error("SSHComputingElement: failed to set up proxy", result["Message"]) - return result - return S_OK() + # SSH connection + self.hosts = [] + self.connection = None + self.timeout = 120 + self.user = None + + # Submission parameters + self.queue = None + self.submitOptions = None + self.preamble = None + self.account = None + self.execution = "SSHCE" - ############################################################################# - def _addCEConfigDefaults(self): - """Method to make sure all necessary Configuration Parameters are defined""" - # First assure that any global parameters are loaded - ComputingElement._addCEConfigDefaults(self) - # Now batch system specific ones - if "SharedArea" not in self.ceParameters: - # . isn't a good location, move to $HOME - self.ceParameters["SharedArea"] = "$HOME" + # Directories + self.sharedArea = "$HOME" + self.batchOutput = "data" + self.batchError = "data" + self.infoArea = "data" + self.executableArea = "info" + self.workArea = "work" - if "BatchOutput" not in self.ceParameters: - self.ceParameters["BatchOutput"] = "data" + # Output and error templates + self.outputTemplate = "" + self.errorTemplate = "" - if "BatchError" not in self.ceParameters: - self.ceParameters["BatchError"] = "data" + ############################################################################# - if "ExecutableArea" not in self.ceParameters: - self.ceParameters["ExecutableArea"] = "data" + def _run(self, connection: Connection, command: str): + """Run the command on the remote host""" + try: + result = connection.run(command, warn=True, hide=True) + if result.failed: + return S_ERROR(f"[{connection.host}] Command returned an error: {result.stderr}") + return S_OK(result.stdout) + except CommandTimedOut as e: + return S_ERROR( + errno.ETIME, f"[{connection.host}] The command timed out. Consider increasing the timeout: {e}" + ) + except SSHException as e: + return S_ERROR(f"[{connection.host}] SSH error occurred: {str(e)}") - if "InfoArea" not in self.ceParameters: - self.ceParameters["InfoArea"] = "info" + def _put(self, connection: Connection, local: str, remote: str, preserveMode: bool = True): + """Upload a file to the remote host""" + try: + connection.put(local, remote=remote, preserve_mode=preserveMode) + return S_OK() + except OSError as e: + return S_ERROR(f"[{connection.host}] Failed uploading file: {str(e)}") + except SSHException as e: + return S_ERROR(f"[{connection.host}] SSH error occurred: {str(e)}") + + def _get(self, connection: Connection, remote: str, local: str, preserveMode: bool = True): + """Upload a file to the remote host""" + try: + connection.get(local, remote=remote, preserve_mode=preserveMode) + return S_OK() + except OSError as e: + return S_ERROR(f"[{connection.host}] Failed uploading file: {str(e)}") + except SSHException as e: + return S_ERROR(f"[{connection.host}] SSH error occurred: {str(e)}") - if "WorkArea" not in self.ceParameters: - self.ceParameters["WorkArea"] = "work" + ############################################################################# def _getBatchSystem(self): """Load a Batch System instance from the CE Parameters""" @@ -354,90 +159,112 @@ def _getBatchSystem(self): def _getBatchSystemDirectoryLocations(self): """Get names of the locations to store outputs, errors, info and executables.""" - self.sharedArea = self.ceParameters["SharedArea"] - self.batchOutput = self.ceParameters["BatchOutput"] - if not self.batchOutput.startswith("/"): - self.batchOutput = os.path.join(self.sharedArea, self.batchOutput) - self.batchError = self.ceParameters["BatchError"] - if not self.batchError.startswith("/"): - self.batchError = os.path.join(self.sharedArea, self.batchError) - self.infoArea = self.ceParameters["InfoArea"] - if not self.infoArea.startswith("/"): - self.infoArea = os.path.join(self.sharedArea, self.infoArea) - self.executableArea = self.ceParameters["ExecutableArea"] - if not self.executableArea.startswith("/"): - self.executableArea = os.path.join(self.sharedArea, self.executableArea) - self.workArea = self.ceParameters["WorkArea"] - if not self.workArea.startswith("/"): - self.workArea = os.path.join(self.sharedArea, self.workArea) + self.sharedArea = self.ceParameters.get("SharedArea", self.sharedArea) + + def _get_dir(directory: str, defaultValue: str) -> str: + value = self.ceParameters.get(directory, defaultValue) + if value.startswith("/"): + return value + return os.path.join(self.sharedArea, value) + + self.batchOutput = _get_dir("BatchOutput", self.batchOutput) + self.batchError = _get_dir("BatchError", self.batchError) + self.infoArea = _get_dir("InfoArea", self.infoArea) + self.executableArea = _get_dir("ExecutableArea", self.executableArea) + self.workArea = _get_dir("WorkArea", self.workArea) + + def _getConnection(self, host: str, user: str, port: int, password: str, key: str, tunnel: str): + """Get a Connection instance to the host""" + connectionParams = {} + if password: + connectionParams["password"] = password + if key: + connectionParams["key_filename"] = key + + gateway = None + if tunnel: + gateway = Connection(tunnel, user=user, connect_kwargs=connectionParams) + + return Connection( + host, + user=user, + port=port, + gateway=gateway, + connect_kwargs=connectionParams, + connect_timeout=self.timeout, + ) def _reset(self): """Process CE parameters and make necessary adjustments""" + # Get the Batch System instance result = self._getBatchSystem() if not result["OK"]: return result + + # Get the location of the remote directories self._getBatchSystemDirectoryLocations() - self.user = self.ceParameters["SSHUser"] + # Get the SSH parameters + self.host = self.ceParameters.get("SSHHost", self.host) + self.timeout = self.ceParameters.get("Timeout", self.timeout) + self.user = self.ceParameters.get("SSHUser", self.user) + port = self.ceParameters.get("SSHPort", None) + password = self.ceParameters.get("SSHPassword", None) + key = self.ceParameters.get("SSHKey", None) + tunnel = self.ceParameters.get("SSHTunnel", None) + + # Configure the SSH connection + self.connection = self._getConnection(self.host, self.user, port, password, key, tunnel) + + # Get submission parameters + self.submitOptions = self.ceParameters.get("SubmitOptions", self.submitOptions) + self.preamble = self.ceParameters.get("Preamble", self.preamble) + self.account = self.ceParameters.get("Account", self.account) self.queue = self.ceParameters["Queue"] self.log.info("Using queue: ", self.queue) - self.submitOptions = self.ceParameters.get("SubmitOptions", "") - self.preamble = self.ceParameters.get("Preamble", "") + # Get output and error templates + self.outputTemplate = self.ceParameters.get("OutputTemplate", self.outputTemplate) + self.errorTemplate = self.ceParameters.get("ErrorTemplate", self.errorTemplate) - self.account = self.ceParameters.get("Account", "") - result = self._prepareRemoteHost() + # Prepare the remote host + result = self._prepareRemoteHost(self.connection) if not result["OK"]: return result return S_OK() - def _prepareRemoteHost(self, host=None): + def _prepareRemoteHost(self, connection: Connection): """Prepare remote directories and upload control script""" - - ssh = SSH(host=host, parameters=self.ceParameters) - # Make remote directories + self.log.verbose(f"Creating working directories on {self.host}") dirTuple = tuple( uniqueElements( [self.sharedArea, self.executableArea, self.infoArea, self.batchOutput, self.batchError, self.workArea] ) ) - nDirs = len(dirTuple) - cmd = "mkdir -p %s; " * nDirs % dirTuple - cmd = f"bash -c '{cmd}'" - self.log.verbose(f"Creating working directories on {self.ceParameters['SSHHost']}") - result = ssh.sshCall(30, cmd) + cmd = f"mkdir -p {' '.join(dirTuple)}" + result = self._run(connection, cmd) if not result["OK"]: - self.log.error("Failed creating working directories", f"({result['Message']})") + self.log.error("Failed creating working directories: ", result["Message"]) return result - status, output, _error = result["Value"] - if status == -1: - self.log.error("Timeout while creating directories") - return S_ERROR(errno.ETIME, "Timeout while creating directories") - if "cannot" in output: - self.log.error("Failed to create directories", f"({output})") - return S_ERROR(errno.EACCES, "Failed to create directories") # Upload the control script now + self.log.verbose("Generating control script") result = self._generateControlScript() if not result["OK"]: - self.log.warn("Failed generating control script") + self.log.error("Failed generating control script") return result localScript = result["Value"] - self.log.verbose(f"Uploading {self.batchSystem.__class__.__name__} script to {self.ceParameters['SSHHost']}") + os.chmod(localScript, 0o755) + + self.log.verbose(f"Uploading {self.batchSystem.__class__.__name__} script to {self.host}") remoteScript = f"{self.sharedArea}/execute_batch" - result = ssh.scpCall(30, localScript, remoteScript, postUploadCommand=f"chmod +x {remoteScript}") + + result = self._put(connection, localScript, remote=remoteScript) if not result["OK"]: - self.log.warn(f"Failed uploading control script: {result['Message']}") + self.log.error(f"Failed uploading control script: {result['Message']}") return result - status, output, _error = result["Value"] - if status != 0: - if status == -1: - self.log.warn("Timeout while uploading control script") - return S_ERROR("Timeout while uploading control script") - self.log.warn(f"Failed uploading control script: {output}") - return S_ERROR("Failed uploading control script") # Delete the generated control script locally try: @@ -470,10 +297,10 @@ def _generateControlScript(self): return S_OK(f"{controlScript}") - def __executeHostCommand(self, command, options, ssh=None, host=None): - if not ssh: - ssh = SSH(host=host, parameters=self.ceParameters) + ############################################################################# + def __executeHostCommand(self, connection: Connection, command: str, options: dict[str]): + """Execute a command on the remote host""" options["BatchSystem"] = self.batchSystem.__class__.__name__ options["Method"] = command options["SharedDir"] = self.sharedArea @@ -489,46 +316,40 @@ def __executeHostCommand(self, command, options, ssh=None, host=None): options = quote(options) cmd = ( - "bash --login -c 'python %s/execute_batch %s || python3 %s/execute_batch %s || python2 %s/execute_batch %s'" - % (self.sharedArea, options, self.sharedArea, options, self.sharedArea, options) + f"python {self.sharedArea}/execute_batch {options} || " + f"python3 {self.sharedArea}/execute_batch {options} || " + f"python2 {self.sharedArea}/execute_batch {options}" ) - self.log.verbose(f"CE submission command: {cmd}") + self.log.verbose("Command:", f"[{connection.host}] {cmd}") - result = ssh.sshCall(120, cmd) + result = self._run(connection, cmd) if not result["OK"]: - self.log.error(f"{self.ceType} CE job submission failed", result["Message"]) return result - sshStatus = result["Value"][0] - sshStdout = result["Value"][1] - sshStderr = result["Value"][2] - # Examine results of the job submission - if sshStatus == 0: - output = sshStdout.strip().replace("\r", "").strip() - if not output: - return S_ERROR("No output from remote command") - - try: - index = output.index("============= Start output ===============") - output = output[index + 42 :] - except ValueError: - return S_ERROR(f"Invalid output from remote command: {output}") - - try: - output = unquote(output) - result = json.loads(output) - if isinstance(result, str) and result.startswith("Exception:"): - return S_ERROR(result) - return S_OK(result) - except Exception: - return S_ERROR("Invalid return structure from job submission") - else: - return S_ERROR("\n".join([sshStdout, sshStderr])) + output = result["Value"].strip() + if not output: + return S_ERROR("No output from remote command") + + try: + index = output.index("============= Start output ===============") + output = output[index + 42 :] + except ValueError: + return S_ERROR(f"Invalid output from remote command: {output}") + + try: + output = unquote(output) + result = json.loads(output) + if isinstance(result, str) and result.startswith("Exception:"): + return S_ERROR(result) + return S_OK(result) + except Exception: + return S_ERROR("Invalid return structure from job submission") + + ############################################################################# def submitJob(self, executableFile, proxy, numberOfJobs=1): - # self.log.verbose( "Executable file path: %s" % executableFile ) if not os.access(executableFile, 5): os.chmod(executableFile, stat.S_IRWXU | stat.S_IRGRP | stat.S_IXGRP | stat.S_IROTH | stat.S_IXOTH) @@ -544,23 +365,40 @@ def submitJob(self, executableFile, proxy, numberOfJobs=1): else: # no proxy submitFile = executableFile - result = self._submitJobToHost(submitFile, numberOfJobs) + result = self._submitJobToHost(self.connection, submitFile, numberOfJobs) + if proxy: os.remove(submitFile) + if not result["OK"]: + return result + + batchIDs, jobStamps = result["Value"] + batchSystemName = self.batchSystem.__class__.__name__.lower() + jobIDs = [f"{self.ceType.lower()}{batchSystemName}://{self.ceName}/{_id}" for _id in batchIDs] + + result = S_OK(jobIDs) + stampDict = {} + for iJob, jobID in enumerate(jobIDs): + stampDict[jobID] = jobStamps[iJob] + result["PilotStampDict"] = stampDict + self.submittedJobs += len(batchIDs) + return result - def _submitJobToHost(self, executableFile, numberOfJobs, host=None): + def _submitJobToHost(self, connection: Connection, executableFile: str, numberOfJobs: int): """Submit prepared executable to the given host""" - ssh = SSH(host=host, parameters=self.ceParameters) # Copy the executable + self.log.verbose(f"Copying executable to {self.host}") submitFile = os.path.join(self.executableArea, os.path.basename(executableFile)) - result = ssh.scpCall(30, executableFile, submitFile, postUploadCommand=f"chmod +x {submitFile}") + os.chmod(executableFile, 0o755) + + result = self._put(connection, executableFile, submitFile) if not result["OK"]: return result jobStamps = [] - for _i in range(numberOfJobs): + for _ in range(numberOfJobs): jobStamps.append(uuid.uuid4().hex) numberOfProcessors = self.ceParameters.get("NumberOfProcessors", 1) @@ -583,52 +421,37 @@ def _submitJobToHost(self, executableFile, numberOfJobs, host=None): "NumberOfGPUs": self.numberOfGPUs, "Account": self.account, } - if host: - commandOptions["SSHNodeHost"] = host - resultCommand = self.__executeHostCommand("submitJob", commandOptions, ssh=ssh, host=host) + resultCommand = self.__executeHostCommand(connection, "submitJob", commandOptions) if not resultCommand["OK"]: return resultCommand result = resultCommand["Value"] if result["Status"] != 0: return S_ERROR(f"Failed job submission: {result['Message']}") - else: - batchIDs = result["Jobs"] - if batchIDs: - batchSystemName = self.batchSystem.__class__.__name__.lower() - if host is None: - jobIDs = [f"{self.ceType.lower()}{batchSystemName}://{self.ceName}/{_id}" for _id in batchIDs] - else: - jobIDs = [ - f"{self.ceType.lower()}{batchSystemName}://{self.ceName}/{host}/{_id}" for _id in batchIDs - ] - else: - return S_ERROR("No jobs IDs returned") - result = S_OK(jobIDs) - stampDict = {} - for iJob, jobID in enumerate(jobIDs): - stampDict[jobID] = jobStamps[iJob] - result["PilotStampDict"] = stampDict - self.submittedJobs += len(batchIDs) + batchIDs = result["Jobs"] + if not batchIDs: + return S_ERROR("No jobs IDs returned") - return result + return S_OK((batchIDs, jobStamps)) + + ############################################################################# def killJob(self, jobIDList): """Kill a bunch of jobs""" if isinstance(jobIDList, str): jobIDList = [jobIDList] - return self._killJobOnHost(jobIDList) + return self._killJobOnHost(self.connection, jobIDList) - def _killJobOnHost(self, jobIDList, host=None): + def _killJobOnHost(self, connection: Connection, jobIDList: list[str]): """Kill the jobs for the given list of job IDs""" batchSystemJobList = [] for jobID in jobIDList: batchSystemJobList.append(os.path.basename(urlparse(jobID.split(":::")[0]).path)) commandOptions = {"JobIDList": batchSystemJobList, "User": self.user} - resultCommand = self.__executeHostCommand("killJob", commandOptions, host=host) + resultCommand = self.__executeHostCommand(connection, "killJob", commandOptions) if not resultCommand["OK"]: return resultCommand @@ -641,6 +464,8 @@ def _killJobOnHost(self, jobIDList, host=None): return S_OK(len(result["Successful"])) + ############################################################################# + def getCEStatus(self): """Method to return information on running and pending jobs.""" result = S_OK() @@ -648,7 +473,7 @@ def getCEStatus(self): result["RunningJobs"] = 0 result["WaitingJobs"] = 0 - resultHost = self._getHostStatus() + resultHost = self._getHostStatus(self.connection) if not resultHost["OK"]: return resultHost @@ -661,9 +486,9 @@ def getCEStatus(self): return result - def _getHostStatus(self, host=None): + def _getHostStatus(self, connection: Connection): """Get jobs running at a given host""" - resultCommand = self.__executeHostCommand("getCEStatus", {}, host=host) + resultCommand = self.__executeHostCommand(connection, "getCEStatus", {}) if not resultCommand["OK"]: return resultCommand @@ -673,11 +498,13 @@ def _getHostStatus(self, host=None): return S_OK(result) + ############################################################################# + def getJobStatus(self, jobIDList): """Get the status information for the given list of jobs""" - return self._getJobStatusOnHost(jobIDList) + return self._getJobStatusOnHost(self.connection, jobIDList) - def _getJobStatusOnHost(self, jobIDList, host=None): + def _getJobStatusOnHost(self, connection: Connection, jobIDList: list[str]): """Get the status information for the given list of jobs""" resultDict = {} batchSystemJobDict = {} @@ -686,7 +513,7 @@ def _getJobStatusOnHost(self, jobIDList, host=None): batchSystemJobDict[batchSystemJobID] = jobID for jobList in breakListIntoChunks(list(batchSystemJobDict), 100): - resultCommand = self.__executeHostCommand("getJobStatus", {"JobIDList": jobList}, host=host) + resultCommand = self.__executeHostCommand(connection, "getJobStatus", {"JobIDList": jobList}) if not resultCommand["OK"]: return resultCommand @@ -699,65 +526,23 @@ def _getJobStatusOnHost(self, jobIDList, host=None): return S_OK(resultDict) + ############################################################################# + def getJobOutput(self, jobID, localDir=None): """Get the specified job standard output and error files. If the localDir is provided, the output is returned as file in this directory. Otherwise, the output is returned as strings. """ self.log.verbose("Getting output for jobID", jobID) - result = self._getJobOutputFiles(jobID) - if not result["OK"]: - return result - - batchSystemJobID, host, outputFile, errorFile = result["Value"] - - if localDir: - localOutputFile = f"{localDir}/{batchSystemJobID}.out" - localErrorFile = f"{localDir}/{batchSystemJobID}.err" - else: - localOutputFile = "Memory" - localErrorFile = "Memory" - - # Take into account the SSHBatch possible SSHHost syntax - host = host.split("/")[0] - - ssh = SSH(host=host, parameters=self.ceParameters) - resultStdout = ssh.scpCall(30, localOutputFile, outputFile, upload=False) - if not resultStdout["OK"]: - return resultStdout - - resultStderr = ssh.scpCall(30, localErrorFile, errorFile, upload=False) - if not resultStderr["OK"]: - return resultStderr - - if localDir: - output = localOutputFile - error = localErrorFile - else: - output = resultStdout["Value"][1] - error = resultStderr["Value"][1] + return self._getJobOutputFilesOnHost(self.connection, jobID, localDir) - return S_OK((output, error)) - - def _getJobOutputFiles(self, jobID): + def _getJobOutputFilesOnHost(self, connection: Connection, jobID: str, localDir: str | None = None): """Get output file names for the specific CE""" batchSystemJobID = os.path.basename(urlparse(jobID.split(":::")[0]).path) - # host can be retrieved from the path of the jobID - # it might not be present, in this case host is an empty string and will be defined by the CE parameters later - host = os.path.dirname(urlparse(jobID).path).lstrip("/") - - if "OutputTemplate" in self.ceParameters: - self.outputTemplate = self.ceParameters["OutputTemplate"] - self.errorTemplate = self.ceParameters["ErrorTemplate"] if self.outputTemplate: - output = self.outputTemplate % batchSystemJobID - error = self.errorTemplate % batchSystemJobID - elif "OutputTemplate" in self.ceParameters: - self.outputTemplate = self.ceParameters["OutputTemplate"] - self.errorTemplate = self.ceParameters["ErrorTemplate"] - output = self.outputTemplate % batchSystemJobID - error = self.errorTemplate % batchSystemJobID + outputFile = self.outputTemplate % batchSystemJobID + errorFile = self.errorTemplate % batchSystemJobID elif hasattr(self.batchSystem, "getJobOutputFiles"): # numberOfNodes is treated as a string as it can contain values such as "2-4" # where 2 would represent the minimum number of nodes to allocate, and 4 the maximum @@ -768,7 +553,7 @@ def _getJobOutputFiles(self, jobID): "ErrorDir": self.batchError, "NumberOfNodes": numberOfNodes, } - resultCommand = self.__executeHostCommand("getJobOutputFiles", commandOptions, host=host) + resultCommand = self.__executeHostCommand(connection, "getJobOutputFiles", commandOptions) if not resultCommand["OK"]: return resultCommand @@ -780,10 +565,32 @@ def _getJobOutputFiles(self, jobID): self.outputTemplate = result["OutputTemplate"] self.errorTemplate = result["ErrorTemplate"] - output = result["Jobs"][batchSystemJobID]["Output"] - error = result["Jobs"][batchSystemJobID]["Error"] + outputFile = result["Jobs"][batchSystemJobID]["Output"] + errorFile = result["Jobs"][batchSystemJobID]["Error"] else: - output = f"{self.batchOutput}/{batchSystemJobID}.out" - error = f"{self.batchError}/{batchSystemJobID}.err" + outputFile = f"{self.batchOutput}/{batchSystemJobID}.out" + errorFile = f"{self.batchError}/{batchSystemJobID}.err" - return S_OK((batchSystemJobID, host, output, error)) + if localDir: + localOutputFile = f"{localDir}/{batchSystemJobID}.out" + localErrorFile = f"{localDir}/{batchSystemJobID}.err" + else: + localOutputFile = "Memory" + localErrorFile = "Memory" + + resultStdout = self._get(connection, outputFile, localOutputFile, preserveMode=False) + if not resultStdout["OK"]: + return resultStdout + + resultStderr = self._get(connection, errorFile, localErrorFile, preserveMode=False) + if not resultStderr["OK"]: + return resultStderr + + if localDir: + output = localOutputFile + error = localErrorFile + else: + output = resultStdout["Value"][1] + error = resultStderr["Value"][1] + + return S_OK((output, error))