From f91d14b25f86dd3f2e4d48229bb53cc7d9b20f1b Mon Sep 17 00:00:00 2001 From: "Justin \"J.R.\" Hill" Date: Wed, 1 Nov 2023 12:57:26 -0700 Subject: [PATCH] refactor(python-sdk): extract oauth2 from credentials --- .github/workflows/semgrep.yaml | 4 +- .openapi-generator-ignore | 3 +- .openapi-generator/FILES | 6 +- openfga_sdk/api/open_fga_api.py | 52 +++++-- openfga_sdk/api_client.py | 33 ++-- openfga_sdk/credentials.py | 57 +------ openfga_sdk/oauth2.py | 79 ++++++++++ openfga_sdk/sync/__init__.py | 1 - openfga_sdk/sync/api_client.py | 34 +++-- openfga_sdk/sync/credentials.py | 150 ------------------ openfga_sdk/sync/oauth2.py | 79 ++++++++++ openfga_sdk/sync/open_fga_api.py | 52 +++++-- test/test_credentials.py | 105 ------------- test/test_credentials_sync.py | 252 ------------------------------- test/test_oauth2.py | 115 ++++++++++++++ test/test_oauth2_sync.py | 116 ++++++++++++++ test/test_open_fga_api_sync.py | 4 +- 17 files changed, 519 insertions(+), 623 deletions(-) create mode 100644 openfga_sdk/oauth2.py delete mode 100644 openfga_sdk/sync/credentials.py create mode 100644 openfga_sdk/sync/oauth2.py delete mode 100644 test/test_credentials_sync.py create mode 100644 test/test_oauth2.py create mode 100644 test/test_oauth2_sync.py diff --git a/.github/workflows/semgrep.yaml b/.github/workflows/semgrep.yaml index 44718e20..92f9ad6a 100644 --- a/.github/workflows/semgrep.yaml +++ b/.github/workflows/semgrep.yaml @@ -11,7 +11,9 @@ jobs: image: returntocorp/semgrep if: (github.actor != 'dependabot[bot]' && github.actor != 'snyk-bot') steps: - - uses: actions/checkout@c85c95e3d7251135ab7dc9ce3241c5835cc595a9 # v3.5.2 + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 + with: + fetch-depth: 0 - run: semgrep ci env: SEMGREP_APP_TOKEN: ${{ secrets.SEMGREP_APP_TOKEN }} diff --git a/.openapi-generator-ignore b/.openapi-generator-ignore index 42da987f..a014ba4a 100644 --- a/.openapi-generator-ignore +++ b/.openapi-generator-ignore @@ -7,7 +7,8 @@ test/* !test/test_client_sync.py !test/test_open_fga_api_sync.py !test/test_validation.py -!test/test_credentials_sync.py +!test/test_oauth2.py +!test/test_oauth2_sync.py .github/workflows/python.yml .gitlab-ci.yml .travis.yml diff --git a/.openapi-generator/FILES b/.openapi-generator/FILES index 0ee44080..a0df92ae 100644 --- a/.openapi-generator/FILES +++ b/.openapi-generator/FILES @@ -146,12 +146,13 @@ openfga_sdk/models/write_assertions_request.py openfga_sdk/models/write_authorization_model_request.py openfga_sdk/models/write_authorization_model_response.py openfga_sdk/models/write_request.py +openfga_sdk/oauth2.py openfga_sdk/rest.py openfga_sdk/sync/__init__.py openfga_sdk/sync/api_client.py openfga_sdk/sync/client/__init__.py openfga_sdk/sync/client/client.py -openfga_sdk/sync/credentials.py +openfga_sdk/sync/oauth2.py openfga_sdk/sync/rest.py openfga_sdk/validation.py requirements.txt @@ -162,7 +163,8 @@ test/__init__.py test/test_client.py test/test_client_sync.py test/test_credentials.py -test/test_credentials_sync.py +test/test_oauth2.py +test/test_oauth2_sync.py test/test_open_fga_api.py test/test_open_fga_api_sync.py test/test_validation.py diff --git a/openfga_sdk/api/open_fga_api.py b/openfga_sdk/api/open_fga_api.py index fda46e1a..8238a4f5 100644 --- a/openfga_sdk/api/open_fga_api.py +++ b/openfga_sdk/api/open_fga_api.py @@ -19,6 +19,7 @@ import six from openfga_sdk.api_client import ApiClient +from openfga_sdk.oauth2 import OAuth2Client from openfga_sdk.exceptions import ( # noqa: F401 FgaValidationException, ApiValueError @@ -37,6 +38,12 @@ def __init__(self, api_client=None): api_client = ApiClient() self.api_client = api_client + self._oauth2_client = None + if api_client.configuration is not None: + credentials = api_client.configuration.credentials + if credentials is not None and credentials.method == 'client_credentials': + self._oauth2_client = OAuth2Client(credentials) + async def __aenter__(self): return self @@ -192,7 +199,8 @@ async def check_with_http_info(self, body, **kwargs): # noqa: E501 _request_timeout=local_var_params.get('_request_timeout'), _retry_params=local_var_params.get('_retry_params'), collection_formats=collection_formats, - _request_auth=local_var_params.get('_request_auth'))) + _request_auth=local_var_params.get('_request_auth'), + _oauth2_client=self._oauth2_client)) async def create_store(self, body, **kwargs): # noqa: E501 """Create a store # noqa: E501 @@ -333,7 +341,8 @@ async def create_store_with_http_info(self, body, **kwargs): # noqa: E501 _request_timeout=local_var_params.get('_request_timeout'), _retry_params=local_var_params.get('_retry_params'), collection_formats=collection_formats, - _request_auth=local_var_params.get('_request_auth'))) + _request_auth=local_var_params.get('_request_auth'), + _oauth2_client=self._oauth2_client)) async def delete_store(self, **kwargs): # noqa: E501 """Delete a store # noqa: E501 @@ -460,7 +469,8 @@ async def delete_store_with_http_info(self, **kwargs): # noqa: E501 _request_timeout=local_var_params.get('_request_timeout'), _retry_params=local_var_params.get('_retry_params'), collection_formats=collection_formats, - _request_auth=local_var_params.get('_request_auth'))) + _request_auth=local_var_params.get('_request_auth'), + _oauth2_client=self._oauth2_client)) async def expand(self, body, **kwargs): # noqa: E501 """Expand all relationships in userset tree format, and following userset rewrite rules. Useful to reason about and debug a certain relationship # noqa: E501 @@ -608,7 +618,8 @@ async def expand_with_http_info(self, body, **kwargs): # noqa: E501 _request_timeout=local_var_params.get('_request_timeout'), _retry_params=local_var_params.get('_retry_params'), collection_formats=collection_formats, - _request_auth=local_var_params.get('_request_auth'))) + _request_auth=local_var_params.get('_request_auth'), + _oauth2_client=self._oauth2_client)) async def get_store(self, **kwargs): # noqa: E501 """Get a store # noqa: E501 @@ -740,7 +751,8 @@ async def get_store_with_http_info(self, **kwargs): # noqa: E501 _request_timeout=local_var_params.get('_request_timeout'), _retry_params=local_var_params.get('_retry_params'), collection_formats=collection_formats, - _request_auth=local_var_params.get('_request_auth'))) + _request_auth=local_var_params.get('_request_auth'), + _oauth2_client=self._oauth2_client)) async def list_objects(self, body, **kwargs): # noqa: E501 """List all objects of the given type that the user has a relation with # noqa: E501 @@ -888,7 +900,8 @@ async def list_objects_with_http_info(self, body, **kwargs): # noqa: E501 _request_timeout=local_var_params.get('_request_timeout'), _retry_params=local_var_params.get('_retry_params'), collection_formats=collection_formats, - _request_auth=local_var_params.get('_request_auth'))) + _request_auth=local_var_params.get('_request_auth'), + _oauth2_client=self._oauth2_client)) async def list_stores(self, **kwargs): # noqa: E501 """List all stores # noqa: E501 @@ -1030,7 +1043,8 @@ async def list_stores_with_http_info(self, **kwargs): # noqa: E501 _request_timeout=local_var_params.get('_request_timeout'), _retry_params=local_var_params.get('_retry_params'), collection_formats=collection_formats, - _request_auth=local_var_params.get('_request_auth'))) + _request_auth=local_var_params.get('_request_auth'), + _oauth2_client=self._oauth2_client)) async def read(self, body, **kwargs): # noqa: E501 """Get tuples from the store that matches a query, without following userset rewrite rules # noqa: E501 @@ -1178,7 +1192,8 @@ async def read_with_http_info(self, body, **kwargs): # noqa: E501 _request_timeout=local_var_params.get('_request_timeout'), _retry_params=local_var_params.get('_retry_params'), collection_formats=collection_formats, - _request_auth=local_var_params.get('_request_auth'))) + _request_auth=local_var_params.get('_request_auth'), + _oauth2_client=self._oauth2_client)) async def read_assertions(self, authorization_model_id, **kwargs): # noqa: E501 """Read assertions for an authorization model ID # noqa: E501 @@ -1323,7 +1338,8 @@ async def read_assertions_with_http_info(self, authorization_model_id, **kwargs) _request_timeout=local_var_params.get('_request_timeout'), _retry_params=local_var_params.get('_retry_params'), collection_formats=collection_formats, - _request_auth=local_var_params.get('_request_auth'))) + _request_auth=local_var_params.get('_request_auth'), + _oauth2_client=self._oauth2_client)) async def read_authorization_model(self, id, **kwargs): # noqa: E501 """Return a particular version of an authorization model # noqa: E501 @@ -1467,7 +1483,8 @@ async def read_authorization_model_with_http_info(self, id, **kwargs): # noqa: _request_timeout=local_var_params.get('_request_timeout'), _retry_params=local_var_params.get('_retry_params'), collection_formats=collection_formats, - _request_auth=local_var_params.get('_request_auth'))) + _request_auth=local_var_params.get('_request_auth'), + _oauth2_client=self._oauth2_client)) async def read_authorization_models(self, **kwargs): # noqa: E501 """Return all the authorization models for a particular store # noqa: E501 @@ -1613,7 +1630,8 @@ async def read_authorization_models_with_http_info(self, **kwargs): # noqa: E50 _request_timeout=local_var_params.get('_request_timeout'), _retry_params=local_var_params.get('_retry_params'), collection_formats=collection_formats, - _request_auth=local_var_params.get('_request_auth'))) + _request_auth=local_var_params.get('_request_auth'), + _oauth2_client=self._oauth2_client)) async def read_changes(self, **kwargs): # noqa: E501 """Return a list of all the tuple changes # noqa: E501 @@ -1766,7 +1784,8 @@ async def read_changes_with_http_info(self, **kwargs): # noqa: E501 _request_timeout=local_var_params.get('_request_timeout'), _retry_params=local_var_params.get('_retry_params'), collection_formats=collection_formats, - _request_auth=local_var_params.get('_request_auth'))) + _request_auth=local_var_params.get('_request_auth'), + _oauth2_client=self._oauth2_client)) async def write(self, body, **kwargs): # noqa: E501 """Add or delete tuples from the store # noqa: E501 @@ -1914,7 +1933,8 @@ async def write_with_http_info(self, body, **kwargs): # noqa: E501 _request_timeout=local_var_params.get('_request_timeout'), _retry_params=local_var_params.get('_retry_params'), collection_formats=collection_formats, - _request_auth=local_var_params.get('_request_auth'))) + _request_auth=local_var_params.get('_request_auth'), + _oauth2_client=self._oauth2_client)) async def write_assertions(self, authorization_model_id, body, **kwargs): # noqa: E501 """Upsert assertions for an authorization model ID # noqa: E501 @@ -2070,7 +2090,8 @@ async def write_assertions_with_http_info(self, authorization_model_id, body, ** _request_timeout=local_var_params.get('_request_timeout'), _retry_params=local_var_params.get('_retry_params'), collection_formats=collection_formats, - _request_auth=local_var_params.get('_request_auth'))) + _request_auth=local_var_params.get('_request_auth'), + _oauth2_client=self._oauth2_client)) async def write_authorization_model(self, body, **kwargs): # noqa: E501 """Create a new authorization model # noqa: E501 @@ -2218,4 +2239,5 @@ async def write_authorization_model_with_http_info(self, body, **kwargs): # noq _request_timeout=local_var_params.get('_request_timeout'), _retry_params=local_var_params.get('_retry_params'), collection_formats=collection_formats, - _request_auth=local_var_params.get('_request_auth'))) + _request_auth=local_var_params.get('_request_auth'), + _oauth2_client=self._oauth2_client)) diff --git a/openfga_sdk/api_client.py b/openfga_sdk/api_client.py index 1aa49ef8..3aad2533 100644 --- a/openfga_sdk/api_client.py +++ b/openfga_sdk/api_client.py @@ -30,7 +30,7 @@ from openfga_sdk.configuration import Configuration import openfga_sdk.models -from openfga_sdk import rest +from openfga_sdk import rest, oauth2 from openfga_sdk.exceptions import ApiValueError, ApiException, FgaValidationException, RateLimitExceededError @@ -140,7 +140,7 @@ async def __call_api( response_types_map=None, auth_settings=None, _return_http_data_only=None, collection_formats=None, _preload_content=True, _request_timeout=None, _host=None, - _request_auth=None, _retry_params=None): + _request_auth=None, _retry_params=None, _oauth2_client=None): self.configuration.is_valid() config = self.configuration @@ -183,7 +183,7 @@ async def __call_api( # auth setting await self.update_params_for_auth( header_params, query_params, auth_settings, - request_auth=_request_auth) + request_auth=_request_auth, oauth2_client=_oauth2_client) # body if body: @@ -369,7 +369,7 @@ async def call_api(self, resource_path, method, async_req=None, _return_http_data_only=None, collection_formats=None, _preload_content=True, _request_timeout=None, _host=None, _request_auth=None, - _retry_params=None): + _retry_params=None, _oauth2_client=None): """Makes the HTTP request (synchronous) and returns deserialized data. To make an async_req request, set the async_req parameter. @@ -417,7 +417,7 @@ async def call_api(self, resource_path, method, response_types_map, auth_settings, _return_http_data_only, collection_formats, _preload_content, _request_timeout, _host, - _request_auth, _retry_params)) + _request_auth, _retry_params, _oauth2_client)) return self.pool.apply_async(self.__call_api, (resource_path, method, path_params, @@ -430,7 +430,9 @@ async def call_api(self, resource_path, method, collection_formats, _preload_content, _request_timeout, - _host, _request_auth, _retry_params)) + _host, _request_auth, + _retry_params, + _oauth2_client)) async def request(self, method, url, query_params=None, headers=None, post_params=None, body=None, _preload_content=True, @@ -561,7 +563,7 @@ def select_header_content_type(self, content_types, method=None, body=None): return content_types[0] async def update_params_for_auth(self, headers, queries, auth_settings, - request_auth=None): + request_auth=None, oauth2_client=None): """Updates header and query params based on authentication setting. :param headers: Header parameters dict to be updated. @@ -569,11 +571,20 @@ async def update_params_for_auth(self, headers, queries, auth_settings, :param auth_settings: Authentication setting identifiers list. :param request_auth: if set, the provided settings will override the token in the configuration. + :param oauth2_client: if set, will be used for credential exchange. """ - if self.configuration.credentials is not None: - added_headers = await self.configuration.credentials.get_authentication_header(self.rest_client) - for key, value in added_headers.items(): - headers[key] = value + credentials = self.configuration.credentials + if credentials is not None: + if credentials.method == 'none': + pass + if credentials.method == 'api_token': + headers['Authorization'] = 'Bearer {}'.format(credentials.configuration.api_token) + if credentials.method == 'client_credentials': + if oauth2_client is None: + oauth2_client = oauth2.OAuth2Client(credentials) + oauth2_headers = await oauth2_client.get_authentication_header(self.rest_client) + for key, value in oauth2_headers.items(): + headers[key] = value if not auth_settings: return diff --git a/openfga_sdk/credentials.py b/openfga_sdk/credentials.py index 23cbf513..5f528741 100644 --- a/openfga_sdk/credentials.py +++ b/openfga_sdk/credentials.py @@ -11,13 +11,10 @@ """ from dataclasses import dataclass -from datetime import datetime, timedelta -import json import typing -import urllib3 from urllib.parse import urlparse -from openfga_sdk.exceptions import FgaValidationException, ApiValueError, AuthenticationError +from openfga_sdk.exceptions import ApiValueError def none_or_empty(value): @@ -136,8 +133,6 @@ def __init__( ): self._method = method self._configuration = configuration - self._access_token = None - self._access_expiry_time = None @property def method(self): @@ -192,53 +187,3 @@ def validate_credentials_config(self): if (parsed_url.netloc == ''): raise ApiValueError('api_issuer `{}` is invalid'.format( self.configuration.api_issuer)) - - def _token_valid(self): - """ - Return whether token is valid - """ - if self._access_token is None or self._access_expiry_time is None: - return False - if self._access_expiry_time < datetime.now(): - return False - return True - - async def _obtain_token(self, client): - """ - Perform OAuth2 and obtain token - """ - token_url = 'https://{}/oauth/token'.format(self.configuration.api_issuer) - body = { - 'client_id': self.configuration.client_id, - 'client_secret': self.configuration.client_secret, - 'audience': self.configuration.api_audience, - 'grant_type': "client_credentials", - } - headers = urllib3.response.HTTPHeaderDict( - {'Accept': 'application/json', 'Content-Type': 'application/json', 'User-Agent': 'openfga-sdk (python) 0.2.1'}) - raw_response = await client.POST(token_url, headers=headers, body=body) - if 200 <= raw_response.status <= 299: - try: - api_response = json.loads(raw_response.data) - except: # noqa: E722 - raise AuthenticationError(http_resp=raw_response) - if not api_response.get('expires_in') or not api_response.get('access_token'): - raise AuthenticationError(http_resp=raw_response) - self._access_expiry_time = datetime.now() + timedelta(seconds=int(api_response.get('expires_in'))) - self._access_token = api_response.get('access_token') - else: - raise AuthenticationError(http_resp=raw_response) - - async def get_authentication_header(self, client): - """ - If configured, return the header for authentication - """ - if self._method == 'none': - return {} - if self._method == 'api_token': - return {'Authorization': 'Bearer {}'.format(self.configuration.api_token)} - # check to see token is valid - if not self._token_valid(): - # In this case, the token is not valid, we need to get the refresh the token - await self._obtain_token(client) - return {'Authorization': 'Bearer {}'.format(self._access_token)} diff --git a/openfga_sdk/oauth2.py b/openfga_sdk/oauth2.py new file mode 100644 index 00000000..62dd7d2d --- /dev/null +++ b/openfga_sdk/oauth2.py @@ -0,0 +1,79 @@ +""" + Python SDK for OpenFGA + + API version: 0.1 + Website: https://openfga.dev + Documentation: https://openfga.dev/docs + Support: https://discord.gg/8naAwJfWN6 + License: [Apache-2.0](https://github.com/openfga/python-sdk/blob/main/LICENSE) + + NOTE: This file was auto generated by OpenAPI Generator (https://openapi-generator.tech). DO NOT EDIT. +""" + +from dataclasses import dataclass +from datetime import datetime, timedelta +import json +import typing +import urllib3 +from urllib.parse import urlparse + +from openfga_sdk.credentials import Credentials +from openfga_sdk.exceptions import AuthenticationError + + +class OAuth2Client: + + def __init__( + self, + credentials: Credentials + ): + self._credentials = credentials + self._access_token = None + self._access_expiry_time = None + + def _token_valid(self): + """ + Return whether token is valid + """ + if self._access_token is None or self._access_expiry_time is None: + return False + if self._access_expiry_time < datetime.now(): + return False + return True + + async def _obtain_token(self, client): + """ + Perform OAuth2 and obtain token + """ + configuration = self._credentials.configuration + token_url = 'https://{}/oauth/token'.format(configuration.api_issuer) + body = { + 'client_id': configuration.client_id, + 'client_secret': configuration.client_secret, + 'audience': configuration.api_audience, + 'grant_type': "client_credentials", + } + headers = urllib3.response.HTTPHeaderDict( + {'Accept': 'application/json', 'Content-Type': 'application/json', 'User-Agent': 'openfga-sdk (python) 0.2.1'}) + raw_response = await client.POST(token_url, headers=headers, body=body) + if 200 <= raw_response.status <= 299: + try: + api_response = json.loads(raw_response.data) + except: # noqa: E722 + raise AuthenticationError(http_resp=raw_response) + if not api_response.get('expires_in') or not api_response.get('access_token'): + raise AuthenticationError(http_resp=raw_response) + self._access_expiry_time = datetime.now() + timedelta(seconds=int(api_response.get('expires_in'))) + self._access_token = api_response.get('access_token') + else: + raise AuthenticationError(http_resp=raw_response) + + async def get_authentication_header(self, client): + """ + If configured, return the header for authentication + """ + # check to see token is valid + if not self._token_valid(): + # In this case, the token is not valid, we need to get the refresh the token + await self._obtain_token(client) + return {'Authorization': 'Bearer {}'.format(self._access_token)} diff --git a/openfga_sdk/sync/__init__.py b/openfga_sdk/sync/__init__.py index cfc8c33e..00395bad 100644 --- a/openfga_sdk/sync/__init__.py +++ b/openfga_sdk/sync/__init__.py @@ -15,4 +15,3 @@ from openfga_sdk.sync.client.client import OpenFgaClient from openfga_sdk.sync.api_client import ApiClient -from openfga_sdk.sync.credentials import Credentials diff --git a/openfga_sdk/sync/api_client.py b/openfga_sdk/sync/api_client.py index 228f5a18..37f4a070 100644 --- a/openfga_sdk/sync/api_client.py +++ b/openfga_sdk/sync/api_client.py @@ -30,7 +30,7 @@ from openfga_sdk.configuration import Configuration import openfga_sdk.models -from openfga_sdk.sync import rest +from openfga_sdk.sync import rest, oauth2 from openfga_sdk.exceptions import ApiValueError, ApiException, FgaValidationException, RateLimitExceededError @@ -139,7 +139,7 @@ def __call_api( response_types_map=None, auth_settings=None, _return_http_data_only=None, collection_formats=None, _preload_content=True, _request_timeout=None, _host=None, - _request_auth=None, _retry_params=None): + _request_auth=None, _retry_params=None, _oauth2_client=None): self.configuration.is_valid() config = self.configuration @@ -182,7 +182,7 @@ def __call_api( # auth setting self.update_params_for_auth( header_params, query_params, auth_settings, - request_auth=_request_auth) + request_auth=_request_auth, oauth2_client=_oauth2_client) # body if body: @@ -367,7 +367,7 @@ def call_api(self, resource_path, method, async_req=None, _return_http_data_only=None, collection_formats=None, _preload_content=True, _request_timeout=None, _host=None, _request_auth=None, - _retry_params=None): + _retry_params=None, _oauth2_client=None): """Makes the HTTP request (synchronous) and returns deserialized data. To make an async_req request, set the async_req parameter. @@ -415,7 +415,7 @@ def call_api(self, resource_path, method, response_types_map, auth_settings, _return_http_data_only, collection_formats, _preload_content, _request_timeout, _host, - _request_auth, _retry_params) + _request_auth, _retry_params, _oauth2_client) return self.pool.apply_async(self.__call_api, (resource_path, method, path_params, @@ -428,7 +428,9 @@ def call_api(self, resource_path, method, collection_formats, _preload_content, _request_timeout, - _host, _request_auth, _retry_params)) + _host, _request_auth, + _retry_params, + _oauth2_client)) def request(self, method, url, query_params=None, headers=None, post_params=None, body=None, _preload_content=True, @@ -559,7 +561,7 @@ def select_header_content_type(self, content_types, method=None, body=None): return content_types[0] def update_params_for_auth(self, headers, queries, auth_settings, - request_auth=None): + request_auth=None, oauth2_client=None): """Updates header and query params based on authentication setting. :param headers: Header parameters dict to be updated. @@ -567,12 +569,20 @@ def update_params_for_auth(self, headers, queries, auth_settings, :param auth_settings: Authentication setting identifiers list. :param request_auth: if set, the provided settings will override the token in the configuration. + :param oauth2_client: if set, will be used for credential exchange. """ - if self.configuration.credentials is not None: - added_headers = self.configuration.credentials.get_authentication_header( - self.rest_client) - for key, value in added_headers.items(): - headers[key] = value + credentials = self.configuration.credentials + if credentials is not None: + if credentials.method == 'none': + pass + if credentials.method == 'api_token': + headers['Authorization'] = 'Bearer {}'.format(credentials.configuration.api_token) + if credentials.method == 'client_credentials': + if oauth2_client is None: + oauth2_client = oauth2.OAuth2Client(credentials) + oauth2_headers = oauth2_client.get_authentication_header(self.rest_client) + for key, value in oauth2_headers.items(): + headers[key] = value if not auth_settings: return diff --git a/openfga_sdk/sync/credentials.py b/openfga_sdk/sync/credentials.py deleted file mode 100644 index 6d86d268..00000000 --- a/openfga_sdk/sync/credentials.py +++ /dev/null @@ -1,150 +0,0 @@ -""" - Python SDK for OpenFGA - - API version: 0.1 - Website: https://openfga.dev - Documentation: https://openfga.dev/docs - Support: https://discord.gg/8naAwJfWN6 - License: [Apache-2.0](https://github.com/openfga/python-sdk/blob/main/LICENSE) - - NOTE: This file was auto generated by OpenAPI Generator (https://openapi-generator.tech). DO NOT EDIT. -""" - -from dataclasses import dataclass -from datetime import datetime, timedelta -import json -import typing -import urllib3 -from urllib.parse import urlparse - -from openfga_sdk.credentials import CredentialConfiguration -from openfga_sdk.exceptions import FgaValidationException, ApiValueError, AuthenticationError - - -def none_or_empty(value): - """ - Return true if value is either none or empty string - """ - return value is None or value == '' - - -class Credentials: - """ - Manage the credential for the API Client - :param method: Type of authentication. Possible value is 'none', 'api_token' and 'client_credentials'. Default as 'none'. - :param configuration: Credential configuration of type CredentialConfiguration. Default as None. - """ - - def __init__( - self, - method: typing.Optional[str] = 'none', - configuration: typing.Optional[CredentialConfiguration] = None, - ): - self._method = method - self._configuration = configuration - self._access_token = None - self._access_expiry_time = None - - @property - def method(self): - """ - Return the method configured - """ - return self._method - - @method.setter - def method(self, value): - """ - Update the method - """ - self._method = value - - @property - def configuration(self): - """ - Return the configuration - """ - return self._configuration - - @configuration.setter - def configuration(self, value): - """ - Update the configuration - """ - self._configuration = value - - def validate_credentials_config(self): - """ - Check whether credentials configuration is valid - """ - if self.method != 'none' and self.method != 'api_token' and self.method != 'client_credentials': - raise ApiValueError( - 'method `{}` must be either `none`, `api_token` or `client_credentials`'.format(self.method)) - if self.method == 'api_token' and (self.configuration is None or none_or_empty(self.configuration.api_token)): - raise ApiValueError( - 'configuration `{}` api_token must be defined and non empty when method is api_token'.format(self.configuration)) - if self.method == 'client_credentials': - if self.configuration is None or none_or_empty(self.configuration.client_id) or none_or_empty(self.configuration.client_secret) or none_or_empty(self.configuration.api_audience) or none_or_empty(self.configuration.api_issuer): - raise ApiValueError( - 'configuration `{}` requires client_id, client_secret, api_audience and api_issuer defined for client_credentials method.') - # validate token issuer - combined_url = 'https://' + self.configuration.api_issuer - parsed_url = None - try: - parsed_url = urlparse(combined_url) - except ValueError: - raise ApiValueError('api_issuer `{}` is invalid'.format( - self.configuration.api_issuer)) - if (parsed_url.netloc == ''): - raise ApiValueError('api_issuer `{}` is invalid'.format( - self.configuration.api_issuer)) - - def _token_valid(self): - """ - Return whether token is valid - """ - if self._access_token is None or self._access_expiry_time is None: - return False - if self._access_expiry_time < datetime.now(): - return False - return True - - def _obtain_token(self, client): - """ - Perform OAuth2 and obtain token - """ - token_url = 'https://{}/oauth/token'.format(self.configuration.api_issuer) - body = { - 'client_id': self.configuration.client_id, - 'client_secret': self.configuration.client_secret, - 'audience': self.configuration.api_audience, - 'grant_type': "client_credentials", - } - headers = urllib3.response.HTTPHeaderDict( - {'Accept': 'application/json', 'Content-Type': 'application/json', 'User-Agent': 'openfga-sdk (python) 0.2.1'}) - raw_response = client.POST(token_url, headers=headers, body=body) - if 200 <= raw_response.status <= 299: - try: - api_response = json.loads(raw_response.data) - except: # noqa: E722 - raise AuthenticationError(http_resp=raw_response) - if not api_response.get('expires_in') or not api_response.get('access_token'): - raise AuthenticationError(http_resp=raw_response) - self._access_expiry_time = datetime.now() + timedelta(seconds=int(api_response.get('expires_in'))) - self._access_token = api_response.get('access_token') - else: - raise AuthenticationError(http_resp=raw_response) - - def get_authentication_header(self, client): - """ - If configured, return the header for authentication - """ - if self._method == 'none': - return {} - if self._method == 'api_token': - return {'Authorization': 'Bearer {}'.format(self.configuration.api_token)} - # check to see token is valid - if not self._token_valid(): - # In this case, the token is not valid, we need to get the refresh the token - self._obtain_token(client) - return {'Authorization': 'Bearer {}'.format(self._access_token)} diff --git a/openfga_sdk/sync/oauth2.py b/openfga_sdk/sync/oauth2.py new file mode 100644 index 00000000..3b026b8a --- /dev/null +++ b/openfga_sdk/sync/oauth2.py @@ -0,0 +1,79 @@ +""" + Python SDK for OpenFGA + + API version: 0.1 + Website: https://openfga.dev + Documentation: https://openfga.dev/docs + Support: https://discord.gg/8naAwJfWN6 + License: [Apache-2.0](https://github.com/openfga/python-sdk/blob/main/LICENSE) + + NOTE: This file was auto generated by OpenAPI Generator (https://openapi-generator.tech). DO NOT EDIT. +""" + +from dataclasses import dataclass +from datetime import datetime, timedelta +import json +import typing +import urllib3 +from urllib.parse import urlparse + +from openfga_sdk.credentials import Credentials +from openfga_sdk.exceptions import AuthenticationError + + +class OAuth2Client: + + def __init__( + self, + credentials: Credentials + ): + self._credentials = credentials + self._access_token = None + self._access_expiry_time = None + + def _token_valid(self): + """ + Return whether token is valid + """ + if self._access_token is None or self._access_expiry_time is None: + return False + if self._access_expiry_time < datetime.now(): + return False + return True + + def _obtain_token(self, client): + """ + Perform OAuth2 and obtain token + """ + configuration = self._credentials.configuration + token_url = 'https://{}/oauth/token'.format(configuration.api_issuer) + body = { + 'client_id': configuration.client_id, + 'client_secret': configuration.client_secret, + 'audience': configuration.api_audience, + 'grant_type': "client_credentials", + } + headers = urllib3.response.HTTPHeaderDict( + {'Accept': 'application/json', 'Content-Type': 'application/json', 'User-Agent': 'openfga-sdk (python) 0.2.1'}) + raw_response = client.POST(token_url, headers=headers, body=body) + if 200 <= raw_response.status <= 299: + try: + api_response = json.loads(raw_response.data) + except: # noqa: E722 + raise AuthenticationError(http_resp=raw_response) + if not api_response.get('expires_in') or not api_response.get('access_token'): + raise AuthenticationError(http_resp=raw_response) + self._access_expiry_time = datetime.now() + timedelta(seconds=int(api_response.get('expires_in'))) + self._access_token = api_response.get('access_token') + else: + raise AuthenticationError(http_resp=raw_response) + + def get_authentication_header(self, client): + """ + If configured, return the header for authentication + """ + # check to see token is valid + if not self._token_valid(): + # In this case, the token is not valid, we need to get the refresh the token + self._obtain_token(client) + return {'Authorization': 'Bearer {}'.format(self._access_token)} diff --git a/openfga_sdk/sync/open_fga_api.py b/openfga_sdk/sync/open_fga_api.py index 4da863d6..14551b07 100644 --- a/openfga_sdk/sync/open_fga_api.py +++ b/openfga_sdk/sync/open_fga_api.py @@ -19,6 +19,7 @@ import six from openfga_sdk.sync.api_client import ApiClient +from openfga_sdk.sync.oauth2 import OAuth2Client from openfga_sdk.exceptions import ( # noqa: F401 FgaValidationException, ApiValueError @@ -37,6 +38,12 @@ def __init__(self, api_client=None): api_client = ApiClient() self.api_client = api_client + self._oauth2_client = None + if api_client.configuration is not None: + credentials = api_client.configuration.credentials + if credentials is not None and credentials.method == 'client_credentials': + self._oauth2_client = OAuth2Client(credentials) + def __enter__(self): return self @@ -192,7 +199,8 @@ def check_with_http_info(self, body, **kwargs): # noqa: E501 _request_timeout=local_var_params.get('_request_timeout'), _retry_params=local_var_params.get('_retry_params'), collection_formats=collection_formats, - _request_auth=local_var_params.get('_request_auth')) + _request_auth=local_var_params.get('_request_auth'), + _oauth2_client=self._oauth2_client) def create_store(self, body, **kwargs): # noqa: E501 """Create a store # noqa: E501 @@ -333,7 +341,8 @@ def create_store_with_http_info(self, body, **kwargs): # noqa: E501 _request_timeout=local_var_params.get('_request_timeout'), _retry_params=local_var_params.get('_retry_params'), collection_formats=collection_formats, - _request_auth=local_var_params.get('_request_auth')) + _request_auth=local_var_params.get('_request_auth'), + _oauth2_client=self._oauth2_client) def delete_store(self, **kwargs): # noqa: E501 """Delete a store # noqa: E501 @@ -460,7 +469,8 @@ def delete_store_with_http_info(self, **kwargs): # noqa: E501 _request_timeout=local_var_params.get('_request_timeout'), _retry_params=local_var_params.get('_retry_params'), collection_formats=collection_formats, - _request_auth=local_var_params.get('_request_auth')) + _request_auth=local_var_params.get('_request_auth'), + _oauth2_client=self._oauth2_client) def expand(self, body, **kwargs): # noqa: E501 """Expand all relationships in userset tree format, and following userset rewrite rules. Useful to reason about and debug a certain relationship # noqa: E501 @@ -608,7 +618,8 @@ def expand_with_http_info(self, body, **kwargs): # noqa: E501 _request_timeout=local_var_params.get('_request_timeout'), _retry_params=local_var_params.get('_retry_params'), collection_formats=collection_formats, - _request_auth=local_var_params.get('_request_auth')) + _request_auth=local_var_params.get('_request_auth'), + _oauth2_client=self._oauth2_client) def get_store(self, **kwargs): # noqa: E501 """Get a store # noqa: E501 @@ -740,7 +751,8 @@ def get_store_with_http_info(self, **kwargs): # noqa: E501 _request_timeout=local_var_params.get('_request_timeout'), _retry_params=local_var_params.get('_retry_params'), collection_formats=collection_formats, - _request_auth=local_var_params.get('_request_auth')) + _request_auth=local_var_params.get('_request_auth'), + _oauth2_client=self._oauth2_client) def list_objects(self, body, **kwargs): # noqa: E501 """List all objects of the given type that the user has a relation with # noqa: E501 @@ -888,7 +900,8 @@ def list_objects_with_http_info(self, body, **kwargs): # noqa: E501 _request_timeout=local_var_params.get('_request_timeout'), _retry_params=local_var_params.get('_retry_params'), collection_formats=collection_formats, - _request_auth=local_var_params.get('_request_auth')) + _request_auth=local_var_params.get('_request_auth'), + _oauth2_client=self._oauth2_client) def list_stores(self, **kwargs): # noqa: E501 """List all stores # noqa: E501 @@ -1030,7 +1043,8 @@ def list_stores_with_http_info(self, **kwargs): # noqa: E501 _request_timeout=local_var_params.get('_request_timeout'), _retry_params=local_var_params.get('_retry_params'), collection_formats=collection_formats, - _request_auth=local_var_params.get('_request_auth')) + _request_auth=local_var_params.get('_request_auth'), + _oauth2_client=self._oauth2_client) def read(self, body, **kwargs): # noqa: E501 """Get tuples from the store that matches a query, without following userset rewrite rules # noqa: E501 @@ -1178,7 +1192,8 @@ def read_with_http_info(self, body, **kwargs): # noqa: E501 _request_timeout=local_var_params.get('_request_timeout'), _retry_params=local_var_params.get('_retry_params'), collection_formats=collection_formats, - _request_auth=local_var_params.get('_request_auth')) + _request_auth=local_var_params.get('_request_auth'), + _oauth2_client=self._oauth2_client) def read_assertions(self, authorization_model_id, **kwargs): # noqa: E501 """Read assertions for an authorization model ID # noqa: E501 @@ -1323,7 +1338,8 @@ def read_assertions_with_http_info(self, authorization_model_id, **kwargs): # n _request_timeout=local_var_params.get('_request_timeout'), _retry_params=local_var_params.get('_retry_params'), collection_formats=collection_formats, - _request_auth=local_var_params.get('_request_auth')) + _request_auth=local_var_params.get('_request_auth'), + _oauth2_client=self._oauth2_client) def read_authorization_model(self, id, **kwargs): # noqa: E501 """Return a particular version of an authorization model # noqa: E501 @@ -1467,7 +1483,8 @@ def read_authorization_model_with_http_info(self, id, **kwargs): # noqa: E501 _request_timeout=local_var_params.get('_request_timeout'), _retry_params=local_var_params.get('_retry_params'), collection_formats=collection_formats, - _request_auth=local_var_params.get('_request_auth')) + _request_auth=local_var_params.get('_request_auth'), + _oauth2_client=self._oauth2_client) def read_authorization_models(self, **kwargs): # noqa: E501 """Return all the authorization models for a particular store # noqa: E501 @@ -1613,7 +1630,8 @@ def read_authorization_models_with_http_info(self, **kwargs): # noqa: E501 _request_timeout=local_var_params.get('_request_timeout'), _retry_params=local_var_params.get('_retry_params'), collection_formats=collection_formats, - _request_auth=local_var_params.get('_request_auth')) + _request_auth=local_var_params.get('_request_auth'), + _oauth2_client=self._oauth2_client) def read_changes(self, **kwargs): # noqa: E501 """Return a list of all the tuple changes # noqa: E501 @@ -1766,7 +1784,8 @@ def read_changes_with_http_info(self, **kwargs): # noqa: E501 _request_timeout=local_var_params.get('_request_timeout'), _retry_params=local_var_params.get('_retry_params'), collection_formats=collection_formats, - _request_auth=local_var_params.get('_request_auth')) + _request_auth=local_var_params.get('_request_auth'), + _oauth2_client=self._oauth2_client) def write(self, body, **kwargs): # noqa: E501 """Add or delete tuples from the store # noqa: E501 @@ -1914,7 +1933,8 @@ def write_with_http_info(self, body, **kwargs): # noqa: E501 _request_timeout=local_var_params.get('_request_timeout'), _retry_params=local_var_params.get('_retry_params'), collection_formats=collection_formats, - _request_auth=local_var_params.get('_request_auth')) + _request_auth=local_var_params.get('_request_auth'), + _oauth2_client=self._oauth2_client) def write_assertions(self, authorization_model_id, body, **kwargs): # noqa: E501 """Upsert assertions for an authorization model ID # noqa: E501 @@ -2070,7 +2090,8 @@ def write_assertions_with_http_info(self, authorization_model_id, body, **kwargs _request_timeout=local_var_params.get('_request_timeout'), _retry_params=local_var_params.get('_retry_params'), collection_formats=collection_formats, - _request_auth=local_var_params.get('_request_auth')) + _request_auth=local_var_params.get('_request_auth'), + _oauth2_client=self._oauth2_client) def write_authorization_model(self, body, **kwargs): # noqa: E501 """Create a new authorization model # noqa: E501 @@ -2218,4 +2239,5 @@ def write_authorization_model_with_http_info(self, body, **kwargs): # noqa: E50 _request_timeout=local_var_params.get('_request_timeout'), _retry_params=local_var_params.get('_retry_params'), collection_formats=collection_formats, - _request_auth=local_var_params.get('_request_auth')) + _request_auth=local_var_params.get('_request_auth'), + _oauth2_client=self._oauth2_client) diff --git a/test/test_credentials.py b/test/test_credentials.py index a7566e04..f7f33409 100644 --- a/test/test_credentials.py +++ b/test/test_credentials.py @@ -14,30 +14,9 @@ from unittest import IsolatedAsyncioTestCase -from mock import patch -from datetime import datetime, timedelta - import openfga_sdk -import urllib3 -from openfga_sdk import rest from openfga_sdk.credentials import CredentialConfiguration, Credentials -from openfga_sdk.configuration import Configuration -from openfga_sdk.exceptions import AuthenticationError - - -# Helper function to construct mock response -def mock_response(body, status): - headers = urllib3.response.HTTPHeaderDict({ - 'content-type': 'application/json' - }) - obj = urllib3.HTTPResponse( - body, - headers, - status, - preload_content=False - ) - return rest.RESTResponse(obj, obj.data) class TestCredentials(IsolatedAsyncioTestCase): @@ -165,87 +144,3 @@ def test_configuration_client_credentials_missing_api_audience(self): client_secret='mysecret', api_issuer='www.testme.com')) with self.assertRaises(openfga_sdk.ApiValueError): credential.validate_credentials_config() - - async def test_get_authentication_header(self): - """ - Test getting authentication header when method is none - """ - credential = Credentials() - auth_header = await credential.get_authentication_header(None) - self.assertEqual(auth_header, {}) - - async def test_get_authentication_api_token(self): - """ - Test getting authentication header when method is api token - """ - credential = Credentials( - method="api_token", configuration=CredentialConfiguration(api_token='ABCDEFG')) - auth_header = await credential.get_authentication_header(None) - self.assertEqual(auth_header, {'Authorization': 'Bearer ABCDEFG'}) - - async def test_get_authentication_valid_client_credentials(self): - """ - Test getting authentication header when method is client credentials - """ - credential = Credentials(method="client_credentials", - configuration=CredentialConfiguration(client_id='myclientid', - client_secret='mysecret', api_issuer='www.testme.com', api_audience='myaudience')) - credential._access_token = 'XYZ123' - credential._access_expiry_time = datetime.now() + timedelta(seconds=60) - auth_header = await credential.get_authentication_header(None) - self.assertEqual(auth_header, {'Authorization': 'Bearer XYZ123'}) - - @patch.object(rest.RESTClientObject, 'request') - async def test_get_authentication_obtain_client_credentials(self, mock_request): - """ - Test getting authentication header when method is client credential and we need to obtain token - """ - response_body = ''' -{ - "expires_in": 120, - "access_token": "AABBCCDD" -} - ''' - mock_request.return_value = mock_response(response_body, 200) - - credential = Credentials(method="client_credentials", - configuration=CredentialConfiguration(client_id='myclientid', - client_secret='mysecret', api_issuer='www.testme.com', api_audience='myaudience')) - client = rest.RESTClientObject(Configuration()) - current_time = datetime.now() - auth_header = await credential.get_authentication_header(client) - self.assertEqual(auth_header, {'Authorization': 'Bearer AABBCCDD'}) - self.assertEqual(credential._access_token, 'AABBCCDD') - self.assertGreaterEqual(credential._access_expiry_time, - current_time + timedelta(seconds=int(120))) - expected_header = urllib3.response.HTTPHeaderDict( - {'Accept': 'application/json', 'Content-Type': 'application/json', 'User-Agent': 'openfga-sdk (python) 0.2.1'}) - mock_request.assert_called_once_with( - 'POST', - 'https://www.testme.com/oauth/token', - headers=expected_header, - query_params=None, post_params=None, _preload_content=True, _request_timeout=None, - body={"client_id": "myclientid", "client_secret": "mysecret", - "audience": "myaudience", "grant_type": "client_credentials"} - ) - await client.close() - - @patch.object(rest.RESTClientObject, 'request') - async def test_get_authentication_obtain_client_credentials_failed(self, mock_request): - """ - Test getting authentication header when method is client credential and we fail to obtain token - """ - response_body = ''' -{ - "reason": "Unauthorized" -} - ''' - mock_request.return_value = mock_response(response_body, 403) - - credential = Credentials(method="client_credentials", - configuration=CredentialConfiguration(client_id='myclientid', - client_secret='mysecret', api_issuer='www.testme.com', api_audience='myaudience')) - client = rest.RESTClientObject(Configuration()) - with self.assertRaises(AuthenticationError): - await credential.get_authentication_header(client) - await client.close() diff --git a/test/test_credentials_sync.py b/test/test_credentials_sync.py deleted file mode 100644 index f4efa02f..00000000 --- a/test/test_credentials_sync.py +++ /dev/null @@ -1,252 +0,0 @@ -# coding: utf-8 - -""" - Python SDK for OpenFGA - - API version: 0.1 - Website: https://openfga.dev - Documentation: https://openfga.dev/docs - Support: https://discord.gg/8naAwJfWN6 - License: [Apache-2.0](https://github.com/openfga/python-sdk/blob/main/LICENSE) - - NOTE: This file was auto generated by OpenAPI Generator (https://openapi-generator.tech). DO NOT EDIT. -""" - -from unittest import IsolatedAsyncioTestCase - -from mock import patch -from datetime import datetime, timedelta - -import urllib3 - -from openfga_sdk.sync import rest -from openfga_sdk.credentials import CredentialConfiguration -from openfga_sdk.sync.credentials import Credentials -from openfga_sdk.configuration import Configuration -from openfga_sdk.exceptions import ApiValueError, AuthenticationError - -# Helper function to construct mock response - - -def mock_response(body, status): - headers = urllib3.response.HTTPHeaderDict({ - 'content-type': 'application/json' - }) - obj = urllib3.HTTPResponse( - body, - headers, - status, - preload_content=False - ) - return rest.RESTResponse(obj, obj.data) - - -class TestCredentials(IsolatedAsyncioTestCase): - """Credentials unit test""" - - def setUp(self): - pass - - def tearDown(self): - pass - - def test_bad_method(self): - """ - Check whether assertion is raised if method is not allowed - """ - credential = Credentials("bad") - with self.assertRaises(ApiValueError): - credential.validate_credentials_config() - - def test_method_none(self): - """ - Test credential with method none is valid - """ - credential = Credentials("none") - credential.validate_credentials_config() - self.assertEqual(credential.method, 'none') - - def test_method_default(self): - """ - Test credential with not method is default to none - """ - credential = Credentials() - credential.validate_credentials_config() - self.assertEqual(credential.method, 'none') - - def test_configuration_api_token(self): - """ - Test credential with method api_token and appropriate configuration is valid - """ - credential = Credentials( - method="api_token", configuration=CredentialConfiguration(api_token='ABCDEFG')) - credential.validate_credentials_config() - self.assertEqual(credential.method, 'api_token') - self.assertEqual(credential.configuration.api_token, 'ABCDEFG') - - def test_configuration_api_token_missing_configuration(self): - """ - Test credential with method api_token but configuration is not specified - """ - credential = Credentials(method="api_token") - with self.assertRaises(ApiValueError): - credential.validate_credentials_config() - - def test_configuration_api_token_missing_token(self): - """ - Test credential with method api_token but configuration is missing token - """ - credential = Credentials(method="api_token", configuration=CredentialConfiguration()) - with self.assertRaises(ApiValueError): - credential.validate_credentials_config() - - def test_configuration_api_token_empty_token(self): - """ - Test credential with method api_token but configuration has empty token - """ - credential = Credentials( - method="api_token", configuration=CredentialConfiguration(api_token='')) - with self.assertRaises(ApiValueError): - credential.validate_credentials_config() - - def test_configuration_client_credentials(self): - """ - Test credential with method client_credentials and appropriate configuration is valid - """ - credential = Credentials(method="client_credentials", - configuration=CredentialConfiguration(client_id='myclientid', - client_secret='mysecret', api_issuer='www.testme.com', api_audience='myaudience')) - credential.validate_credentials_config() - self.assertEqual(credential.method, 'client_credentials') - - def test_configuration_client_credentials_missing_config(self): - """ - Test credential with method client_credentials and configuration is missing - """ - credential = Credentials(method="client_credentials") - with self.assertRaises(ApiValueError): - credential.validate_credentials_config() - - def test_configuration_client_credentials_missing_client_id(self): - """ - Test credential with method client_credentials and configuration is missing client id - """ - credential = Credentials(method="client_credentials", - configuration=CredentialConfiguration( - client_secret='mysecret', api_issuer='www.testme.com', api_audience='myaudience')) - with self.assertRaises(ApiValueError): - credential.validate_credentials_config() - - def test_configuration_client_credentials_missing_client_secret(self): - """ - Test credential with method client_credentials and configuration is missing client secret - """ - credential = Credentials(method="client_credentials", - configuration=CredentialConfiguration(client_id='myclientid', - api_issuer='www.testme.com', api_audience='myaudience')) - with self.assertRaises(ApiValueError): - credential.validate_credentials_config() - - def test_configuration_client_credentials_missing_api_issuer(self): - """ - Test credential with method client_credentials and configuration is missing api issuer - """ - credential = Credentials(method="client_credentials", - configuration=CredentialConfiguration(client_id='myclientid', - client_secret='mysecret', api_audience='myaudience')) - with self.assertRaises(ApiValueError): - credential.validate_credentials_config() - - def test_configuration_client_credentials_missing_api_audience(self): - """ - Test credential with method client_credentials and configuration is missing api audience - """ - credential = Credentials(method="client_credentials", - configuration=CredentialConfiguration(client_id='myclientid', - client_secret='mysecret', api_issuer='www.testme.com')) - with self.assertRaises(ApiValueError): - credential.validate_credentials_config() - - def test_get_authentication_header(self): - """ - Test getting authentication header when method is none - """ - credential = Credentials() - auth_header = credential.get_authentication_header(None) - self.assertEqual(auth_header, {}) - - def test_get_authentication_api_token(self): - """ - Test getting authentication header when method is api token - """ - credential = Credentials( - method="api_token", configuration=CredentialConfiguration(api_token='ABCDEFG')) - auth_header = credential.get_authentication_header(None) - self.assertEqual(auth_header, {'Authorization': 'Bearer ABCDEFG'}) - - def test_get_authentication_valid_client_credentials(self): - """ - Test getting authentication header when method is client credentials - """ - credential = Credentials(method="client_credentials", - configuration=CredentialConfiguration(client_id='myclientid', - client_secret='mysecret', api_issuer='www.testme.com', api_audience='myaudience')) - credential._access_token = 'XYZ123' - credential._access_expiry_time = datetime.now() + timedelta(seconds=60) - auth_header = credential.get_authentication_header(None) - self.assertEqual(auth_header, {'Authorization': 'Bearer XYZ123'}) - - @patch.object(rest.RESTClientObject, 'request') - def test_get_authentication_obtain_client_credentials(self, mock_request): - """ - Test getting authentication header when method is client credential and we need to obtain token - """ - response_body = ''' -{ - "expires_in": 120, - "access_token": "AABBCCDD" -} - ''' - mock_request.return_value = mock_response(response_body, 200) - - credential = Credentials(method="client_credentials", - configuration=CredentialConfiguration(client_id='myclientid', - client_secret='mysecret', api_issuer='www.testme.com', api_audience='myaudience')) - client = rest.RESTClientObject(Configuration()) - current_time = datetime.now() - auth_header = credential.get_authentication_header(client) - self.assertEqual(auth_header, {'Authorization': 'Bearer AABBCCDD'}) - self.assertEqual(credential._access_token, 'AABBCCDD') - self.assertGreaterEqual(credential._access_expiry_time, - current_time + timedelta(seconds=int(120))) - expected_header = urllib3.response.HTTPHeaderDict( - {'Accept': 'application/json', 'Content-Type': 'application/json', 'User-Agent': 'openfga-sdk (python) 0.2.1'}) - mock_request.assert_called_once_with( - 'POST', - 'https://www.testme.com/oauth/token', - headers=expected_header, - query_params=None, post_params=None, _preload_content=True, _request_timeout=None, - body={"client_id": "myclientid", "client_secret": "mysecret", - "audience": "myaudience", "grant_type": "client_credentials"} - ) - client.close() - - @patch.object(rest.RESTClientObject, 'request') - def test_get_authentication_obtain_client_credentials_failed(self, mock_request): - """ - Test getting authentication header when method is client credential and we fail to obtain token - """ - response_body = ''' -{ - "reason": "Unauthorized" -} - ''' - mock_request.return_value = mock_response(response_body, 403) - - credential = Credentials(method="client_credentials", - configuration=CredentialConfiguration(client_id='myclientid', - client_secret='mysecret', api_issuer='www.testme.com', api_audience='myaudience')) - client = rest.RESTClientObject(Configuration()) - with self.assertRaises(AuthenticationError): - credential.get_authentication_header(client) - client.close() diff --git a/test/test_oauth2.py b/test/test_oauth2.py new file mode 100644 index 00000000..819fd95d --- /dev/null +++ b/test/test_oauth2.py @@ -0,0 +1,115 @@ +# coding: utf-8 + +""" + Python SDK for OpenFGA + + API version: 0.1 + Website: https://openfga.dev + Documentation: https://openfga.dev/docs + Support: https://discord.gg/8naAwJfWN6 + License: [Apache-2.0](https://github.com/openfga/python-sdk/blob/main/LICENSE) + + NOTE: This file was auto generated by OpenAPI Generator (https://openapi-generator.tech). DO NOT EDIT. +""" + +import urllib3 + +from unittest import IsolatedAsyncioTestCase +from mock import patch +from datetime import datetime, timedelta +from openfga_sdk.oauth2 import OAuth2Client +from openfga_sdk import rest +from openfga_sdk.credentials import CredentialConfiguration, Credentials +from openfga_sdk.configuration import Configuration +from openfga_sdk.exceptions import AuthenticationError + + +# Helper function to construct mock response +def mock_response(body, status): + headers = urllib3.response.HTTPHeaderDict({ + 'content-type': 'application/json' + }) + obj = urllib3.HTTPResponse( + body, + headers, + status, + preload_content=False + ) + return rest.RESTResponse(obj, obj.data) + + +class TestOAuth2Client(IsolatedAsyncioTestCase): + """TestOAuth2Client unit test""" + + def setUp(self): + pass + + def tearDown(self): + pass + + async def test_get_authentication_valid_client_credentials(self): + """ + Test getting authentication header when method is client credentials + """ + client = OAuth2Client(None) + client._access_token = 'XYZ123' + client._access_expiry_time = datetime.now() + timedelta(seconds=60) + auth_header = await client.get_authentication_header(None) + self.assertEqual(auth_header, {'Authorization': 'Bearer XYZ123'}) + + @patch.object(rest.RESTClientObject, 'request') + async def test_get_authentication_obtain_client_credentials(self, mock_request): + """ + Test getting authentication header when method is client credential and we need to obtain token + """ + response_body = ''' +{ + "expires_in": 120, + "access_token": "AABBCCDD" +} + ''' + mock_request.return_value = mock_response(response_body, 200) + + credentials = Credentials(method="client_credentials", + configuration=CredentialConfiguration(client_id='myclientid', + client_secret='mysecret', api_issuer='www.testme.com', api_audience='myaudience')) + rest_client = rest.RESTClientObject(Configuration()) + current_time = datetime.now() + client = OAuth2Client(credentials) + auth_header = await client.get_authentication_header(rest_client) + self.assertEqual(auth_header, {'Authorization': 'Bearer AABBCCDD'}) + self.assertEqual(client._access_token, 'AABBCCDD') + self.assertGreaterEqual(client._access_expiry_time, + current_time + timedelta(seconds=int(120))) + expected_header = urllib3.response.HTTPHeaderDict( + {'Accept': 'application/json', 'Content-Type': 'application/json', 'User-Agent': 'openfga-sdk (python) 0.2.1'}) + mock_request.assert_called_once_with( + 'POST', + 'https://www.testme.com/oauth/token', + headers=expected_header, + query_params=None, post_params=None, _preload_content=True, _request_timeout=None, + body={"client_id": "myclientid", "client_secret": "mysecret", + "audience": "myaudience", "grant_type": "client_credentials"} + ) + await rest_client.close() + + @patch.object(rest.RESTClientObject, 'request') + async def test_get_authentication_obtain_client_credentials_failed(self, mock_request): + """ + Test getting authentication header when method is client credential and we fail to obtain token + """ + response_body = ''' +{ + "reason": "Unauthorized" +} + ''' + mock_request.return_value = mock_response(response_body, 403) + + credentials = Credentials(method="client_credentials", + configuration=CredentialConfiguration(client_id='myclientid', + client_secret='mysecret', api_issuer='www.testme.com', api_audience='myaudience')) + rest_client = rest.RESTClientObject(Configuration()) + client = OAuth2Client(credentials) + with self.assertRaises(AuthenticationError): + await client.get_authentication_header(rest_client) + await rest_client.close() diff --git a/test/test_oauth2_sync.py b/test/test_oauth2_sync.py new file mode 100644 index 00000000..f4700e96 --- /dev/null +++ b/test/test_oauth2_sync.py @@ -0,0 +1,116 @@ +# coding: utf-8 + +""" + Python SDK for OpenFGA + + API version: 0.1 + Website: https://openfga.dev + Documentation: https://openfga.dev/docs + Support: https://discord.gg/8naAwJfWN6 + License: [Apache-2.0](https://github.com/openfga/python-sdk/blob/main/LICENSE) + + NOTE: This file was auto generated by OpenAPI Generator (https://openapi-generator.tech). DO NOT EDIT. +""" + + +import urllib3 + +from unittest import IsolatedAsyncioTestCase +from mock import patch +from datetime import datetime, timedelta +from openfga_sdk.sync.oauth2 import OAuth2Client +from openfga_sdk.sync import rest +from openfga_sdk.credentials import CredentialConfiguration, Credentials +from openfga_sdk.configuration import Configuration +from openfga_sdk.exceptions import AuthenticationError + + +# Helper function to construct mock response +def mock_response(body, status): + headers = urllib3.response.HTTPHeaderDict({ + 'content-type': 'application/json' + }) + obj = urllib3.HTTPResponse( + body, + headers, + status, + preload_content=False + ) + return rest.RESTResponse(obj, obj.data) + + +class TestOAuth2Client(IsolatedAsyncioTestCase): + """TestOAuth2Client unit test""" + + def setUp(self): + pass + + def tearDown(self): + pass + + def test_get_authentication_valid_client_credentials(self): + """ + Test getting authentication header when method is client credentials + """ + client = OAuth2Client(None) + client._access_token = 'XYZ123' + client._access_expiry_time = datetime.now() + timedelta(seconds=60) + auth_header = client.get_authentication_header(None) + self.assertEqual(auth_header, {'Authorization': 'Bearer XYZ123'}) + + @patch.object(rest.RESTClientObject, 'request') + def test_get_authentication_obtain_client_credentials(self, mock_request): + """ + Test getting authentication header when method is client credential and we need to obtain token + """ + response_body = ''' +{ + "expires_in": 120, + "access_token": "AABBCCDD" +} + ''' + mock_request.return_value = mock_response(response_body, 200) + + credentials = Credentials(method="client_credentials", + configuration=CredentialConfiguration(client_id='myclientid', + client_secret='mysecret', api_issuer='www.testme.com', api_audience='myaudience')) + rest_client = rest.RESTClientObject(Configuration()) + current_time = datetime.now() + client = OAuth2Client(credentials) + auth_header = client.get_authentication_header(rest_client) + self.assertEqual(auth_header, {'Authorization': 'Bearer AABBCCDD'}) + self.assertEqual(client._access_token, 'AABBCCDD') + self.assertGreaterEqual(client._access_expiry_time, + current_time + timedelta(seconds=int(120))) + expected_header = urllib3.response.HTTPHeaderDict( + {'Accept': 'application/json', 'Content-Type': 'application/json', 'User-Agent': 'openfga-sdk (python) 0.2.1'}) + mock_request.assert_called_once_with( + 'POST', + 'https://www.testme.com/oauth/token', + headers=expected_header, + query_params=None, post_params=None, _preload_content=True, _request_timeout=None, + body={"client_id": "myclientid", "client_secret": "mysecret", + "audience": "myaudience", "grant_type": "client_credentials"} + ) + rest_client.close() + + @patch.object(rest.RESTClientObject, 'request') + def test_get_authentication_obtain_client_credentials_failed(self, mock_request): + """ + Test getting authentication header when method is client credential and we fail to obtain token + """ + response_body = ''' +{ + "reason": "Unauthorized" +} + ''' + mock_request.return_value = mock_response(response_body, 403) + + credentials = Credentials(method="client_credentials", + configuration=CredentialConfiguration(client_id='myclientid', + client_secret='mysecret', api_issuer='www.testme.com', api_audience='myaudience')) + rest_client = rest.RESTClientObject(Configuration()) + client = OAuth2Client(credentials) + with self.assertRaises(AuthenticationError): + client.get_authentication_header(rest_client) + rest_client.close() diff --git a/test/test_open_fga_api_sync.py b/test/test_open_fga_api_sync.py index c81f2bd7..7d426ddd 100644 --- a/test/test_open_fga_api_sync.py +++ b/test/test_open_fga_api_sync.py @@ -21,9 +21,9 @@ import urllib3 import openfga_sdk.sync -from openfga_sdk.sync import rest, open_fga_api, Credentials +from openfga_sdk.sync import rest, open_fga_api from openfga_sdk.sync.api_client import ApiClient -from openfga_sdk.credentials import CredentialConfiguration +from openfga_sdk.credentials import Credentials, CredentialConfiguration from openfga_sdk.configuration import Configuration from openfga_sdk.exceptions import FgaValidationException, ApiValueError, NotFoundException, RateLimitExceededError, ServiceException, ValidationException, FGA_REQUEST_ID from openfga_sdk.models.assertion import Assertion