Skip to content

Commit

Permalink
Merge pull request #7318 from chaen/prepareHack
Browse files Browse the repository at this point in the history
feat (diracx) add JobStateUpdate handler
  • Loading branch information
chaen authored Nov 30, 2023
2 parents 3abaa00 + e2a2f91 commit fce9b4c
Show file tree
Hide file tree
Showing 10 changed files with 618 additions and 3 deletions.
2 changes: 2 additions & 0 deletions src/DIRAC/Core/Security/DiracX.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def DiracXClient() -> _DiracClient:

proxyLocation = getDefaultProxyLocation()
diracxToken = diracxTokenFromPEM(proxyLocation)
if not diracxToken:
raise ValueError(f"No dirax token in the proxy file {proxyLocation}")

with NamedTemporaryFile(mode="wt") as token_file:
token_file.write(json.dumps(diracxToken))
Expand Down
9 changes: 6 additions & 3 deletions src/DIRAC/FrameworkSystem/scripts/dirac_diracx_whoami.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
def main():
Script.parseCommandLine()

with DiracXClient() as api:
user_info = api.auth.userinfo()
print(json.dumps(user_info.as_dict(), indent=2))
try:
with DiracXClient() as api:
user_info = api.auth.userinfo()
print(json.dumps(user_info.as_dict(), indent=2))
except Exception as e:
print(f"Failed to access DiracX: {e}")


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
""" Class that contains client access to the JobStateUpdate handler. """

from DIRAC.Core.Base.Client import Client, createClient
from DIRAC.WorkloadManagementSystem.FutureClient.JobStateUpdateClient import (
JobStateUpdateClient as futureJobStateUpdateClient,
)


@createClient("WorkloadManagement/JobStateUpdate")
class JobStateUpdateClient(Client):
"""JobStateUpdateClient sets url for the JobStateUpdateHandler."""

diracxClient = futureJobStateUpdateClient

def __init__(self, url=None, **kwargs):
"""
Sets URL for JobStateUpdate handler
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import functools
from datetime import datetime, timezone


from DIRAC.Core.Security.DiracX import DiracXClient
from DIRAC.Core.Utilities.ReturnValues import convertToReturnValue
from DIRAC.Core.Utilities.TimeUtilities import fromString


def stripValueIfOK(func):
"""Decorator to remove S_OK["Value"] from the return value of a function if it is OK.
This is done as some update functions return the number of modified rows in
the database. This likely not actually useful so it isn't supported in
DiracX. Stripping the "Value" key of the dictionary means that we should
get a fairly straight forward error if the assumption is incorrect.
"""

@functools.wraps(func)
def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
if result.get("OK"):
assert result.pop("Value") is None, "Value should be None if OK"
return result

return wrapper


class JobStateUpdateClient:
@stripValueIfOK
@convertToReturnValue
def sendHeartBeat(self, jobID: str | int, dynamicData: dict, staticData: dict):
print("HACK: This is a no-op until we decide what to do")

@stripValueIfOK
@convertToReturnValue
def setJobApplicationStatus(self, jobID: str | int, appStatus: str, source: str = "Unknown"):
statusDict = {
"application_status": appStatus,
}
if source:
statusDict["Source"] = source
with DiracXClient() as api:
api.jobs.set_single_job_status(
jobID,
{datetime.now(tz=timezone.utc): statusDict},
)

@stripValueIfOK
@convertToReturnValue
def setJobAttribute(self, jobID: str | int, attribute: str, value: str):
with DiracXClient() as api:
if attribute == "Status":
api.jobs.set_single_job_status(
jobID,
{datetime.now(tz=timezone.utc): {"status": value}},
)
else:
api.jobs.set_single_job_properties(jobID, {attribute: value})

@stripValueIfOK
@convertToReturnValue
def setJobFlag(self, jobID: str | int, flag: str):
with DiracXClient() as api:
api.jobs.set_single_job_properties(jobID, {flag: True})

@stripValueIfOK
@convertToReturnValue
def setJobParameter(self, jobID: str | int, name: str, value: str):
print("HACK: This is a no-op until we decide what to do")

@stripValueIfOK
@convertToReturnValue
def setJobParameters(self, jobID: str | int, parameters: list):
print("HACK: This is a no-op until we decide what to do")

@stripValueIfOK
@convertToReturnValue
def setJobSite(self, jobID: str | int, site: str):
with DiracXClient() as api:
api.jobs.set_single_job_properties(jobID, {"Site": site})

@stripValueIfOK
@convertToReturnValue
def setJobStatus(
self,
jobID: str | int,
status: str = "",
minorStatus: str = "",
source: str = "Unknown",
datetime_=None,
force=False,
):
statusDict = {}
if status:
statusDict["Status"] = status
if minorStatus:
statusDict["MinorStatus"] = minorStatus
if source:
statusDict["Source"] = source
if datetime_ is None:
datetime_ = datetime.utcnow()
with DiracXClient() as api:
api.jobs.set_single_job_status(
jobID,
{fromString(datetime_).replace(tzinfo=timezone.utc): statusDict},
force=force,
)

@stripValueIfOK
@convertToReturnValue
def setJobStatusBulk(self, jobID: str | int, statusDict: dict, force=False):
statusDict = {fromString(k).replace(tzinfo=timezone.utc): v for k, v in statusDict.items()}
with DiracXClient() as api:
api.jobs.set_job_status_bulk(
{jobID: statusDict},
force=force,
)

@stripValueIfOK
@convertToReturnValue
def setJobsParameter(self, jobsParameterDict: dict):
print("HACK: This is a no-op until we decide what to do")

@stripValueIfOK
@convertToReturnValue
def unsetJobFlag(self, jobID: str | int, flag: str):
with DiracXClient() as api:
api.jobs.set_single_job_properties(jobID, {flag: False})

def updateJobFromStager(self, jobID: str | int, status: str):
raise NotImplementedError("TODO")
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
from functools import partial

import pytest

import DIRAC

DIRAC.initialize()
from DIRAC.WorkloadManagementSystem.Client.JobMonitoringClient import JobMonitoringClient
from ..utils import compare_results

TEST_JOBS = [7470, 7471, 7469]
TEST_JOB_IDS = [TEST_JOBS] + TEST_JOBS + [str(x) for x in TEST_JOBS]


def test_getApplicationStates():
# JobMonitoringClient().getApplicationStates(condDict = None, older = None, newer = None)
method = JobMonitoringClient().getApplicationStates
pytest.skip()


def test_getAtticJobParameters():
# JobMonitoringClient().getAtticJobParameters(jobID: int, parameters = None, rescheduleCycle = -1)
method = JobMonitoringClient().getAtticJobParameters
pytest.skip()


def test_getCounters():
# JobMonitoringClient().getCounters(attrList: list, attrDict = None, cutDate = )
method = JobMonitoringClient().getCounters
pytest.skip()


def test_getInputData():
# JobMonitoringClient().getInputData(jobID: int)
method = JobMonitoringClient().getInputData
pytest.skip()


def test_getJobAttribute():
# JobMonitoringClient().getJobAttribute(jobID: int, attribute: str)
method = JobMonitoringClient().getJobAttribute
pytest.skip()


def test_getJobAttributes():
# JobMonitoringClient().getJobAttributes(jobID: int, attrList = None)
method = JobMonitoringClient().getJobAttributes
pytest.skip()


def test_getJobGroups():
# JobMonitoringClient().getJobGroups(condDict = None, older = None, cutDate = None)
method = JobMonitoringClient().getJobGroups
pytest.skip()


def test_getJobHeartBeatData():
# JobMonitoringClient().getJobHeartBeatData(jobID: int)
method = JobMonitoringClient().getJobHeartBeatData
pytest.skip()


def test_getJobJDL():
# JobMonitoringClient().getJobJDL(jobID: int, original: bool)
method = JobMonitoringClient().getJobJDL
pytest.skip()


def test_getJobLoggingInfo():
# JobMonitoringClient().getJobLoggingInfo(jobID: int)
method = JobMonitoringClient().getJobLoggingInfo
pytest.skip()


def test_getJobOptParameters():
# JobMonitoringClient().getJobOptParameters(jobID: int)
method = JobMonitoringClient().getJobOptParameters
pytest.skip()


def test_getJobOwner():
# JobMonitoringClient().getJobOwner(jobID: int)
method = JobMonitoringClient().getJobOwner
pytest.skip()


def test_getJobPageSummaryWeb():
# JobMonitoringClient().getJobPageSummaryWeb(self: dict, selectDict: list, sortList: int, startItem: int, maxItems, selectJobs = True)
method = JobMonitoringClient().getJobPageSummaryWeb
pytest.skip()


def test_getJobParameter():
# JobMonitoringClient().getJobParameter(jobID: str | int, parName: str)
method = JobMonitoringClient().getJobParameter
pytest.skip()


def test_getJobParameters():
# JobMonitoringClient().getJobParameters(jobIDs: str | int | list, parName = None)
method = JobMonitoringClient().getJobParameters
pytest.skip()


def test_getJobSite():
# JobMonitoringClient().getJobSite(jobID: int)
method = JobMonitoringClient().getJobSite
pytest.skip()


def test_getJobStats():
# JobMonitoringClient().getJobStats(attribute: str, selectDict: dict)
method = JobMonitoringClient().getJobStats
pytest.skip()


def test_getJobSummary():
# JobMonitoringClient().getJobSummary(jobID: int)
method = JobMonitoringClient().getJobSummary
pytest.skip()


def test_getJobTypes():
# JobMonitoringClient().getJobTypes(condDict = None, older = None, newer = None)
method = JobMonitoringClient().getJobTypes
pytest.skip()


def test_getJobs():
# JobMonitoringClient().getJobs(attrDict = None, cutDate = None)
method = JobMonitoringClient().getJobs
pytest.skip()


@pytest.mark.parametrize("jobIDs", TEST_JOB_IDS)
def test_getJobsApplicationStatus(jobIDs):
# JobMonitoringClient().getJobsApplicationStatus(jobIDs: str | int | list)
method = JobMonitoringClient().getJobsApplicationStatus
compare_results(partial(method, jobIDs))


@pytest.mark.parametrize("jobIDs", TEST_JOB_IDS)
def test_getJobsMinorStatus(jobIDs):
# JobMonitoringClient().getJobsMinorStatus(jobIDs: str | int | list)
method = JobMonitoringClient().getJobsMinorStatus
compare_results(partial(method, jobIDs))


def test_getJobsParameters():
# JobMonitoringClient().getJobsParameters(jobIDs: str | int | list, parameters: list)
method = JobMonitoringClient().getJobsParameters
pytest.skip()


@pytest.mark.parametrize("jobIDs", TEST_JOB_IDS)
def test_getJobsSites(jobIDs):
# JobMonitoringClient().getJobsSites(jobIDs: str | int | list)
method = JobMonitoringClient().getJobsSites
compare_results(partial(method, jobIDs))


@pytest.mark.parametrize("jobIDs", TEST_JOB_IDS)
def test_getJobsStates(jobIDs):
# JobMonitoringClient().getJobsStates(jobIDs: str | int | list)
method = JobMonitoringClient().getJobsStates
compare_results(partial(method, jobIDs))


@pytest.mark.parametrize("jobIDs", TEST_JOB_IDS)
def test_getJobsStatus(jobIDs):
# JobMonitoringClient().getJobsStatus(jobIDs: str | int | list)
method = JobMonitoringClient().getJobsStatus
compare_results(partial(method, jobIDs))


def test_getJobsSummary():
# JobMonitoringClient().getJobsSummary(jobIDs: list)
method = JobMonitoringClient().getJobsSummary
pytest.skip()


def test_getMinorStates():
# JobMonitoringClient().getMinorStates(condDict = None, older = None, newer = None)
method = JobMonitoringClient().getMinorStates
pytest.skip()


def test_getOwnerGroup():
# JobMonitoringClient().getOwnerGroup()
method = JobMonitoringClient().getOwnerGroup
pytest.skip()


def test_getOwners():
# JobMonitoringClient().getOwners(condDict = None, older = None, newer = None)
method = JobMonitoringClient().getOwners
pytest.skip()


def test_getSiteSummary():
# JobMonitoringClient().getSiteSummary()
method = JobMonitoringClient().getSiteSummary
pytest.skip()


def test_getSites():
# JobMonitoringClient().getSites(condDict = None, older = None, newer = None)
method = JobMonitoringClient().getSites
pytest.skip()


def test_getStates():
# JobMonitoringClient().getStates(condDict = None, older = None, newer = None)
method = JobMonitoringClient().getStates
pytest.skip()
Loading

0 comments on commit fce9b4c

Please sign in to comment.