Skip to content

Commit

Permalink
feat(Resources): introduce fabric in SSHCE
Browse files Browse the repository at this point in the history
  • Loading branch information
aldbr committed Jul 25, 2024
1 parent 5f268bd commit 1c60e47
Show file tree
Hide file tree
Showing 4 changed files with 324 additions and 470 deletions.
3 changes: 3 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@ dependencies:
- elasticsearch <7.14
- elasticsearch-dsl
- opensearch-py
- fabric
- fts3
- gitpython >=2.1.0
- invoke
- m2crypto >=0.38.0
- matplotlib
- numpy
- paramiko
- pexpect >=4.0.1
- pillow
- prompt-toolkit >=3,<4
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ install_requires =
gfal2-python
importlib_metadata >=4.4
importlib_resources
invoke
M2Crypto >=0.36
packaging
paramiko
pexpect
prompt-toolkit >=3
psutil
Expand Down
118 changes: 80 additions & 38 deletions src/DIRAC/Resources/Computing/SSHBatchComputingElement.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
""" SSH (Virtual) Computing Element: For a given list of ip/cores pair it will send jobs
""" SSH (Virtual) Batch Computing Element: For a given list of ip/cores pair it will send jobs
directly through ssh
"""

Expand All @@ -13,64 +13,78 @@


class SSHBatchComputingElement(SSHComputingElement):
#############################################################################
def __init__(self, ceUniqueID):
"""Standard constructor."""
super().__init__(ceUniqueID)

self.ceType = "SSHBatch"
self.sshHost = []
self.connections = {}
self.execution = "SSHBATCH"

def _reset(self):
"""Process CE parameters and make necessary adjustments"""
# Get the Batch System instance
result = self._getBatchSystem()
if not result["OK"]:
return result

# Get the location of the remote directories
self._getBatchSystemDirectoryLocations()

self.user = self.ceParameters["SSHUser"]
# Get the SSH parameters
self.timeout = self.ceParameters.get("Timeout", self.timeout)
self.user = self.ceParameters.get("SSHUser", self.user)
port = self.ceParameters.get("SSHPort", None)
password = self.ceParameters.get("SSHPassword", None)
key = self.ceParameters.get("SSHKey", None)
tunnel = self.ceParameters.get("SSHTunnel", None)

# Get submission parameters
self.submitOptions = self.ceParameters.get("SubmitOptions", self.submitOptions)
self.preamble = self.ceParameters.get("Preamble", self.preamble)
self.account = self.ceParameters.get("Account", self.account)
self.queue = self.ceParameters["Queue"]
self.log.info("Using queue: ", self.queue)

self.submitOptions = self.ceParameters.get("SubmitOptions", "")
self.preamble = self.ceParameters.get("Preamble", "")
self.account = self.ceParameters.get("Account", "")

# Prepare all the hosts
for hPar in self.ceParameters["SSHHost"].strip().split(","):
host = hPar.strip().split("/")[0]
result = self._prepareRemoteHost(host=host)
if result["OK"]:
self.log.info(f"Host {host} registered for usage")
self.sshHost.append(hPar.strip())
# Get output and error templates
self.outputTemplate = self.ceParameters.get("OutputTemplate", self.outputTemplate)
self.errorTemplate = self.ceParameters.get("ErrorTemplate", self.errorTemplate)

# Prepare the remote hosts
for host in self.ceParameters.get("SSHHost", "").strip().split(","):
hostDetails = host.strip().split("/")
if len(hostDetails) > 1:
hostname = hostDetails[0]
maxJobs = int(hostDetails[1])
else:
self.log.error("Failed to initialize host", host)
hostname = hostDetails[0]
maxJobs = self.ceParameters.get("MaxTotalJobs", 0)

connection = self._getConnection(hostname, self.user, port, password, key, tunnel)

result = self._prepareRemoteHost(connection)
if not result["OK"]:
return result

self.connections[hostname] = {"connection": connection, "maxJobs": maxJobs}
self.log.info(f"Host {hostname} registered for usage")

return S_OK()

#############################################################################

def submitJob(self, executableFile, proxy, numberOfJobs=1):
"""Method to submit job"""

# Choose eligible hosts, rank them by the number of available slots
rankHosts = {}
maxSlots = 0
for host in self.sshHost:
thost = host.split("/")
hostName = thost[0]
maxHostJobs = 1
if len(thost) > 1:
maxHostJobs = int(thost[1])

result = self._getHostStatus(hostName)
for _, details in self.connections.items():
result = self._getHostStatus(details["connection"])
if not result["OK"]:
continue
slots = maxHostJobs - result["Value"]["Running"]
slots = details["maxJobs"] - result["Value"]["Running"]
if slots > 0:
rankHosts.setdefault(slots, [])
rankHosts[slots].append(hostName)
rankHosts[slots].append(details["connection"])
if slots > maxSlots:
maxSlots = slots

Expand All @@ -96,18 +110,28 @@ def submitJob(self, executableFile, proxy, numberOfJobs=1):
restJobs = numberOfJobs
submittedJobs = []
stampDict = {}
batchSystemName = self.batchSystem.__class__.__name__.lower()

for slots in range(maxSlots, 0, -1):
if slots not in rankHosts:
continue
for host in rankHosts[slots]:
result = self._submitJobToHost(submitFile, min(slots, restJobs), host)
for connection in rankHosts[slots]:
result = self._submitJobToHost(connection, submitFile, min(slots, restJobs))
if not result["OK"]:
continue

nJobs = len(result["Value"])
batchIDs, jobStamps = result["Value"]

nJobs = len(batchIDs)
if nJobs > 0:
submittedJobs.extend(result["Value"])
stampDict.update(result.get("PilotStampDict", {}))
jobIDs = [
f"{self.ceType.lower()}{batchSystemName}://{self.ceName}/{connection.host}/{_id}"
for _id in batchIDs
]
submittedJobs.extend(jobIDs)
for iJob, jobID in enumerate(jobIDs):
stampDict[jobID] = jobStamps[iJob]

restJobs = restJobs - nJobs
if restJobs <= 0:
break
Expand All @@ -121,6 +145,8 @@ def submitJob(self, executableFile, proxy, numberOfJobs=1):
result["PilotStampDict"] = stampDict
return result

#############################################################################

def killJob(self, jobIDs):
"""Kill specified jobs"""
jobIDList = list(jobIDs)
Expand All @@ -136,7 +162,7 @@ def killJob(self, jobIDs):

failed = []
for host, jobIDList in hostDict.items():
result = self._killJobOnHost(jobIDList, host)
result = self._killJobOnHost(self.connections[host]["connection"], jobIDList)
if not result["OK"]:
failed.extend(jobIDList)
message = result["Message"]
Expand All @@ -149,16 +175,17 @@ def killJob(self, jobIDs):

return result

#############################################################################

def getCEStatus(self):
"""Method to return information on running and pending jobs."""
result = S_OK()
result["SubmittedJobs"] = self.submittedJobs
result["RunningJobs"] = 0
result["WaitingJobs"] = 0

for host in self.sshHost:
thost = host.split("/")
resultHost = self._getHostStatus(thost[0])
for _, details in self.connections:
resultHost = self._getHostStatus(details["connection"])
if resultHost["OK"]:
result["RunningJobs"] += resultHost["Value"]["Running"]

Expand All @@ -167,6 +194,8 @@ def getCEStatus(self):

return result

#############################################################################

def getJobStatus(self, jobIDList):
"""Get status of the jobs in the given list"""
hostDict = {}
Expand All @@ -178,7 +207,7 @@ def getJobStatus(self, jobIDList):
resultDict = {}
failed = []
for host, jobIDList in hostDict.items():
result = self._getJobStatusOnHost(jobIDList, host)
result = self._getJobStatusOnHost(self.connections[host]["connection"], jobIDList)
if not result["OK"]:
failed.extend(jobIDList)
continue
Expand All @@ -189,3 +218,16 @@ def getJobStatus(self, jobIDList):
resultDict[job] = PilotStatus.UNKNOWN

return S_OK(resultDict)

#############################################################################

def getJobOutput(self, jobID, localDir=None):
"""Get the specified job standard output and error files. If the localDir is provided,
the output is returned as file in this directory. Otherwise, the output is returned
as strings.
"""
self.log.verbose("Getting output for jobID", jobID)

# host can be retrieved from the path of the jobID
host = os.path.dirname(urlparse(jobID).path).lstrip("/")
return self._getJobOutputFilesOnHost(self.connections[host]["connection"], jobID, localDir)
Loading

0 comments on commit 1c60e47

Please sign in to comment.