diff --git a/src/DIRAC/Core/Security/DiracX.py b/src/DIRAC/Core/Security/DiracX.py index af5269248a5..fd74bf57759 100644 --- a/src/DIRAC/Core/Security/DiracX.py +++ b/src/DIRAC/Core/Security/DiracX.py @@ -1,11 +1,16 @@ from __future__ import annotations __all__ = ( + "addRPCStub", "DiracXClient", "diracxTokenFromPEM", + "executeRPCStub", + "FutureClient", ) import base64 +import functools +import importlib import json import re import textwrap @@ -20,9 +25,10 @@ from diracx.core.utils import serialize_credentials from DIRAC import gConfig, gLogger + from DIRAC.ConfigurationSystem.Client.Helpers import Registry from DIRAC.Core.Security.Locations import getDefaultProxyLocation -from DIRAC.Core.Utilities.ReturnValues import convertToReturnValue, returnValueOrRaise +from DIRAC.Core.Utilities.ReturnValues import convertToReturnValue, returnValueOrRaise, isReturnStructure PEM_BEGIN = "-----BEGIN DIRACX-----" @@ -67,6 +73,14 @@ def diracxTokenFromPEM(pemPath) -> dict[str, Any] | None: return json.loads(base64.b64decode(match).decode("utf-8")) +class FutureClient: + """This is just a empty class to make sure that all the FutureClients + inherit from a common class. + """ + + ... + + @contextmanager def DiracXClient() -> _DiracClient: """Get a DiracX client instance with the current user's credentials""" @@ -87,3 +101,49 @@ def DiracXClient() -> _DiracClient: pref = DiracxPreferences(url=diracxUrl, credentials_path=token_file.name) with _DiracClient(diracx_preferences=pref) as api: yield api + + +def addRPCStub(meth): + """Decorator to add an rpc like stub to DiracX adapter method + to be called by the ForwardDISET operation + + """ + + @functools.wraps(meth) + def inner(self, *args, **kwargs): + dCls = self.__class__.__name__ + dMod = self.__module__ + res = meth(self, *args, **kwargs) + if isReturnStructure(res): + res["rpcStub"] = { + "dCls": dCls, + "dMod": dMod, + "dMeth": meth.__name__, + "args": args, + "kwargs": kwargs, + } + return res + + return inner + + +def executeRPCStub(stub: dict): + className = stub.get("dCls") + modName = stub.get("dMod") + methName = stub.get("dMeth") + methArgs = stub.get("args") + methKwArgs = stub.get("kwargs") + # Load the module + mod = importlib.import_module(modName) + # import the class + cl = getattr(mod, className) + + # Check that cl is a subclass of JSerializable, + # and that we are not putting ourselves in trouble... + if not (isinstance(cl, type) and issubclass(cl, FutureClient)): + raise TypeError("Only subclasses of FutureClient can be decoded") + + # Instantiate the object + obj = cl() + meth = getattr(obj, methName) + return meth(*methArgs, **methKwArgs) diff --git a/src/DIRAC/Core/Security/test/Test_DiracX.py b/src/DIRAC/Core/Security/test/Test_DiracX.py new file mode 100644 index 00000000000..1918f47384c --- /dev/null +++ b/src/DIRAC/Core/Security/test/Test_DiracX.py @@ -0,0 +1,56 @@ +import pytest +from DIRAC.Core.Utilities.ReturnValues import convertToReturnValue +from DIRAC.Core.Security.DiracX import addRPCStub, executeRPCStub, FutureClient + + +class BadClass: + """This class does not inherit from FutureClient + So we should not be able to execute its stub""" + + @addRPCStub + @convertToReturnValue + def sum(self, *args, **kwargs): + """Just sum whatever is given as param""" + return sum(args + tuple(kwargs.values())) + + +class Fake(FutureClient): + @addRPCStub + @convertToReturnValue + def sum(self, *args, **kwargs): + """Just sum whatever is given as param""" + return sum(args + tuple(kwargs.values())) + + +def test_rpcStub(): + b = BadClass() + res = b.sum(1, 2, 3) + assert res["OK"] + assert "rpcStub" in res + stub = res["rpcStub"] + # Cannot execute this stub as it does not come + # from a FutureClient + with pytest.raises(TypeError): + executeRPCStub(stub) + + def test_sum(f, *args, **kwargs): + """Test that the original result is the same as the stub""" + + res = f.sum(*args, **kwargs) + stub = res["rpcStub"] + replay_res = executeRPCStub(stub) + + assert res["OK"] == replay_res["OK"] + assert res.get("Value") == replay_res.get("Value") + assert res.get("Message") == replay_res.get("Message") + + # Test some success cases + + f = Fake() + test_sum(f, 1, 2, 3) + test_sum(f, a=3, b=4) + test_sum(f, 1, 2, a=3, b=4) + + # Test error case + + test_sum(f, "a", None) diff --git a/src/DIRAC/RequestManagementSystem/Agent/RequestOperations/ForwardDISET.py b/src/DIRAC/RequestManagementSystem/Agent/RequestOperations/ForwardDISET.py index e1be3afedb0..ff224ecaacd 100644 --- a/src/DIRAC/RequestManagementSystem/Agent/RequestOperations/ForwardDISET.py +++ b/src/DIRAC/RequestManagementSystem/Agent/RequestOperations/ForwardDISET.py @@ -11,12 +11,15 @@ DISET forwarding operation handler """ +import importlib + # imports from DIRAC import S_ERROR, S_OK, gConfig from DIRAC.ConfigurationSystem.Client.ConfigurationData import gConfigurationData from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getDNForUsername from DIRAC.Core.Base.Client import executeRPCStub from DIRAC.Core.Utilities import DEncode +from DIRAC.Core.Security.DiracX import executeRPCStub from DIRAC.RequestManagementSystem.private.OperationHandlerBase import OperationHandlerBase ######################################################################## @@ -27,6 +30,14 @@ class ForwardDISET(OperationHandlerBase): .. class:: ForwardDISET functor forwarding DISET operations + + There are fundamental differences in behavior between the forward diset + for DIPS service and the one for DiracX: + * dips call will be done with the server certificates and use the delegated DN field + * diracx call will be done with the credentials setup by request tasks + * dips call are just RPC call, they do not execute the logic of the client (that is anyway not relied upon for now) + * diracx calls will effectively call the client entirely. + """ def __init__(self, operation=None, csPath=None): @@ -45,27 +56,34 @@ def __call__(self): """execute RPC stub""" # # decode arguments try: - decode, length = DEncode.decode(self.operation.Arguments) - self.log.debug(f"decoded len={length} val={decode}") + stub, length = DEncode.decode(self.operation.Arguments) + self.log.debug(f"decoded len={length} val={stub}") except ValueError as error: self.log.exception(error) self.operation.Error = str(error) self.operation.Status = "Failed" return S_ERROR(str(error)) - # Ensure the forwarded request is done on behalf of the request owner - res = getDNForUsername(self.request.Owner) - if not res["OK"]: - return res - decode[0][1]["delegatedDN"] = res["Value"][0] - decode[0][1]["delegatedGroup"] = self.request.OwnerGroup - - # ForwardDiset is supposed to be used with a host certificate - useServerCertificate = gConfig.useServerCertificate() - gConfigurationData.setOptionInCFG("/DIRAC/Security/UseServerCertificate", "true") - forward = executeRPCStub(decode) - if not useServerCertificate: - gConfigurationData.setOptionInCFG("/DIRAC/Security/UseServerCertificate", "false") + # This is the DISET rpcStub + if isinstance(stub, tuple): + # Ensure the forwarded request is done on behalf of the request owner + res = getDNForUsername(self.request.Owner) + if not res["OK"]: + return res + stub[0][1]["delegatedDN"] = res["Value"][0] + stub[0][1]["delegatedGroup"] = self.request.OwnerGroup + + # ForwardDiset is supposed to be used with a host certificate + useServerCertificate = gConfig.useServerCertificate() + gConfigurationData.setOptionInCFG("/DIRAC/Security/UseServerCertificate", "true") + forward = executeRPCStub(stub) + if not useServerCertificate: + gConfigurationData.setOptionInCFG("/DIRAC/Security/UseServerCertificate", "false") + # DiracX stub + elif isinstance(stub, dict): + forward = executeRPCStub(stub) + else: + raise TypeError("Unknwon type of stub") if not forward["OK"]: self.log.error("unable to execute operation", f"'{self.operation.Type}' : {forward['Message']}") diff --git a/src/DIRAC/WorkloadManagementSystem/FutureClient/JobStateUpdateClient.py b/src/DIRAC/WorkloadManagementSystem/FutureClient/JobStateUpdateClient.py index 201376014bd..0bdbcf3a727 100644 --- a/src/DIRAC/WorkloadManagementSystem/FutureClient/JobStateUpdateClient.py +++ b/src/DIRAC/WorkloadManagementSystem/FutureClient/JobStateUpdateClient.py @@ -1,8 +1,9 @@ +import importlib import functools from datetime import datetime, timezone -from DIRAC.Core.Security.DiracX import DiracXClient +from DIRAC.Core.Security.DiracX import DiracXClient, FutureClient, addRPCStub from DIRAC.Core.Utilities.ReturnValues import convertToReturnValue from DIRAC.Core.Utilities.TimeUtilities import fromString @@ -26,7 +27,7 @@ def wrapper(*args, **kwargs): return wrapper -class JobStateUpdateClient: +class JobStateUpdateClient(FutureClient): @stripValueIfOK @convertToReturnValue def sendHeartBeat(self, jobID: str | int, dynamicData: dict, staticData: dict): @@ -107,6 +108,7 @@ def setJobStatus( force=force, ) + @addRPCStub @stripValueIfOK @convertToReturnValue def setJobStatusBulk(self, jobID: str | int, statusDict: dict, force=False):