From 8f527c0fa58aa482adfa270deff412d4ce288bf3 Mon Sep 17 00:00:00 2001 From: KrKOo Date: Fri, 15 Mar 2024 16:30:19 +0100 Subject: [PATCH] fmt --- .../__init__.py | 8 +-- snakemake_executor_plugin_auth_tes/auth.py | 59 +++++++++++-------- tests/tests.py | 2 +- 3 files changed, 40 insertions(+), 29 deletions(-) diff --git a/snakemake_executor_plugin_auth_tes/__init__.py b/snakemake_executor_plugin_auth_tes/__init__.py index 4cd18ba..b64b58c 100644 --- a/snakemake_executor_plugin_auth_tes/__init__.py +++ b/snakemake_executor_plugin_auth_tes/__init__.py @@ -154,7 +154,9 @@ def __post_init__(self): self._refresh_token = exchange_result["refresh_token"] new_client = self.auth_client.register_client( - "run", [self.workflow.executor_settings.oidc_audience], ["offline_access", "client_dynamic_deregistration"] + "run", + [self.workflow.executor_settings.oidc_audience], + ["offline_access", "client_dynamic_deregistration"], ) self.auth_client = AuthClient( @@ -191,9 +193,7 @@ def tes_access_token(self): return self.workflow.executor_settings.token if self.auth_client.is_token_expired(self._access_token): - refresh_result = self.auth_client.refresh_access_token( - self._refresh_token - ) + refresh_result = self.auth_client.refresh_access_token(self._refresh_token) self._access_token = refresh_result["access_token"] self._refresh_token = refresh_result["refresh_token"] diff --git a/snakemake_executor_plugin_auth_tes/auth.py b/snakemake_executor_plugin_auth_tes/auth.py index 254344c..163f195 100644 --- a/snakemake_executor_plugin_auth_tes/auth.py +++ b/snakemake_executor_plugin_auth_tes/auth.py @@ -5,6 +5,7 @@ GRANT_TYPE_TOKEN_EXCHANGE = "urn:ietf:params:oauth:grant-type:token-exchange" GRANT_TYPE_CLIENT_CREDENTIALS = "client_credentials" + class AuthClient: def __init__(self, client_id, client_secret, oidc_url): self.client_id = client_id @@ -16,7 +17,9 @@ def __init__(self, client_id, client_secret, oidc_url): self.register_url = self.oidc_url + "/register" self.jwks_url = self.oidc_url + "/jwk" - self.basic_auth = requests.auth.HTTPBasicAuth(self.client_id, self.client_secret) + self.basic_auth = requests.auth.HTTPBasicAuth( + self.client_id, self.client_secret + ) def is_token_expired(self, token): jwks_client = jwt.PyJWKClient(self.jwks_url) @@ -24,16 +27,14 @@ def is_token_expired(self, token): key = jwks_client.get_signing_key(header["kid"]).key try: - jwt.decode(token, key, [header["alg"]], options={"verify_aud": False}) + jwt.decode(token, key, [header["alg"]], options={"verify_aud": False}) except jwt.ExpiredSignatureError: return True return False def is_token_valid(self, token): - body = { - "token": token - } + body = {"token": token} response = requests.post(self.introspect_url, body, auth=self.basic_auth) @@ -44,7 +45,7 @@ def is_token_valid(self, token): if token_info["active"]: return True - + return False def get_new_token(self, scopes, audience=None): @@ -62,14 +63,14 @@ def get_new_token(self, scopes, audience=None): raise Exception("Failed to get a new access token: " + response.text) return response.json() - + def exchange_access_token(self, token, scopes, audience=None): body = { "subject_token": token, "subject_token_type": "urn:ietf:params:oauth:token-type:access_token", "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", "scope": " ".join(scopes), - "grant_type": GRANT_TYPE_TOKEN_EXCHANGE + "grant_type": GRANT_TYPE_TOKEN_EXCHANGE, } if audience: @@ -79,15 +80,11 @@ def exchange_access_token(self, token, scopes, audience=None): if response.status_code != 200: raise Exception("Failed to exchange access token: " + response.text) - + return response.json() - + def refresh_access_token(self, refresh_token): - body = { - "refresh_token": refresh_token, - "grant_type": "refresh_token" - "" - } + body = {"refresh_token": refresh_token, "grant_type": "refresh_token" ""} response = requests.post(self.token_url, body, auth=self.basic_auth) @@ -95,19 +92,30 @@ def refresh_access_token(self, refresh_token): raise Exception("Failed to refresh access token: " + response.text) return response.json() - - def register_client(self, client_name, resource_ids, scopes, access_token_validity_seconds=600, refresh_token_validity_seconds=3600): - new_token_response = self.get_new_token(["client_dynamic_registration"]) + + def register_client( + self, + client_name, + resource_ids, + scopes, + access_token_validity_seconds=600, + refresh_token_validity_seconds=3600, + ): + new_token_response = self.get_new_token(["client_dynamic_registration"]) access_token = new_token_response["access_token"] body = { "client_name": client_name, - "grant_types": ["urn:ietf:params:oauth:grant-type:token-exchange", "refresh_token", "client_credentials"], + "grant_types": [ + "urn:ietf:params:oauth:grant-type:token-exchange", + "refresh_token", + "client_credentials", + ], "token_endpoint_auth_method": "client_secret_basic", "scope": scopes, "resources": resource_ids, "access_token_validity_seconds": access_token_validity_seconds, - "refresh_token_validity_seconds": refresh_token_validity_seconds + "refresh_token_validity_seconds": refresh_token_validity_seconds, } headers = {"Authorization": f"Bearer {access_token}"} @@ -117,10 +125,10 @@ def register_client(self, client_name, resource_ids, scopes, access_token_validi raise Exception("Failed to register a new client: " + response.text) response_data = response.json() - + return { "client_id": response_data["client_id"], - "client_secret": response_data["client_secret"] + "client_secret": response_data["client_secret"], } def deregister_self(self): @@ -128,10 +136,13 @@ def deregister_self(self): access_token = new_token_response["access_token"] headers = {"Authorization": f"Bearer {access_token}"} - base_register_url = self.register_url if self.register_url.endswith("/") else self.register_url + "/" + base_register_url = ( + self.register_url + if self.register_url.endswith("/") + else self.register_url + "/" + ) url = urlparse.urljoin(base_register_url, self.client_id) response = requests.delete(url, headers=headers) if response.status_code != 204: raise Exception("Failed to deregister the client: " + response.text) - \ No newline at end of file diff --git a/tests/tests.py b/tests/tests.py index 954d501..59c8e3d 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -1,6 +1,6 @@ from typing import Optional import snakemake.common.tests -from snakemake_executor_plugin_tes import ExecutorSettings +from snakemake_executor_plugin_auth_tes import ExecutorSettings from snakemake_interface_executor_plugins.settings import ExecutorSettingsBase