diff --git a/planqk/context.py b/planqk/context.py index 9f99ded..af29ae3 100644 --- a/planqk/context.py +++ b/planqk/context.py @@ -5,11 +5,22 @@ from planqk.credentials import get_config_file_path +_ORGANIZATION_ID = "PLANQK_ORGANIZATION_ID" + class Context(BaseModel): - id: str = Field(description="Id of the user or organization") - displayName: str = Field(description="Name of the user or organization") - isOrganization: bool = Field(description="True if the context is an organization") + id: str = Field(..., description="Id of the user or organization") + display_name: str = Field(..., alias="displayName", description="Name of the user or organization") + is_organization: bool = Field(..., alias="isOrganization", description="True if the context is an organization") + + def get_organization_id(self) -> Union[str, None]: + organization_id = os.environ.get(_ORGANIZATION_ID, None) + if organization_id: + return organization_id + + if self.is_organization: + return self.id + return None class Config(BaseModel): diff --git a/tests/integration/test_context.py b/tests/integration/test_context.py index 32f671a..edaa607 100644 --- a/tests/integration/test_context.py +++ b/tests/integration/test_context.py @@ -10,6 +10,52 @@ class ContextResolverTestSuite(unittest.TestCase): def tearDown(self): if "PLANQK_CONFIG_FILE_PATH" in os.environ: del os.environ["PLANQK_CONFIG_FILE_PATH"] + if "PLANQK_ORGANIZATION_ID" in os.environ: + del os.environ["PLANQK_ORGANIZATION_ID"] + + def test_should_get_organization_id_from_context_when_env_var_set(self): + json_value = """ + { + "context": { + "id": "c557000f-f2b1-4505-8172-dac7960caf16", + "displayName": "Test User", + "isOrganization": false + } + } + """ + with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as fp: + fp.write(json_value.encode("utf-8")) + os.environ["PLANQK_CONFIG_FILE_PATH"] = os.path.abspath(fp.name) + + os.environ["PLANQK_ORGANIZATION_ID"] = "c557000f-f2b1-4505-8172-dac7960caf15" + + context_resolver = ContextResolver() + context = context_resolver.get_context() + + self.assertIsNotNone(context) + self.assertEqual(context.is_organization, False) + self.assertEqual(context.get_organization_id(), "c557000f-f2b1-4505-8172-dac7960caf15") + + def test_should_get_organization_id_from_context(self): + json_value = """ + { + "context": { + "id": "c557000f-f2b1-4505-8172-dac7960caf16", + "displayName": "Test Org", + "isOrganization": true + } + } + """ + with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as fp: + fp.write(json_value.encode("utf-8")) + os.environ["PLANQK_CONFIG_FILE_PATH"] = os.path.abspath(fp.name) + + context_resolver = ContextResolver() + context = context_resolver.get_context() + + self.assertIsNotNone(context) + self.assertEqual(context.is_organization, True) + self.assertEqual(context.get_organization_id(), "c557000f-f2b1-4505-8172-dac7960caf16") def test_should_retrieve_context_from_config(self): json_value = """ @@ -30,8 +76,9 @@ def test_should_retrieve_context_from_config(self): self.assertIsNotNone(context) self.assertEqual(context.id, "c557000f-f2b1-4505-8172-dac7960caf16") - self.assertEqual(context.displayName, "Test User") - self.assertEqual(context.isOrganization, False) + self.assertEqual(context.display_name, "Test User") + self.assertEqual(context.is_organization, False) + self.assertIsNone(context.get_organization_id()) def test_should_return_none_when_file_not_present(self): os.environ["PLANQK_CONFIG_FILE_PATH"] = "/var/folders/c6/32xv5kh16p19yf8yz7bl294h0000gn/T/tmp8iqmj5ji.json"