diff --git a/.circleci/config.yml b/.circleci/config.yml index f6fa3365..e4ccd94b 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -21,7 +21,8 @@ jobs: steps: - checkout - run: pip install --user tox - - run: tox -e py<< parameters.python_version >>-pydantic<< parameters.pydantic_version >>-requests<< parameters.requests_version >> + - run: poetry --no-ansi install --no-root --sync + - run: poetry --no-ansi run tox -v -e py<< parameters.python_version >>-pydantic<< parameters.pydantic_version >>-requests<< parameters.requests_version >> --recreate pyright: docker: diff --git a/changelog/@unreleased/pr-44.v2.yml b/changelog/@unreleased/pr-44.v2.yml new file mode 100644 index 00000000..a176f801 --- /dev/null +++ b/changelog/@unreleased/pr-44.v2.yml @@ -0,0 +1,5 @@ +type: improvement +improvement: + description: Updating Errors and Auth + links: + - https://github.com/palantir/foundry-platform-python/pull/44 diff --git a/foundry/_core/confidential_client_auth.py b/foundry/_core/confidential_client_auth.py index 86a11f8d..03e0576a 100644 --- a/foundry/_core/confidential_client_auth.py +++ b/foundry/_core/confidential_client_auth.py @@ -13,7 +13,8 @@ # limitations under the License. -import asyncio +import threading +import time from typing import Callable from typing import List from typing import Optional @@ -27,7 +28,6 @@ from foundry._core.oauth_utils import ConfidentialClientOAuthFlowProvider from foundry._core.oauth_utils import OAuthToken from foundry._core.utils import remove_prefixes -from foundry._errors.environment_not_configured import EnvironmentNotConfigured from foundry._errors.not_authenticated import NotAuthenticated T = TypeVar("T") @@ -56,7 +56,7 @@ def __init__( self._client_secret = client_secret self._token: Optional[OAuthToken] = None self._should_refresh = should_refresh - self._refresh_task: Optional[asyncio.Task] = None + self._stop_refresh_event = threading.Event() self._hostname = hostname self._server_oauth_flow_provider = ConfidentialClientOAuthFlowProvider( client_id, client_secret, self.url, scopes=scopes @@ -70,15 +70,21 @@ def get_token(self) -> OAuthToken: def execute_with_token(self, func: Callable[[OAuthToken], T]) -> T: try: return self._run_with_attempted_refresh(func) + except requests.HTTPError as http_e: + if http_e.response.status_code == 401: + self.sign_out() + raise http_e except Exception as e: - self.sign_out() raise e def run_with_token(self, func: Callable[[OAuthToken], T]) -> None: try: self._run_with_attempted_refresh(func) + except requests.HTTPError as http_e: + if http_e.response.status_code == 401: + self.sign_out() + raise http_e except Exception as e: - self.sign_out() raise e def _run_with_attempted_refresh(self, func: Callable[[OAuthToken], T]) -> T: @@ -89,45 +95,50 @@ def _run_with_attempted_refresh(self, func: Callable[[OAuthToken], T]) -> T: try: return func(self.get_token()) except requests.HTTPError as e: - if e.response is not None and e.response.status_code == 401: + if e.response.status_code == 401: self._refresh_token() return func(self.get_token()) else: raise e @property - def url(self): + def url(self) -> str: return remove_prefixes(self._hostname, ["https://", "http://"]) - def _refresh_token(self): + def _refresh_token(self) -> None: self._token = self._server_oauth_flow_provider.get_token() + def _start_auto_refresh(self) -> None: + def _auto_refresh_token() -> None: + while not self._stop_refresh_event.is_set(): + if self._token: + # Sleep for (expires_in - 60) seconds to refresh the token 1 minute before it expires + time.sleep(self._token.expires_in - 60) + self._refresh_token() + else: + # Wait 10 seconds and check again if the token is set + time.sleep(10) + + refresh_thread = threading.Thread(target=_auto_refresh_token, daemon=True) + refresh_thread.start() + def sign_in_as_service_user(self) -> SignInResponse: token = self._server_oauth_flow_provider.get_token() self._token = token - async def refresh_token_task(): - while True: - if self._token is None: - raise RuntimeError("The token was None when trying to refresh.") - - await asyncio.sleep(self._token.expires_in / 60 - 10) - self._token = self._server_oauth_flow_provider.get_token() - if self._should_refresh: - loop = asyncio.get_event_loop() - self._refresh_task = loop.create_task(refresh_token_task()) + self._start_auto_refresh() return SignInResponse( session={"accessToken": token.access_token, "expiresIn": token.expires_in} ) def sign_out(self) -> SignOutResponse: - if self._refresh_task: - self._refresh_task.cancel() - self._refresh_task = None - if self._token: self._server_oauth_flow_provider.revoke_token(self._token.access_token) self._token = None + + # Signal the auto-refresh thread to stop + self._stop_refresh_event.set() + return SignOutResponse() diff --git a/foundry/_core/oauth.py b/foundry/_core/oauth.py index 25aa66b1..35043ef4 100644 --- a/foundry/_core/oauth.py +++ b/foundry/_core/oauth.py @@ -13,11 +13,14 @@ # limitations under the License. +from typing import Any +from typing import Dict + from pydantic import BaseModel class SignInResponse(BaseModel): - session: dict + session: Dict[str, Any] class SignOutResponse(BaseModel): diff --git a/foundry/_core/oauth_utils.py b/foundry/_core/oauth_utils.py index af4feb78..32003e11 100644 --- a/foundry/_core/oauth_utils.py +++ b/foundry/_core/oauth_utils.py @@ -18,6 +18,8 @@ import secrets import string import time +from typing import Any +from typing import Dict from typing import List from typing import Optional from urllib.parse import urlencode @@ -68,7 +70,7 @@ class OAuthTokenResponse(BaseModel): expires_in: int refresh_token: Optional[str] = None - def __init__(self, token_response: dict) -> None: + def __init__(self, token_response: Dict[str, Any]) -> None: super().__init__(**token_response) @@ -167,13 +169,13 @@ def get_scopes(self) -> List[str]: return scopes -def generate_random_string(min_length=43, max_length=128): +def generate_random_string(min_length: int = 43, max_length: int = 128) -> str: characters = string.ascii_letters + string.digits + "-._~" length = secrets.randbelow(max_length - min_length + 1) + min_length return "".join(secrets.choice(characters) for _ in range(length)) -def generate_code_challenge(input_string): +def generate_code_challenge(input_string: str) -> str: # Calculate the SHA256 hash sha256_hash = hashlib.sha256(input_string.encode("utf-8")).digest() @@ -249,7 +251,7 @@ def get_token(self, code: str, code_verifier: str) -> OAuthToken: response.raise_for_status() return OAuthToken(token=OAuthTokenResponse(token_response=response.json())) - def refresh_token(self, refresh_token): + def refresh_token(self, refresh_token: str) -> OAuthToken: headers = {"Content-Type": "application/x-www-form-urlencoded"} params = { "grant_type": "refresh_token", diff --git a/foundry/_core/public_client_auth.py b/foundry/_core/public_client_auth.py index 7174fa49..7255448e 100644 --- a/foundry/_core/public_client_auth.py +++ b/foundry/_core/public_client_auth.py @@ -36,19 +36,21 @@ class PublicClientAuth(Auth): - scopes: List[str] = ["api:read-data", "api:write-data", "offline_access"] - """ Client for Public Client OAuth-authenticated Ontology applications. Runs a background thread to periodically refresh access token. - :param client_id: OAuth client id to be used by the application. :param client_secret: OAuth client secret to be used by the application. :param hostname: Hostname for authentication and ontology endpoints. """ def __init__( - self, client_id: str, redirect_url: str, hostname: str, should_refresh: bool = False + self, + client_id: str, + redirect_url: str, + hostname: str, + scopes: Optional[List[str]] = None, + should_refresh: bool = False, ) -> None: self._client_id = client_id self._redirect_url = redirect_url @@ -58,7 +60,7 @@ def __init__( self._stop_refresh_event = threading.Event() self._hostname = hostname self._server_oauth_flow_provider = PublicClientOAuthFlowProvider( - client_id=client_id, redirect_url=redirect_url, url=self.url, scopes=self.scopes + client_id=client_id, redirect_url=redirect_url, url=self.url, scopes=scopes ) self._auth_request: Optional[AuthorizeRequest] = None @@ -81,9 +83,11 @@ def run_with_token(self, func: Callable[[OAuthToken], T]) -> None: self.sign_out() raise e - def _refresh_token(self): - if self._token is None: - raise Exception("") + def _refresh_token(self) -> None: + if not self._token: + raise RuntimeError("must have token to refresh") + if not self._token.refresh_token: + raise RuntimeError("no refresh token provided") self._token = self._server_oauth_flow_provider.refresh_token( refresh_token=self._token.refresh_token @@ -92,30 +96,29 @@ def _refresh_token(self): def _run_with_attempted_refresh(self, func: Callable[[OAuthToken], T]) -> T: """ Attempt to run func, and if it fails with a 401, refresh the token and try again. - If it fails with a 401 again, raise the exception. """ try: return func(self.get_token()) except requests.HTTPError as e: - if e.response is not None and e.response.status_code == 401: + if e.response.status_code == 401: self._refresh_token() return func(self.get_token()) else: raise e @property - def url(self): + def url(self) -> str: return remove_prefixes(self._hostname, ["https://", "http://"]) - def sign_in(self) -> None: + def sign_in(self) -> str: self._auth_request = self._server_oauth_flow_provider.generate_auth_request() - webbrowser.open(self._auth_request.url) + return self._auth_request.url - def _start_auto_refresh(self): - def _auto_refresh_token(): + def _start_auto_refresh(self) -> None: + def _auto_refresh_token() -> None: while not self._stop_refresh_event.is_set(): - if self._token: + if self._token and self._token.refresh_token: # Sleep for (expires_in - 60) seconds to refresh the token 1 minute before it expires time.sleep(self._token.expires_in - 60) self._token = self._server_oauth_flow_provider.refresh_token( @@ -129,9 +132,10 @@ def _auto_refresh_token(): refresh_thread.start() def set_token(self, code: str, state: str) -> None: - if self._auth_request is None or state != self._auth_request.state: - raise RuntimeError("Unable to verify the state") - + if not self._auth_request: + raise RuntimeError("Must sign in prior to setting token") + if state != self._auth_request.state: + raise RuntimeError("Unable to verify state") self._token = self._server_oauth_flow_provider.get_token( code=code, code_verifier=self._auth_request.code_verifier ) diff --git a/foundry/_errors/__init__.py b/foundry/_errors/__init__.py index 75ab9dc7..43247eb5 100644 --- a/foundry/_errors/__init__.py +++ b/foundry/_errors/__init__.py @@ -14,7 +14,6 @@ from foundry._errors.environment_not_configured import EnvironmentNotConfigured -from foundry._errors.helpers import format_error_message from foundry._errors.not_authenticated import NotAuthenticated from foundry._errors.palantir_rpc_exception import PalantirRPCException from foundry._errors.sdk_internal_error import SDKInternalError diff --git a/foundry/_errors/environment_not_configured.py b/foundry/_errors/environment_not_configured.py index f8ac1777..dc77143b 100644 --- a/foundry/_errors/environment_not_configured.py +++ b/foundry/_errors/environment_not_configured.py @@ -14,4 +14,5 @@ class EnvironmentNotConfigured(Exception): - pass + def __init__(self, message: str) -> None: + super().__init__(message) diff --git a/foundry/_errors/helpers.py b/foundry/_errors/helpers.py deleted file mode 100644 index 0502cd86..00000000 --- a/foundry/_errors/helpers.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright 2024 Palantir Technologies, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import json -from importlib import import_module -from json import JSONDecodeError -from typing import Any -from typing import Dict - - -def format_error_message(fields: Dict[str, Any]) -> str: - return json.dumps(fields, sort_keys=True, indent=4) diff --git a/foundry/_errors/palantir_rpc_exception.py b/foundry/_errors/palantir_rpc_exception.py index 238a4f95..a57ad075 100644 --- a/foundry/_errors/palantir_rpc_exception.py +++ b/foundry/_errors/palantir_rpc_exception.py @@ -13,15 +13,18 @@ # limitations under the License. +import json from typing import Any from typing import Dict -from foundry._errors.helpers import format_error_message + +def format_error_message(fields: Dict[str, Any]) -> str: + return json.dumps(fields, sort_keys=True, indent=4, default=str) class PalantirRPCException(Exception): def __init__(self, error_metadata: Dict[str, Any]): super().__init__(format_error_message(error_metadata)) - self.name: str = error_metadata["errorName"] - self.parameters: Dict[str, Any] = error_metadata["parameters"] - self.error_instance_id: str = error_metadata["errorInstanceId"] + self.name = error_metadata.get("errorName") + self.parameters = error_metadata.get("parameters") + self.error_instance_id = error_metadata.get("errorInstanceId") diff --git a/pyproject.toml b/pyproject.toml index 3379f8a7..649f6adf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,11 +10,17 @@ keywords = ["Palantir", "Foundry", "SDK", "Client", "API"] packages = [{ include = "foundry" }] [tool.poetry.dependencies] +annotated-types = ">=0.7.0" +pydantic = "^2.1.0" python = "^3.9" requests = "^2.25.0" -pydantic = "^2.1.0" typing-extensions = ">=4.7.1" -annotated-types = ">=0.7.0" + +[tool.poetry.group.test.dependencies] +expects = ">=0.9.0" +mockito = ">=1.5.1" +pytest = ">=7.4.0" +pytest-asyncio = ">=0.23.0" [tool.poetry.extras] cli = ["click"] diff --git a/tests/auth/__init__.py b/tests/auth/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/auth/test_confidential_client.py b/tests/auth/test_confidential_client.py new file mode 100644 index 00000000..c42e771d --- /dev/null +++ b/tests/auth/test_confidential_client.py @@ -0,0 +1,162 @@ +# Copyright 2024 Palantir Technologies, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest +import requests +from expects import equal +from expects import expect +from expects import raise_error +from mockito import mock +from mockito import spy +from mockito import unstub +from mockito import verify +from mockito import when + +from foundry._core.auth_utils import Token +from foundry._core.confidential_client_auth import ConfidentialClientAuth +from foundry._core.oauth import SignInResponse +from foundry._errors.not_authenticated import NotAuthenticated + + +def test_confidential_client_instantiate(): + auth = ConfidentialClientAuth( + client_id="client_id", + client_secret="client_secret", + hostname="https://a.b.c.com", + should_refresh=True, + ) + assert auth._client_id == "client_id" + assert auth._client_secret == "client_secret" + assert auth._hostname == "https://a.b.c.com" + assert auth._token == None + assert auth.url == "a.b.c.com" + assert auth._should_refresh == True + + +@pytest.mark.asyncio +async def test_confidential_client_sign_in_as_service_user(): + auth = ConfidentialClientAuth( + client_id="client_id", + client_secret="client_secret", + hostname="https://a.b.c.com", + should_refresh=True, + ) + token = mock(Token) + token.access_token = "token" + token.expires_in = 3600 + when(auth._server_oauth_flow_provider).get_token().thenReturn(token) + expect(auth.sign_in_as_service_user()).to( + equal(SignInResponse(session={"accessToken": "token", "expiresIn": 3600})) + ) + expect(auth.get_token()).to(equal(token)) + unstub() + + +def test_confidential_client_get_token(): + auth = ConfidentialClientAuth( + client_id="client_id", client_secret="client_secret", hostname="https://a.b.c.com" + ) + token = mock(Token) + token.access_token = "token" + token.expires_in = 3600 + when(auth._server_oauth_flow_provider).get_token().thenReturn(token) + auth.sign_in_as_service_user() + expect(auth.get_token()).to(equal(token)) + unstub() + + +def test_confidential_client_sign_out(): + auth = ConfidentialClientAuth( + client_id="client_id", + client_secret="client_secret", + hostname="https://a.b.c.com", + should_refresh=True, + ) + token = mock(Token) + token.access_token = "access_token" + token.expires_in = 3600 + auth._token = token + when(auth._server_oauth_flow_provider).revoke_token("access_token").thenReturn(None) + auth.sign_out() + expect(auth._token).to(equal(None)) + expect(auth._stop_refresh_event._flag).to(equal(True)) + unstub() + + +def test_confidential_client_get_token_throws_if_not_signed_in(): + # pylint: disable=unnecessary-lambda + auth = ConfidentialClientAuth( + client_id="client_id", client_secret="client_secret", hostname="https://a.b.c.com" + ) + expect(lambda: auth.get_token()).to(raise_error(NotAuthenticated)) + + +def test_confidential_client_execute_with_token_successful_method(): + auth = ConfidentialClientAuth( + client_id="client_id", client_secret="client_secret", hostname="https://a.b.c.com" + ) + token = mock(Token) + token.access_token = "token" + token.expires_in = 3600 + auth._token = token + auth = spy(auth) + expect(auth.execute_with_token(lambda _: "success")).to(equal("success")) + verify(auth, times=0)._refresh_token() + + +def test_confidential_client_execute_with_token_failing_method(): + auth = ConfidentialClientAuth( + client_id="client_id", client_secret="client_secret", hostname="https://a.b.c.com" + ) + token = mock(Token) + token.access_token = "token" + token.expires_in = 3600 + auth._token = token + when(auth).sign_out().thenReturn(None) + + def raise_(ex): + raise ex + + expect(lambda: auth.execute_with_token(lambda _: raise_(ValueError("Oops!")))).to( + raise_error(ValueError) + ) + verify(auth, times=0)._refresh_token() + verify(auth, times=0).sign_out() + unstub() + + +def test_confidential_client_execute_with_token_method_raises_401(): + auth = ConfidentialClientAuth( + client_id="client_id", client_secret="client_secret", hostname="https://a.b.c.com" + ) + token = mock(Token) + token.access_token = "access_token" + token.expires_in = 3600 + auth._token = token + when(auth).sign_out().thenReturn(None) + when(auth)._refresh_token().thenReturn(token) + + def raise_401(): + e = requests.HTTPError() + e.response = requests.Response() + e.response.status_code = 401 + raise e + + expect(lambda: auth.execute_with_token(lambda _: raise_401())).to( + raise_error(requests.HTTPError) + ) + verify(auth, times=1)._refresh_token() + verify(auth, times=1).sign_out() + unstub() diff --git a/tests/auth/test_confidential_client_oauth_flow_provider.py b/tests/auth/test_confidential_client_oauth_flow_provider.py new file mode 100644 index 00000000..d28b6f04 --- /dev/null +++ b/tests/auth/test_confidential_client_oauth_flow_provider.py @@ -0,0 +1,109 @@ +# Copyright 2024 Palantir Technologies, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest +import requests +from expects import equal +from expects import expect +from expects import raise_error +from mockito import mock +from mockito import unstub +from mockito import when +from requests import HTTPError + +from foundry._core.oauth_utils import ConfidentialClientOAuthFlowProvider +from foundry._core.oauth_utils import OAuthUtils + + +@pytest.fixture(name="client", scope="module") +def instantiate_server_oauth_flow_provider(): + return ConfidentialClientOAuthFlowProvider( + client_id="client_id", + client_secret="client_secret", + url="https://a.b.c", + multipass_context_path="/multipass", + scopes=["scope1", "scope2"], + ) + + +def test_get_token(client): + import foundry._core.oauth_utils as module_under_test + + when(ConfidentialClientOAuthFlowProvider).get_scopes().thenReturn(["scope1", "scope2"]) + when(OAuthUtils).get_token_uri("https://a.b.c", "/multipass").thenReturn("token_url") + response = mock(requests.Response) + response.ok = True + when(response).raise_for_status().thenReturn(None) + when(response).json().thenReturn( + {"access_token": "example_token", "expires_in": 42, "token_type": "Bearer"} + ) + when(module_under_test.requests).post( + "token_url", + data={ + "client_id": "client_id", + "client_secret": "client_secret", + "grant_type": "client_credentials", + "scope": "scope1 scope2", + }, + ).thenReturn(response) + token = client.get_token() + expect(token.access_token).to(equal("example_token")) + expect(token.token_type).to(equal("Bearer")) + unstub() + + +def test_get_token_throws_when_unsuccessful(client): + # pylint: disable=unnecessary-lambda + import foundry._core.oauth_utils as module_under_test + + when(ConfidentialClientOAuthFlowProvider).get_scopes().thenReturn( + ["scope1", "scope2", "offline_access"] + ) + when(OAuthUtils).get_token_uri("https://a.b.c", "/multipass").thenReturn("token_url") + response = mock(requests.Response) + when(response).raise_for_status().thenRaise(HTTPError) + when(module_under_test.requests).post( + "token_url", + data={ + "client_id": "client_id", + "client_secret": "client_secret", + "grant_type": "client_credentials", + "scope": "scope1 scope2 offline_access", + }, + ).thenReturn(response) + expect(lambda: client.get_token()).to(raise_error(HTTPError)) + unstub() + + +def test_revoke_token(client): + import foundry._core.oauth_utils as module_under_test + + when(OAuthUtils).get_revoke_uri("https://a.b.c", "/multipass").thenReturn("revoke_url") + response = mock(requests.Response) + when(response).raise_for_status().thenReturn(None) + when(module_under_test.requests).post( + "revoke_url", + data={ + "client_id": "client_id", + "client_secret": "client_secret", + "token": "token_to_be_revoked", + }, + ).thenReturn(response) + client.revoke_token("token_to_be_revoked") + unstub() + + +def test_get_scopes(client): + expect(client.get_scopes()).to(equal(["scope1", "scope2", "offline_access"])) diff --git a/tests/auth/test_foundry_auth_token_client.py b/tests/auth/test_foundry_auth_token_client.py new file mode 100644 index 00000000..d616bfa1 --- /dev/null +++ b/tests/auth/test_foundry_auth_token_client.py @@ -0,0 +1,82 @@ +# Copyright 2024 Palantir Technologies, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os + +import pytest + +from foundry import UserTokenAuth + + +@pytest.fixture +def temp_os_environ(): + old_environ = os.environ.copy() + + # Make sure to start with a clean slate + for key in ["PALANTIR_HOSTNAME", "PALANTIR_TOKEN"]: + if key in os.environ: + os.environ.pop(key) + + yield + os.environ = old_environ + + +@pytest.mark.skip +def test_load_from_env(temp_os_environ): + os.environ["PALANTIR_HOSTNAME"] = "host_test" + os.environ["PALANTIR_TOKEN"] = "token_test" + config = UserTokenAuth() # type: ignore + assert config._hostname == "host_test" + assert config._token == "token_test" + + +@pytest.mark.skip +def test_load_from_env_missing_token(temp_os_environ): + os.environ["PALANTIR_HOSTNAME"] = "host_test" + assert pytest.raises(ValueError, lambda: UserTokenAuth()) # type: ignore + + +@pytest.mark.skip +def test_load_from_env_missing_host(temp_os_environ): + os.environ["PALANTIR_TOKEN"] = "token_test" + assert pytest.raises(ValueError, lambda: UserTokenAuth()) # type: ignore + + +@pytest.mark.skip +def test_can_pass_config(): + os.environ["PALANTIR_HOSTNAME"] = "host_test" + os.environ["PALANTIR_TOKEN"] = "token_test" + config = UserTokenAuth(hostname="host_test2", token="token_test2") + assert config.hostname == "host_test2" # type: ignore + assert config._token == "token_test2" + + +def test_can_pass_config_missing_token(): + assert pytest.raises(TypeError, lambda: UserTokenAuth(hostname="test")) # type: ignore + + +def test_can_pass_config_missing_host(): + assert pytest.raises(TypeError, lambda: UserTokenAuth(token="test")) # type: ignore + + +@pytest.mark.skip +def test_checks_host_type(): + assert pytest.raises(ValueError, lambda: UserTokenAuth(hostname=1)) # type: ignore + + +@pytest.mark.skip +def test_checks_token_type(): + assert pytest.raises(ValueError, lambda: UserTokenAuth(token=1)) # type: ignore + assert pytest.raises(ValueError, lambda: UserTokenAuth(token=1)) # type: ignore diff --git a/tests/auth/test_foundry_token_oauth_client.py b/tests/auth/test_foundry_token_oauth_client.py new file mode 100644 index 00000000..8673397d --- /dev/null +++ b/tests/auth/test_foundry_token_oauth_client.py @@ -0,0 +1,37 @@ +# Copyright 2024 Palantir Technologies, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from foundry import ConfidentialClientAuth +from foundry._errors.not_authenticated import NotAuthenticated + + +def test_can_pass_config(): + config = ConfidentialClientAuth( + client_id="123", + client_secret="abc", + hostname="example.com", + scopes=["hello"], + ) + + assert config._hostname == "example.com" # type: ignore + assert config._client_id == "123" # type: ignore + assert config._client_secret == "abc" # type: ignore + + with pytest.raises(NotAuthenticated) as info: + config.get_token() + + assert str(info.value) == "Client has not been authenticated." diff --git a/tests/auth/test_oauth_utils.py b/tests/auth/test_oauth_utils.py new file mode 100644 index 00000000..e9d68b7d --- /dev/null +++ b/tests/auth/test_oauth_utils.py @@ -0,0 +1,65 @@ +# Copyright 2024 Palantir Technologies, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from expects import equal +from expects import expect +from mockito import unstub +from mockito import when + +from foundry._core.oauth_utils import OAuthToken +from foundry._core.oauth_utils import OAuthTokenResponse +from foundry._core.oauth_utils import OAuthUtils + + +def test_get_token_uri(): + expect(OAuthUtils.get_token_uri("a.b.c")).to(equal("https://a.b.c/multipass/api/oauth2/token")) + + +def test_get_authorize_uri(): + expect(OAuthUtils.get_authorize_uri("a.b.c")).to( + equal("https://a.b.c/multipass/api/oauth2/authorize") + ) + + +def test_get_revoke_uri(): + expect(OAuthUtils.get_revoke_uri("a.b.c")).to( + equal("https://a.b.c/multipass/api/oauth2/revoke_token") + ) + + +def test_create_uri(): + expect(OAuthUtils.create_uri("a.b.c", "/api/v2/datasets", "/abc")).to( + equal("https://a.b.c/api/v2/datasets/abc") + ) + expect(OAuthUtils.create_uri("https://a.b.c", "/api/v2/datasets", "/abc")).to( + equal("https://a.b.c/api/v2/datasets/abc") + ) + + +def test_token_from_dict(): + import foundry._core.oauth_utils as module_under_test + + when(module_under_test.time).time().thenReturn(123) + token = OAuthToken( + OAuthTokenResponse( + {"access_token": "example_token", "expires_in": 42, "token_type": "Bearer"} + ) + ) + expect(token.access_token).to(equal("example_token")) + expect(token.token_type).to(equal("Bearer")) + expect(token.expires_in).to(equal(42)) + expect(token.expires_at).to(equal(123 * 1000 + 42 * 1000)) + expect(token._calculate_expiration()).to(equal(123 * 1000 + 42 * 1000)) + unstub() diff --git a/tests/auth/test_public_client.py b/tests/auth/test_public_client.py new file mode 100644 index 00000000..281fd7c1 --- /dev/null +++ b/tests/auth/test_public_client.py @@ -0,0 +1,193 @@ +# Copyright 2024 Palantir Technologies, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest +import requests +from expects import equal +from expects import expect +from expects import raise_error +from mockito import mock +from mockito import spy +from mockito import unstub +from mockito import verify +from mockito import when + +from foundry._core.auth_utils import Token +from foundry._core.oauth_utils import AuthorizeRequest +from foundry._core.public_client_auth import PublicClientAuth +from foundry._errors.not_authenticated import NotAuthenticated + + +def test_public_client_instantiate(): + auth = PublicClientAuth( + client_id="client_id", + redirect_url="redirect_url", + hostname="https://a.b.c.com", + should_refresh=True, + ) + assert auth._client_id == "client_id" + assert auth._redirect_url == "redirect_url" + assert auth._hostname == "https://a.b.c.com" + assert auth._token == None + assert auth.url == "a.b.c.com" + assert auth._should_refresh == True + + +@pytest.mark.asyncio +async def test_public_client_sign_in(): + auth = PublicClientAuth( + client_id="client_id", + redirect_url="redirect_url", + hostname="https://a.b.c.com", + should_refresh=True, + ) + auth_request = mock(AuthorizeRequest) + auth_request.url = "auth_request url" + auth_request.state = "random string" + auth_request.code_verifier = "random string" + when(auth._server_oauth_flow_provider).generate_auth_request().thenReturn(auth_request) + + expect(auth.sign_in()).to(equal("auth_request url")) + expect(auth._auth_request).to(equal(auth_request)) + unstub() + + +@pytest.mark.asyncio +async def test_public_client_set_token(): + auth = PublicClientAuth( + client_id="client_id", + redirect_url="redirect_url", + hostname="https://a.b.c.com", + should_refresh=True, + ) + auth._auth_request = AuthorizeRequest(url="", state="", code_verifier="") + token = mock(Token) + token.access_token = "access_token" + token.expires_in = 3600 + when(auth._server_oauth_flow_provider).get_token(code="", code_verifier="").thenReturn(token) + auth.set_token(code="", state="") + expect(auth._token).to(equal(token)) + unstub() + + +def test_public_client_get_token(): + auth = PublicClientAuth( + client_id="client_id", + redirect_url="redirect_url", + hostname="https://a.b.c.com", + should_refresh=True, + ) + token = mock(Token) + token.access_token = "access_token" + token.expires_in = 3600 + auth._token = token + expect(auth.get_token()).to(equal(token)) + + +def test_public_client_sign_out(): + auth = PublicClientAuth( + client_id="client_id", + redirect_url="redirect_url", + hostname="https://a.b.c.com", + should_refresh=True, + ) + token = mock(Token) + token.access_token = "access_token" + token.expires_in = 3600 + auth._token = token + when(auth._server_oauth_flow_provider).revoke_token("access_token").thenReturn(None) + auth.sign_out() + expect(auth._token).to(equal(None)) + expect(auth._stop_refresh_event._flag).to(equal(True)) + unstub() + + +def test_public_client_get_token_throws_if_not_signed_in(): + # pylint: disable=unnecessary-lambda + auth = PublicClientAuth( + client_id="client_id", + redirect_url="redirect_url", + hostname="https://a.b.c.com", + should_refresh=True, + ) + expect(lambda: auth.get_token()).to( + raise_error(NotAuthenticated, "Client has not been authenticated.") + ) + + +def test_public_client_execute_with_token_successful_method(): + auth = PublicClientAuth( + client_id="client_id", + redirect_url="redirect_url", + hostname="https://a.b.c.com", + should_refresh=True, + ) + token = mock(Token) + token.access_token = "token" + token.expires_in = 3600 + auth._token = token + auth = spy(auth) + expect(auth.execute_with_token(lambda _: "success")).to(equal("success")) + verify(auth, times=0)._refresh_token() + + +def test_public_client_execute_with_token_failing_method(): + auth = PublicClientAuth( + client_id="client_id", + redirect_url="redirect_url", + hostname="https://a.b.c.com", + should_refresh=True, + ) + token = mock(Token) + token.access_token = "token" + token.expires_in = 3600 + auth._token = token + when(auth).sign_out().thenReturn(None) + + def raise_(ex): + raise ex + + expect(lambda: auth.execute_with_token(lambda _: raise_(ValueError("Oops!")))).to( + raise_error(ValueError) + ) + verify(auth, times=0)._refresh_token() + unstub() + + +def _test_public_client_execute_with_token_method_raises_401(): + auth = PublicClientAuth( + client_id="client_id", + redirect_url="redirect_url", + hostname="https://a.b.c.com", + should_refresh=True, + ) + token = mock(Token) + token.access_token = "access_token" + token.expires_in = 3600 + auth._token = token + when(auth).sign_out().thenReturn(None) + when(auth)._refresh_token().thenReturn(token) + + def raise_401(): + e = requests.HTTPError() + e.response = requests.Response() + e.response.status_code = 401 + raise e + + expect(lambda: auth.execute_with_token(lambda _: raise_401())).to( + raise_error(requests.HTTPError) + ) + verify(auth, times=1)._refresh_token() + unstub() diff --git a/tests/auth/test_public_client_oauth_flow_provider.py b/tests/auth/test_public_client_oauth_flow_provider.py new file mode 100644 index 00000000..b2a3753b --- /dev/null +++ b/tests/auth/test_public_client_oauth_flow_provider.py @@ -0,0 +1,143 @@ +# Copyright 2024 Palantir Technologies, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest +import requests +from expects import equal +from expects import expect +from expects import raise_error +from mockito import mock +from mockito import unstub +from mockito import when +from requests import HTTPError + +from foundry._core.oauth_utils import OAuthUtils +from foundry._core.oauth_utils import PublicClientOAuthFlowProvider + + +@pytest.fixture(name="client", scope="module") +def instantiate_server_oauth_flow_provider(): + return PublicClientOAuthFlowProvider( + client_id="client_id", + redirect_url="redirect_url", + url="https://a.b.c", + multipass_context_path="/multipass", + scopes=["scope1", "scope2"], + ) + + +def test_get_token(client): + import foundry._core.oauth_utils as module_under_test + + when(PublicClientOAuthFlowProvider).get_scopes().thenReturn(["scope1", "scope2"]) + when(OAuthUtils).get_token_uri("https://a.b.c", "/multipass").thenReturn("token_url") + response = mock(requests.Response) + response.ok = True + when(response).raise_for_status().thenReturn(None) + when(response).json().thenReturn( + {"access_token": "example_token", "expires_in": 42, "token_type": "Bearer"} + ) + + headers = {"Content-Type": "application/x-www-form-urlencoded"} + params = { + "grant_type": "authorization_code", + "code": "code", + "redirect_uri": "redirect_url", + "client_id": "client_id", + "code_verifier": "code_verifier", + "scope": "scope1 scope2", + } + + when(module_under_test.requests).post("token_url", data=params, headers=headers).thenReturn( + response + ) + token = client.get_token(code="code", code_verifier="code_verifier") + expect(token.access_token).to(equal("example_token")) + expect(token.token_type).to(equal("Bearer")) + unstub() + + +def test_get_token_throws_when_unsuccessful(client): + # pylint: disable=unnecessary-lambda + import foundry._core.oauth_utils as module_under_test + + when(PublicClientOAuthFlowProvider).get_scopes().thenReturn( + ["scope1", "scope2", "offline_access"] + ) + when(OAuthUtils).get_token_uri("https://a.b.c", "/multipass").thenReturn("token_url") + response = mock(requests.Response) + when(response).raise_for_status().thenRaise(HTTPError) + + headers = {"Content-Type": "application/x-www-form-urlencoded"} + params = { + "grant_type": "authorization_code", + "code": "code", + "redirect_uri": "redirect_url", + "client_id": "client_id", + "code_verifier": "code_verifier", + "scope": "scope1 scope2 offline_access", + } + + when(module_under_test.requests).post("token_url", data=params, headers=headers).thenReturn( + response + ) + expect(lambda: client.get_token(code="code", code_verifier="code_verifier")).to( + raise_error(HTTPError) + ) + unstub() + + +def test_refresh_token(client): + import foundry._core.oauth_utils as module_under_test + + when(OAuthUtils).get_token_uri("https://a.b.c", "/multipass").thenReturn("token_url") + response = mock(requests.Response) + response.ok = True + when(response).raise_for_status().thenReturn(None) + when(response).json().thenReturn( + {"access_token": "example_token", "expires_in": 42, "token_type": "Bearer"} + ) + + headers = {"Content-Type": "application/x-www-form-urlencoded"} + params = { + "grant_type": "refresh_token", + "client_id": "client_id", + "refresh_token": "refresh_token", + } + + when(module_under_test.requests).post("token_url", data=params, headers=headers).thenReturn( + response + ) + token = client.refresh_token(refresh_token="refresh_token") + expect(token.access_token).to(equal("example_token")) + expect(token.token_type).to(equal("Bearer")) + unstub() + + +def test_revoke_token(client): + import foundry._core.oauth_utils as module_under_test + + when(OAuthUtils).get_revoke_uri("https://a.b.c", "/multipass").thenReturn("revoke_url") + response = mock(requests.Response) + when(response).raise_for_status().thenReturn(None) + when(module_under_test.requests).post( + "revoke_url", data={"client_id": "client_id", "token": "token_to_be_revoked"} + ).thenReturn(response) + client.revoke_token("token_to_be_revoked") + unstub() + + +def test_get_scopes(client): + expect(client.get_scopes()).to(equal(["scope1", "scope2", "offline_access"])) diff --git a/tox.ini b/tox.ini index 2a22c23e..82f903fd 100644 --- a/tox.ini +++ b/tox.ini @@ -6,8 +6,6 @@ envlist = py{39,310,311,312}-pydantic{2.1.0,2.1,2.2,2.3,2.4,2.5}-requests{2.25,2 setenv = PYTHONPATH = {toxinidir} deps = - pytest - typing-extensions >= 4.7.1 pydantic{2.1.0}: pydantic==2.1.0 pydantic{2.1}: pydantic==2.1.* pydantic{2.2}: pydantic==2.2.* @@ -17,8 +15,11 @@ deps = requests{2.25}: requests==2.25.* requests{2.26}: requests==2.26.* requests{2.31}: requests==2.31.* +allowlist_externals = poetry +commands_pre = + poetry install commands = - pytest tests/ + poetry run pytest tests/ [testenv:pyright] deps =