Skip to content
This repository has been archived by the owner on Jun 17, 2024. It is now read-only.

Commit

Permalink
feat: organization id can be provided during provider creation
Browse files Browse the repository at this point in the history
  • Loading branch information
Sebastian Wagner committed Apr 10, 2024
1 parent 60c4909 commit 64434c0
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 17 deletions.
9 changes: 8 additions & 1 deletion planqk/qiskit/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def set_credentials(cls, credentials: DefaultCredentialsProvider):
def get_credentials(cls):
return cls._credentials

@classmethod
def set_organization_id(cls, organization_id: str):
cls._organization_id = organization_id

@classmethod
def perform_request(cls, request_func: Callable[..., Response], url: str, params=None, data=None, headers=None):
headers = {**cls._get_default_headers(), **(headers or {})}
Expand Down Expand Up @@ -140,7 +144,10 @@ def _get_default_headers(cls):
cls._context_resolver = ContextResolver()

context = cls._context_resolver.get_context()
if context is not None and context.is_organization:

if cls._organization_id is not None:
headers["x-organizationid"] = cls._organization_id
elif context is not None and context.is_organization:
headers["x-organizationid"] = context.get_organization_id()

headers[HEADER_CLOUD_TRACE_CTX] = cls._generate_trace_id()
Expand Down
3 changes: 2 additions & 1 deletion planqk/qiskit/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@

class PlanqkQuantumProvider(Provider):

def __init__(self, access_token: str = None):
def __init__(self, access_token: str = None, organization_id: str = None):
"""Initialize the PlanQK provider.
Args:
access_token (str): access token used for authentication with PlanQK. If not token is provided,
the token is retrieved from the environment variable PLANQK_ACCESS_TOKEN that can be either set
manually or by using the PlanQK CLI.
"""
_PlanqkClient.set_credentials(DefaultCredentialsProvider(access_token))
_PlanqkClient.set_organization_id(organization_id)

def backends(self, provider: PROVIDER = None, **kwargs):
"""
Expand Down
9 changes: 6 additions & 3 deletions planqk/qiskit/runtime_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@

class PlanqkQiskitRuntimeService(PlanqkQuantumProvider):

def __init__(self, access_token=None, channel: Optional[ChannelType] = None, channel_strategy=None):
super().__init__(access_token)
def __init__(self, access_token: Optional[str] = None,
organization_id: Optional[str] = None,
channel: Optional[ChannelType] = None,
channel_strategy=None):
super().__init__(access_token, organization_id)

self._channel = channel
self._channel_strategy = channel_strategy
Expand Down Expand Up @@ -73,7 +76,7 @@ def run(self,
qrt_options.validate(channel=self.channel)

hgp_name = 'ibm-q/open/main'

runtime_job_params = RuntimeJobParamsDto(
program_id=program_id,
image=qrt_options.image,
Expand Down
37 changes: 25 additions & 12 deletions tests/integration/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,23 @@
import unittest.mock

from planqk.context import ContextResolver
from planqk.qiskit import PlanqkQuantumProvider
from planqk.qiskit.client.client import _PlanqkClient


def _create_context_env_file():
json_value = """
{
"context": {
"id": "c557000f-f2b1-4505-8172-dac7960caf16",
"displayName": "Test Org",
"isOrganization": true
}
}
"""
with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as fp:
fp.write(json_value.encode("utf-8"))
os.environ["PLANQK_CONFIG_FILE_PATH"] = os.path.abspath(fp.name)


class ContextResolverTestSuite(unittest.TestCase):
Expand Down Expand Up @@ -37,18 +54,7 @@ def test_should_get_organization_id_from_context_when_env_var_set(self):
self.assertEqual(context.get_organization_id(), "c557000f-f2b1-4505-8172-dac7960caf15")

def test_should_get_organization_id_from_context(self):
json_value = """
{
"context": {
"id": "c557000f-f2b1-4505-8172-dac7960caf16",
"displayName": "Test Org",
"isOrganization": true
}
}
"""
with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as fp:
fp.write(json_value.encode("utf-8"))
os.environ["PLANQK_CONFIG_FILE_PATH"] = os.path.abspath(fp.name)
_create_context_env_file()

context_resolver = ContextResolver()
context = context_resolver.get_context()
Expand Down Expand Up @@ -98,3 +104,10 @@ def test_should_return_none_when_file_is_empty(self):
context = context_resolver.get_context()

self.assertIsNone(context)

def test_should_use_user_provided_org_id(self):
_create_context_env_file()
access_token = "user_access_token"
user_org_id = "user_org_id"
PlanqkQuantumProvider(access_token, user_org_id)
self.assertEqual(_PlanqkClient._get_default_headers()["x-organizationid"], user_org_id)

0 comments on commit 64434c0

Please sign in to comment.