Skip to content

Commit

Permalink
Updating Errors and Auth (#44)
Browse files Browse the repository at this point in the history
  • Loading branch information
jameszhang244 authored Oct 10, 2024
1 parent 9d5c3e3 commit 659a634
Show file tree
Hide file tree
Showing 20 changed files with 885 additions and 82 deletions.
3 changes: 2 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ jobs:
steps:
- checkout
- run: pip install --user tox
- run: tox -e py<< parameters.python_version >>-pydantic<< parameters.pydantic_version >>-requests<< parameters.requests_version >>
- run: poetry --no-ansi install --no-root --sync
- run: poetry --no-ansi run tox -v -e py<< parameters.python_version >>-pydantic<< parameters.pydantic_version >>-requests<< parameters.requests_version >> --recreate

pyright:
docker:
Expand Down
5 changes: 5 additions & 0 deletions changelog/@unreleased/pr-44.v2.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
type: improvement
improvement:
description: Updating Errors and Auth
links:
- https://github.com/palantir/foundry-platform-python/pull/44
55 changes: 33 additions & 22 deletions foundry/_core/confidential_client_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# limitations under the License.


import asyncio
import threading
import time
from typing import Callable
from typing import List
from typing import Optional
Expand All @@ -27,7 +28,6 @@
from foundry._core.oauth_utils import ConfidentialClientOAuthFlowProvider
from foundry._core.oauth_utils import OAuthToken
from foundry._core.utils import remove_prefixes
from foundry._errors.environment_not_configured import EnvironmentNotConfigured
from foundry._errors.not_authenticated import NotAuthenticated

T = TypeVar("T")
Expand Down Expand Up @@ -56,7 +56,7 @@ def __init__(
self._client_secret = client_secret
self._token: Optional[OAuthToken] = None
self._should_refresh = should_refresh
self._refresh_task: Optional[asyncio.Task] = None
self._stop_refresh_event = threading.Event()
self._hostname = hostname
self._server_oauth_flow_provider = ConfidentialClientOAuthFlowProvider(
client_id, client_secret, self.url, scopes=scopes
Expand All @@ -70,15 +70,21 @@ def get_token(self) -> OAuthToken:
def execute_with_token(self, func: Callable[[OAuthToken], T]) -> T:
try:
return self._run_with_attempted_refresh(func)
except requests.HTTPError as http_e:
if http_e.response.status_code == 401:
self.sign_out()
raise http_e
except Exception as e:
self.sign_out()
raise e

def run_with_token(self, func: Callable[[OAuthToken], T]) -> None:
try:
self._run_with_attempted_refresh(func)
except requests.HTTPError as http_e:
if http_e.response.status_code == 401:
self.sign_out()
raise http_e
except Exception as e:
self.sign_out()
raise e

def _run_with_attempted_refresh(self, func: Callable[[OAuthToken], T]) -> T:
Expand All @@ -89,45 +95,50 @@ def _run_with_attempted_refresh(self, func: Callable[[OAuthToken], T]) -> T:
try:
return func(self.get_token())
except requests.HTTPError as e:
if e.response is not None and e.response.status_code == 401:
if e.response.status_code == 401:
self._refresh_token()
return func(self.get_token())
else:
raise e

@property
def url(self):
def url(self) -> str:
return remove_prefixes(self._hostname, ["https://", "http://"])

def _refresh_token(self):
def _refresh_token(self) -> None:
self._token = self._server_oauth_flow_provider.get_token()

def _start_auto_refresh(self) -> None:
def _auto_refresh_token() -> None:
while not self._stop_refresh_event.is_set():
if self._token:
# Sleep for (expires_in - 60) seconds to refresh the token 1 minute before it expires
time.sleep(self._token.expires_in - 60)
self._refresh_token()
else:
# Wait 10 seconds and check again if the token is set
time.sleep(10)

refresh_thread = threading.Thread(target=_auto_refresh_token, daemon=True)
refresh_thread.start()

def sign_in_as_service_user(self) -> SignInResponse:
token = self._server_oauth_flow_provider.get_token()
self._token = token

async def refresh_token_task():
while True:
if self._token is None:
raise RuntimeError("The token was None when trying to refresh.")

await asyncio.sleep(self._token.expires_in / 60 - 10)
self._token = self._server_oauth_flow_provider.get_token()

if self._should_refresh:
loop = asyncio.get_event_loop()
self._refresh_task = loop.create_task(refresh_token_task())
self._start_auto_refresh()
return SignInResponse(
session={"accessToken": token.access_token, "expiresIn": token.expires_in}
)

def sign_out(self) -> SignOutResponse:
if self._refresh_task:
self._refresh_task.cancel()
self._refresh_task = None

if self._token:
self._server_oauth_flow_provider.revoke_token(self._token.access_token)

self._token = None

# Signal the auto-refresh thread to stop
self._stop_refresh_event.set()

return SignOutResponse()
5 changes: 4 additions & 1 deletion foundry/_core/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
# limitations under the License.


from typing import Any
from typing import Dict

from pydantic import BaseModel


class SignInResponse(BaseModel):
session: dict
session: Dict[str, Any]


class SignOutResponse(BaseModel):
Expand Down
10 changes: 6 additions & 4 deletions foundry/_core/oauth_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import secrets
import string
import time
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from urllib.parse import urlencode
Expand Down Expand Up @@ -68,7 +70,7 @@ class OAuthTokenResponse(BaseModel):
expires_in: int
refresh_token: Optional[str] = None

def __init__(self, token_response: dict) -> None:
def __init__(self, token_response: Dict[str, Any]) -> None:
super().__init__(**token_response)


Expand Down Expand Up @@ -167,13 +169,13 @@ def get_scopes(self) -> List[str]:
return scopes


def generate_random_string(min_length=43, max_length=128):
def generate_random_string(min_length: int = 43, max_length: int = 128) -> str:
characters = string.ascii_letters + string.digits + "-._~"
length = secrets.randbelow(max_length - min_length + 1) + min_length
return "".join(secrets.choice(characters) for _ in range(length))


def generate_code_challenge(input_string):
def generate_code_challenge(input_string: str) -> str:
# Calculate the SHA256 hash
sha256_hash = hashlib.sha256(input_string.encode("utf-8")).digest()

Expand Down Expand Up @@ -249,7 +251,7 @@ def get_token(self, code: str, code_verifier: str) -> OAuthToken:
response.raise_for_status()
return OAuthToken(token=OAuthTokenResponse(token_response=response.json()))

def refresh_token(self, refresh_token):
def refresh_token(self, refresh_token: str) -> OAuthToken:
headers = {"Content-Type": "application/x-www-form-urlencoded"}
params = {
"grant_type": "refresh_token",
Expand Down
42 changes: 23 additions & 19 deletions foundry/_core/public_client_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,21 @@


class PublicClientAuth(Auth):
scopes: List[str] = ["api:read-data", "api:write-data", "offline_access"]

"""
Client for Public Client OAuth-authenticated Ontology applications.
Runs a background thread to periodically refresh access token.
:param client_id: OAuth client id to be used by the application.
:param client_secret: OAuth client secret to be used by the application.
:param hostname: Hostname for authentication and ontology endpoints.
"""

def __init__(
self, client_id: str, redirect_url: str, hostname: str, should_refresh: bool = False
self,
client_id: str,
redirect_url: str,
hostname: str,
scopes: Optional[List[str]] = None,
should_refresh: bool = False,
) -> None:
self._client_id = client_id
self._redirect_url = redirect_url
Expand All @@ -58,7 +60,7 @@ def __init__(
self._stop_refresh_event = threading.Event()
self._hostname = hostname
self._server_oauth_flow_provider = PublicClientOAuthFlowProvider(
client_id=client_id, redirect_url=redirect_url, url=self.url, scopes=self.scopes
client_id=client_id, redirect_url=redirect_url, url=self.url, scopes=scopes
)
self._auth_request: Optional[AuthorizeRequest] = None

Expand All @@ -81,9 +83,11 @@ def run_with_token(self, func: Callable[[OAuthToken], T]) -> None:
self.sign_out()
raise e

def _refresh_token(self):
if self._token is None:
raise Exception("")
def _refresh_token(self) -> None:
if not self._token:
raise RuntimeError("must have token to refresh")
if not self._token.refresh_token:
raise RuntimeError("no refresh token provided")

self._token = self._server_oauth_flow_provider.refresh_token(
refresh_token=self._token.refresh_token
Expand All @@ -92,30 +96,29 @@ def _refresh_token(self):
def _run_with_attempted_refresh(self, func: Callable[[OAuthToken], T]) -> T:
"""
Attempt to run func, and if it fails with a 401, refresh the token and try again.
If it fails with a 401 again, raise the exception.
"""
try:
return func(self.get_token())
except requests.HTTPError as e:
if e.response is not None and e.response.status_code == 401:
if e.response.status_code == 401:
self._refresh_token()
return func(self.get_token())
else:
raise e

@property
def url(self):
def url(self) -> str:
return remove_prefixes(self._hostname, ["https://", "http://"])

def sign_in(self) -> None:
def sign_in(self) -> str:
self._auth_request = self._server_oauth_flow_provider.generate_auth_request()
webbrowser.open(self._auth_request.url)
return self._auth_request.url

def _start_auto_refresh(self):
def _auto_refresh_token():
def _start_auto_refresh(self) -> None:
def _auto_refresh_token() -> None:
while not self._stop_refresh_event.is_set():
if self._token:
if self._token and self._token.refresh_token:
# Sleep for (expires_in - 60) seconds to refresh the token 1 minute before it expires
time.sleep(self._token.expires_in - 60)
self._token = self._server_oauth_flow_provider.refresh_token(
Expand All @@ -129,9 +132,10 @@ def _auto_refresh_token():
refresh_thread.start()

def set_token(self, code: str, state: str) -> None:
if self._auth_request is None or state != self._auth_request.state:
raise RuntimeError("Unable to verify the state")

if not self._auth_request:
raise RuntimeError("Must sign in prior to setting token")
if state != self._auth_request.state:
raise RuntimeError("Unable to verify state")
self._token = self._server_oauth_flow_provider.get_token(
code=code, code_verifier=self._auth_request.code_verifier
)
Expand Down
1 change: 0 additions & 1 deletion foundry/_errors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@


from foundry._errors.environment_not_configured import EnvironmentNotConfigured
from foundry._errors.helpers import format_error_message
from foundry._errors.not_authenticated import NotAuthenticated
from foundry._errors.palantir_rpc_exception import PalantirRPCException
from foundry._errors.sdk_internal_error import SDKInternalError
Expand Down
3 changes: 2 additions & 1 deletion foundry/_errors/environment_not_configured.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@


class EnvironmentNotConfigured(Exception):
pass
def __init__(self, message: str) -> None:
super().__init__(message)
24 changes: 0 additions & 24 deletions foundry/_errors/helpers.py

This file was deleted.

11 changes: 7 additions & 4 deletions foundry/_errors/palantir_rpc_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,18 @@
# limitations under the License.


import json
from typing import Any
from typing import Dict

from foundry._errors.helpers import format_error_message

def format_error_message(fields: Dict[str, Any]) -> str:
return json.dumps(fields, sort_keys=True, indent=4, default=str)


class PalantirRPCException(Exception):
def __init__(self, error_metadata: Dict[str, Any]):
super().__init__(format_error_message(error_metadata))
self.name: str = error_metadata["errorName"]
self.parameters: Dict[str, Any] = error_metadata["parameters"]
self.error_instance_id: str = error_metadata["errorInstanceId"]
self.name = error_metadata.get("errorName")
self.parameters = error_metadata.get("parameters")
self.error_instance_id = error_metadata.get("errorInstanceId")
10 changes: 8 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,17 @@ keywords = ["Palantir", "Foundry", "SDK", "Client", "API"]
packages = [{ include = "foundry" }]

[tool.poetry.dependencies]
annotated-types = ">=0.7.0"
pydantic = "^2.1.0"
python = "^3.9"
requests = "^2.25.0"
pydantic = "^2.1.0"
typing-extensions = ">=4.7.1"
annotated-types = ">=0.7.0"

[tool.poetry.group.test.dependencies]
expects = ">=0.9.0"
mockito = ">=1.5.1"
pytest = ">=7.4.0"
pytest-asyncio = ">=0.23.0"

[tool.poetry.extras]
cli = ["click"]
Expand Down
Empty file added tests/auth/__init__.py
Empty file.
Loading

0 comments on commit 659a634

Please sign in to comment.