Skip to content

Commit

Permalink
feat (diracx): Add an equivalent forwardDISET for DiracX
Browse files Browse the repository at this point in the history
  • Loading branch information
chaen committed Feb 29, 2024
1 parent ae56101 commit 26e1e2f
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 18 deletions.
62 changes: 61 additions & 1 deletion src/DIRAC/Core/Security/DiracX.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from __future__ import annotations

__all__ = (
"addRPCStub",
"DiracXClient",
"diracxTokenFromPEM",
"executeRPCStub",
"FutureClient",
)

import base64
import functools
import importlib
import json
import re
import textwrap
Expand All @@ -20,9 +25,10 @@
from diracx.core.utils import serialize_credentials

from DIRAC import gConfig, gLogger

from DIRAC.ConfigurationSystem.Client.Helpers import Registry
from DIRAC.Core.Security.Locations import getDefaultProxyLocation
from DIRAC.Core.Utilities.ReturnValues import convertToReturnValue, returnValueOrRaise
from DIRAC.Core.Utilities.ReturnValues import convertToReturnValue, returnValueOrRaise, isReturnStructure


PEM_BEGIN = "-----BEGIN DIRACX-----"
Expand Down Expand Up @@ -67,6 +73,14 @@ def diracxTokenFromPEM(pemPath) -> dict[str, Any] | None:
return json.loads(base64.b64decode(match).decode("utf-8"))


class FutureClient:
"""This is just a empty class to make sure that all the FutureClients
inherit from a common class.
"""

...


@contextmanager
def DiracXClient() -> _DiracClient:
"""Get a DiracX client instance with the current user's credentials"""
Expand All @@ -87,3 +101,49 @@ def DiracXClient() -> _DiracClient:
pref = DiracxPreferences(url=diracxUrl, credentials_path=token_file.name)
with _DiracClient(diracx_preferences=pref) as api:
yield api


def addRPCStub(meth):
"""Decorator to add an rpc like stub to DiracX adapter method
to be called by the ForwardDISET operation
"""

@functools.wraps(meth)
def inner(self, *args, **kwargs):
dCls = self.__class__.__name__
dMod = self.__module__
res = meth(self, *args, **kwargs)
if isReturnStructure(res):
res["rpcStub"] = {
"dCls": dCls,
"dMod": dMod,
"dMeth": meth.__name__,
"args": args,
"kwargs": kwargs,
}
return res

return inner


def executeRPCStub(stub: dict):
className = stub.get("dCls")
modName = stub.get("dMod")
methName = stub.get("dMeth")
methArgs = stub.get("args")
methKwArgs = stub.get("kwargs")
# Load the module
mod = importlib.import_module(modName)
# import the class
cl = getattr(mod, className)

# Check that cl is a subclass of JSerializable,
# and that we are not putting ourselves in trouble...
if not (isinstance(cl, type) and issubclass(cl, FutureClient)):
raise TypeError("Only subclasses of FutureClient can be decoded")

# Instantiate the object
obj = cl()
meth = getattr(obj, methName)
return meth(*methArgs, **methKwArgs)
56 changes: 56 additions & 0 deletions src/DIRAC/Core/Security/test/Test_DiracX.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import pytest
from DIRAC.Core.Utilities.ReturnValues import convertToReturnValue
from DIRAC.Core.Security.DiracX import addRPCStub, executeRPCStub, FutureClient


class BadClass:
"""This class does not inherit from FutureClient
So we should not be able to execute its stub"""

@addRPCStub
@convertToReturnValue
def sum(self, *args, **kwargs):
"""Just sum whatever is given as param"""
return sum(args + tuple(kwargs.values()))


class Fake(FutureClient):
@addRPCStub
@convertToReturnValue
def sum(self, *args, **kwargs):
"""Just sum whatever is given as param"""
return sum(args + tuple(kwargs.values()))


def test_rpcStub():
b = BadClass()
res = b.sum(1, 2, 3)
assert res["OK"]
assert "rpcStub" in res
stub = res["rpcStub"]
# Cannot execute this stub as it does not come
# from a FutureClient
with pytest.raises(TypeError):
executeRPCStub(stub)

def test_sum(f, *args, **kwargs):
"""Test that the original result is the same as the stub"""

res = f.sum(*args, **kwargs)
stub = res["rpcStub"]
replay_res = executeRPCStub(stub)

assert res["OK"] == replay_res["OK"]
assert res.get("Value") == replay_res.get("Value")
assert res.get("Message") == replay_res.get("Message")

# Test some success cases

f = Fake()
test_sum(f, 1, 2, 3)
test_sum(f, a=3, b=4)
test_sum(f, 1, 2, a=3, b=4)

# Test error case

test_sum(f, "a", None)
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
DISET forwarding operation handler
"""

import importlib

# imports
from DIRAC import S_ERROR, S_OK, gConfig
from DIRAC.ConfigurationSystem.Client.ConfigurationData import gConfigurationData
from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getDNForUsername
from DIRAC.Core.Base.Client import executeRPCStub
from DIRAC.Core.Utilities import DEncode
from DIRAC.Core.Security.DiracX import executeRPCStub
from DIRAC.RequestManagementSystem.private.OperationHandlerBase import OperationHandlerBase

########################################################################
Expand All @@ -27,6 +30,14 @@ class ForwardDISET(OperationHandlerBase):
.. class:: ForwardDISET
functor forwarding DISET operations
There are fundamental differences in behavior between the forward diset
for DIPS service and the one for DiracX:
* dips call will be done with the server certificates and use the delegated DN field
* diracx call will be done with the credentials setup by request tasks
* dips call are just RPC call, they do not execute the logic of the client (that is anyway not relied upon for now)
* diracx calls will effectively call the client entirely.
"""

def __init__(self, operation=None, csPath=None):
Expand All @@ -45,27 +56,34 @@ def __call__(self):
"""execute RPC stub"""
# # decode arguments
try:
decode, length = DEncode.decode(self.operation.Arguments)
self.log.debug(f"decoded len={length} val={decode}")
stub, length = DEncode.decode(self.operation.Arguments)
self.log.debug(f"decoded len={length} val={stub}")
except ValueError as error:
self.log.exception(error)
self.operation.Error = str(error)
self.operation.Status = "Failed"
return S_ERROR(str(error))

# Ensure the forwarded request is done on behalf of the request owner
res = getDNForUsername(self.request.Owner)
if not res["OK"]:
return res
decode[0][1]["delegatedDN"] = res["Value"][0]
decode[0][1]["delegatedGroup"] = self.request.OwnerGroup

# ForwardDiset is supposed to be used with a host certificate
useServerCertificate = gConfig.useServerCertificate()
gConfigurationData.setOptionInCFG("/DIRAC/Security/UseServerCertificate", "true")
forward = executeRPCStub(decode)
if not useServerCertificate:
gConfigurationData.setOptionInCFG("/DIRAC/Security/UseServerCertificate", "false")
# This is the DISET rpcStub
if isinstance(stub, tuple):
# Ensure the forwarded request is done on behalf of the request owner
res = getDNForUsername(self.request.Owner)
if not res["OK"]:
return res
stub[0][1]["delegatedDN"] = res["Value"][0]
stub[0][1]["delegatedGroup"] = self.request.OwnerGroup

# ForwardDiset is supposed to be used with a host certificate
useServerCertificate = gConfig.useServerCertificate()
gConfigurationData.setOptionInCFG("/DIRAC/Security/UseServerCertificate", "true")
forward = executeRPCStub(stub)
if not useServerCertificate:
gConfigurationData.setOptionInCFG("/DIRAC/Security/UseServerCertificate", "false")
# DiracX stub
elif isinstance(stub, dict):
forward = executeRPCStub(stub)
else:
raise TypeError("Unknwon type of stub")

if not forward["OK"]:
self.log.error("unable to execute operation", f"'{self.operation.Type}' : {forward['Message']}")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import importlib
import functools
from datetime import datetime, timezone


from DIRAC.Core.Security.DiracX import DiracXClient
from DIRAC.Core.Security.DiracX import DiracXClient, FutureClient, addRPCStub
from DIRAC.Core.Utilities.ReturnValues import convertToReturnValue
from DIRAC.Core.Utilities.TimeUtilities import fromString

Expand All @@ -26,7 +27,7 @@ def wrapper(*args, **kwargs):
return wrapper


class JobStateUpdateClient:
class JobStateUpdateClient(FutureClient):
@stripValueIfOK
@convertToReturnValue
def sendHeartBeat(self, jobID: str | int, dynamicData: dict, staticData: dict):
Expand Down Expand Up @@ -107,6 +108,7 @@ def setJobStatus(
force=force,
)

@addRPCStub
@stripValueIfOK
@convertToReturnValue
def setJobStatusBulk(self, jobID: str | int, statusDict: dict, force=False):
Expand Down

0 comments on commit 26e1e2f

Please sign in to comment.