Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(python-sdk): extract oauth2 from credentials #42

Merged
merged 1 commit into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/semgrep.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ jobs:
image: returntocorp/semgrep
if: (github.actor != 'dependabot[bot]' && github.actor != 'snyk-bot')
steps:
- uses: actions/checkout@c85c95e3d7251135ab7dc9ce3241c5835cc595a9 # v3.5.2
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1
with:
fetch-depth: 0
- run: semgrep ci
env:
SEMGREP_APP_TOKEN: ${{ secrets.SEMGREP_APP_TOKEN }}
3 changes: 2 additions & 1 deletion .openapi-generator-ignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ test/*
!test/test_client_sync.py
!test/test_open_fga_api_sync.py
!test/test_validation.py
!test/test_credentials_sync.py
!test/test_oauth2.py
!test/test_oauth2_sync.py
.github/workflows/python.yml
.gitlab-ci.yml
.travis.yml
Expand Down
6 changes: 4 additions & 2 deletions .openapi-generator/FILES
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,13 @@ openfga_sdk/models/write_assertions_request.py
openfga_sdk/models/write_authorization_model_request.py
openfga_sdk/models/write_authorization_model_response.py
openfga_sdk/models/write_request.py
openfga_sdk/oauth2.py
openfga_sdk/rest.py
openfga_sdk/sync/__init__.py
openfga_sdk/sync/api_client.py
openfga_sdk/sync/client/__init__.py
openfga_sdk/sync/client/client.py
openfga_sdk/sync/credentials.py
openfga_sdk/sync/oauth2.py
openfga_sdk/sync/rest.py
openfga_sdk/validation.py
requirements.txt
Expand All @@ -162,7 +163,8 @@ test/__init__.py
test/test_client.py
test/test_client_sync.py
test/test_credentials.py
test/test_credentials_sync.py
test/test_oauth2.py
test/test_oauth2_sync.py
test/test_open_fga_api.py
test/test_open_fga_api_sync.py
test/test_validation.py
52 changes: 37 additions & 15 deletions openfga_sdk/api/open_fga_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import six

from openfga_sdk.api_client import ApiClient
from openfga_sdk.oauth2 import OAuth2Client
from openfga_sdk.exceptions import ( # noqa: F401
FgaValidationException,
ApiValueError
Expand All @@ -37,6 +38,12 @@ def __init__(self, api_client=None):
api_client = ApiClient()
self.api_client = api_client

self._oauth2_client = None
if api_client.configuration is not None:
credentials = api_client.configuration.credentials
if credentials is not None and credentials.method == 'client_credentials':
self._oauth2_client = OAuth2Client(credentials)

async def __aenter__(self):
return self

Expand Down Expand Up @@ -192,7 +199,8 @@ async def check_with_http_info(self, body, **kwargs): # noqa: E501
_request_timeout=local_var_params.get('_request_timeout'),
_retry_params=local_var_params.get('_retry_params'),
collection_formats=collection_formats,
_request_auth=local_var_params.get('_request_auth')))
_request_auth=local_var_params.get('_request_auth'),
_oauth2_client=self._oauth2_client))

async def create_store(self, body, **kwargs): # noqa: E501
"""Create a store # noqa: E501
Expand Down Expand Up @@ -333,7 +341,8 @@ async def create_store_with_http_info(self, body, **kwargs): # noqa: E501
_request_timeout=local_var_params.get('_request_timeout'),
_retry_params=local_var_params.get('_retry_params'),
collection_formats=collection_formats,
_request_auth=local_var_params.get('_request_auth')))
_request_auth=local_var_params.get('_request_auth'),
_oauth2_client=self._oauth2_client))

async def delete_store(self, **kwargs): # noqa: E501
"""Delete a store # noqa: E501
Expand Down Expand Up @@ -460,7 +469,8 @@ async def delete_store_with_http_info(self, **kwargs): # noqa: E501
_request_timeout=local_var_params.get('_request_timeout'),
_retry_params=local_var_params.get('_retry_params'),
collection_formats=collection_formats,
_request_auth=local_var_params.get('_request_auth')))
_request_auth=local_var_params.get('_request_auth'),
_oauth2_client=self._oauth2_client))

async def expand(self, body, **kwargs): # noqa: E501
"""Expand all relationships in userset tree format, and following userset rewrite rules. Useful to reason about and debug a certain relationship # noqa: E501
Expand Down Expand Up @@ -608,7 +618,8 @@ async def expand_with_http_info(self, body, **kwargs): # noqa: E501
_request_timeout=local_var_params.get('_request_timeout'),
_retry_params=local_var_params.get('_retry_params'),
collection_formats=collection_formats,
_request_auth=local_var_params.get('_request_auth')))
_request_auth=local_var_params.get('_request_auth'),
_oauth2_client=self._oauth2_client))

async def get_store(self, **kwargs): # noqa: E501
"""Get a store # noqa: E501
Expand Down Expand Up @@ -740,7 +751,8 @@ async def get_store_with_http_info(self, **kwargs): # noqa: E501
_request_timeout=local_var_params.get('_request_timeout'),
_retry_params=local_var_params.get('_retry_params'),
collection_formats=collection_formats,
_request_auth=local_var_params.get('_request_auth')))
_request_auth=local_var_params.get('_request_auth'),
_oauth2_client=self._oauth2_client))

async def list_objects(self, body, **kwargs): # noqa: E501
"""List all objects of the given type that the user has a relation with # noqa: E501
Expand Down Expand Up @@ -888,7 +900,8 @@ async def list_objects_with_http_info(self, body, **kwargs): # noqa: E501
_request_timeout=local_var_params.get('_request_timeout'),
_retry_params=local_var_params.get('_retry_params'),
collection_formats=collection_formats,
_request_auth=local_var_params.get('_request_auth')))
_request_auth=local_var_params.get('_request_auth'),
_oauth2_client=self._oauth2_client))

async def list_stores(self, **kwargs): # noqa: E501
"""List all stores # noqa: E501
Expand Down Expand Up @@ -1030,7 +1043,8 @@ async def list_stores_with_http_info(self, **kwargs): # noqa: E501
_request_timeout=local_var_params.get('_request_timeout'),
_retry_params=local_var_params.get('_retry_params'),
collection_formats=collection_formats,
_request_auth=local_var_params.get('_request_auth')))
_request_auth=local_var_params.get('_request_auth'),
_oauth2_client=self._oauth2_client))

async def read(self, body, **kwargs): # noqa: E501
"""Get tuples from the store that matches a query, without following userset rewrite rules # noqa: E501
Expand Down Expand Up @@ -1178,7 +1192,8 @@ async def read_with_http_info(self, body, **kwargs): # noqa: E501
_request_timeout=local_var_params.get('_request_timeout'),
_retry_params=local_var_params.get('_retry_params'),
collection_formats=collection_formats,
_request_auth=local_var_params.get('_request_auth')))
_request_auth=local_var_params.get('_request_auth'),
_oauth2_client=self._oauth2_client))

async def read_assertions(self, authorization_model_id, **kwargs): # noqa: E501
"""Read assertions for an authorization model ID # noqa: E501
Expand Down Expand Up @@ -1323,7 +1338,8 @@ async def read_assertions_with_http_info(self, authorization_model_id, **kwargs)
_request_timeout=local_var_params.get('_request_timeout'),
_retry_params=local_var_params.get('_retry_params'),
collection_formats=collection_formats,
_request_auth=local_var_params.get('_request_auth')))
_request_auth=local_var_params.get('_request_auth'),
_oauth2_client=self._oauth2_client))

async def read_authorization_model(self, id, **kwargs): # noqa: E501
"""Return a particular version of an authorization model # noqa: E501
Expand Down Expand Up @@ -1467,7 +1483,8 @@ async def read_authorization_model_with_http_info(self, id, **kwargs): # noqa:
_request_timeout=local_var_params.get('_request_timeout'),
_retry_params=local_var_params.get('_retry_params'),
collection_formats=collection_formats,
_request_auth=local_var_params.get('_request_auth')))
_request_auth=local_var_params.get('_request_auth'),
_oauth2_client=self._oauth2_client))

async def read_authorization_models(self, **kwargs): # noqa: E501
"""Return all the authorization models for a particular store # noqa: E501
Expand Down Expand Up @@ -1613,7 +1630,8 @@ async def read_authorization_models_with_http_info(self, **kwargs): # noqa: E50
_request_timeout=local_var_params.get('_request_timeout'),
_retry_params=local_var_params.get('_retry_params'),
collection_formats=collection_formats,
_request_auth=local_var_params.get('_request_auth')))
_request_auth=local_var_params.get('_request_auth'),
_oauth2_client=self._oauth2_client))

async def read_changes(self, **kwargs): # noqa: E501
"""Return a list of all the tuple changes # noqa: E501
Expand Down Expand Up @@ -1766,7 +1784,8 @@ async def read_changes_with_http_info(self, **kwargs): # noqa: E501
_request_timeout=local_var_params.get('_request_timeout'),
_retry_params=local_var_params.get('_retry_params'),
collection_formats=collection_formats,
_request_auth=local_var_params.get('_request_auth')))
_request_auth=local_var_params.get('_request_auth'),
_oauth2_client=self._oauth2_client))

async def write(self, body, **kwargs): # noqa: E501
"""Add or delete tuples from the store # noqa: E501
Expand Down Expand Up @@ -1914,7 +1933,8 @@ async def write_with_http_info(self, body, **kwargs): # noqa: E501
_request_timeout=local_var_params.get('_request_timeout'),
_retry_params=local_var_params.get('_retry_params'),
collection_formats=collection_formats,
_request_auth=local_var_params.get('_request_auth')))
_request_auth=local_var_params.get('_request_auth'),
_oauth2_client=self._oauth2_client))

async def write_assertions(self, authorization_model_id, body, **kwargs): # noqa: E501
"""Upsert assertions for an authorization model ID # noqa: E501
Expand Down Expand Up @@ -2070,7 +2090,8 @@ async def write_assertions_with_http_info(self, authorization_model_id, body, **
_request_timeout=local_var_params.get('_request_timeout'),
_retry_params=local_var_params.get('_retry_params'),
collection_formats=collection_formats,
_request_auth=local_var_params.get('_request_auth')))
_request_auth=local_var_params.get('_request_auth'),
_oauth2_client=self._oauth2_client))

async def write_authorization_model(self, body, **kwargs): # noqa: E501
"""Create a new authorization model # noqa: E501
Expand Down Expand Up @@ -2218,4 +2239,5 @@ async def write_authorization_model_with_http_info(self, body, **kwargs): # noq
_request_timeout=local_var_params.get('_request_timeout'),
_retry_params=local_var_params.get('_retry_params'),
collection_formats=collection_formats,
_request_auth=local_var_params.get('_request_auth')))
_request_auth=local_var_params.get('_request_auth'),
_oauth2_client=self._oauth2_client))
33 changes: 22 additions & 11 deletions openfga_sdk/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from openfga_sdk.configuration import Configuration
import openfga_sdk.models
from openfga_sdk import rest
from openfga_sdk import rest, oauth2
from openfga_sdk.exceptions import ApiValueError, ApiException, FgaValidationException, RateLimitExceededError


Expand Down Expand Up @@ -140,7 +140,7 @@ async def __call_api(
response_types_map=None, auth_settings=None,
_return_http_data_only=None, collection_formats=None,
_preload_content=True, _request_timeout=None, _host=None,
_request_auth=None, _retry_params=None):
_request_auth=None, _retry_params=None, _oauth2_client=None):

self.configuration.is_valid()
config = self.configuration
Expand Down Expand Up @@ -183,7 +183,7 @@ async def __call_api(
# auth setting
await self.update_params_for_auth(
header_params, query_params, auth_settings,
request_auth=_request_auth)
request_auth=_request_auth, oauth2_client=_oauth2_client)

# body
if body:
Expand Down Expand Up @@ -369,7 +369,7 @@ async def call_api(self, resource_path, method,
async_req=None, _return_http_data_only=None,
collection_formats=None, _preload_content=True,
_request_timeout=None, _host=None, _request_auth=None,
_retry_params=None):
_retry_params=None, _oauth2_client=None):
"""Makes the HTTP request (synchronous) and returns deserialized data.

To make an async_req request, set the async_req parameter.
Expand Down Expand Up @@ -417,7 +417,7 @@ async def call_api(self, resource_path, method,
response_types_map, auth_settings,
_return_http_data_only, collection_formats,
_preload_content, _request_timeout, _host,
_request_auth, _retry_params))
_request_auth, _retry_params, _oauth2_client))

return self.pool.apply_async(self.__call_api, (resource_path,
method, path_params,
Expand All @@ -430,7 +430,9 @@ async def call_api(self, resource_path, method,
collection_formats,
_preload_content,
_request_timeout,
_host, _request_auth, _retry_params))
_host, _request_auth,
_retry_params,
_oauth2_client))

async def request(self, method, url, query_params=None, headers=None,
post_params=None, body=None, _preload_content=True,
Expand Down Expand Up @@ -561,19 +563,28 @@ def select_header_content_type(self, content_types, method=None, body=None):
return content_types[0]

async def update_params_for_auth(self, headers, queries, auth_settings,
request_auth=None):
request_auth=None, oauth2_client=None):
"""Updates header and query params based on authentication setting.

:param headers: Header parameters dict to be updated.
:param queries: Query parameters tuple list to be updated.
:param auth_settings: Authentication setting identifiers list.
:param request_auth: if set, the provided settings will
override the token in the configuration.
:param oauth2_client: if set, will be used for credential exchange.
"""
if self.configuration.credentials is not None:
added_headers = await self.configuration.credentials.get_authentication_header(self.rest_client)
for key, value in added_headers.items():
headers[key] = value
credentials = self.configuration.credentials
if credentials is not None:
if credentials.method == 'none':
pass
if credentials.method == 'api_token':
headers['Authorization'] = 'Bearer {}'.format(credentials.configuration.api_token)
if credentials.method == 'client_credentials':
if oauth2_client is None:
oauth2_client = oauth2.OAuth2Client(credentials)
oauth2_headers = await oauth2_client.get_authentication_header(self.rest_client)
for key, value in oauth2_headers.items():
headers[key] = value

if not auth_settings:
return
Expand Down
57 changes: 1 addition & 56 deletions openfga_sdk/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,10 @@
"""

from dataclasses import dataclass
from datetime import datetime, timedelta
import json
import typing
import urllib3
from urllib.parse import urlparse

from openfga_sdk.exceptions import FgaValidationException, ApiValueError, AuthenticationError
from openfga_sdk.exceptions import ApiValueError


def none_or_empty(value):
Expand Down Expand Up @@ -136,8 +133,6 @@ def __init__(
):
self._method = method
self._configuration = configuration
self._access_token = None
self._access_expiry_time = None

@property
def method(self):
Expand Down Expand Up @@ -192,53 +187,3 @@ def validate_credentials_config(self):
if (parsed_url.netloc == ''):
raise ApiValueError('api_issuer `{}` is invalid'.format(
self.configuration.api_issuer))

def _token_valid(self):
"""
Return whether token is valid
"""
if self._access_token is None or self._access_expiry_time is None:
return False
if self._access_expiry_time < datetime.now():
return False
return True

async def _obtain_token(self, client):
"""
Perform OAuth2 and obtain token
"""
token_url = 'https://{}/oauth/token'.format(self.configuration.api_issuer)
body = {
'client_id': self.configuration.client_id,
'client_secret': self.configuration.client_secret,
'audience': self.configuration.api_audience,
'grant_type': "client_credentials",
}
headers = urllib3.response.HTTPHeaderDict(
{'Accept': 'application/json', 'Content-Type': 'application/json', 'User-Agent': 'openfga-sdk (python) 0.2.1'})
raw_response = await client.POST(token_url, headers=headers, body=body)
if 200 <= raw_response.status <= 299:
try:
api_response = json.loads(raw_response.data)
except: # noqa: E722
raise AuthenticationError(http_resp=raw_response)
if not api_response.get('expires_in') or not api_response.get('access_token'):
raise AuthenticationError(http_resp=raw_response)
self._access_expiry_time = datetime.now() + timedelta(seconds=int(api_response.get('expires_in')))
self._access_token = api_response.get('access_token')
else:
raise AuthenticationError(http_resp=raw_response)

async def get_authentication_header(self, client):
"""
If configured, return the header for authentication
"""
if self._method == 'none':
return {}
if self._method == 'api_token':
return {'Authorization': 'Bearer {}'.format(self.configuration.api_token)}
# check to see token is valid
if not self._token_valid():
# In this case, the token is not valid, we need to get the refresh the token
await self._obtain_token(client)
return {'Authorization': 'Bearer {}'.format(self._access_token)}
Loading