From e0198488cd686fb40fb45a019929bc482ab2e2bb Mon Sep 17 00:00:00 2001 From: Ethan Ho <53266718+ethho@users.noreply.github.com> Date: Mon, 7 Oct 2024 09:54:02 -0500 Subject: [PATCH 1/5] Checkout axon module https://github.com/dj-sciops/djsciops-python/blob/17ccc19ad0e235ea83168ac9161a46cbfec308b5/djsciops/authentication.py --- datajoint/axon.py | 453 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 453 insertions(+) create mode 100644 datajoint/axon.py diff --git a/datajoint/axon.py b/datajoint/axon.py new file mode 100644 index 00000000..b6ff9a0a --- /dev/null +++ b/datajoint/axon.py @@ -0,0 +1,453 @@ +import base64 +from datetime import datetime, timezone +import json +import logging +import os +import sys +import flask +import webbrowser +import urllib +import http.client +import botocore +import botocore.config +from .log import log +from time import time +import multiprocessing + +import boto3 +from botocore.credentials import RefreshableCredentials +from botocore.session import get_session +from djsciops import settings as djsciops_settings +try: + # Python 3.4+ + if sys.platform.startswith('win'): + import multiprocessing.popen_spawn_win32 as forking + else: + import multiprocessing.popen_fork as forking +except ImportError: + import multiprocessing.forking as forking + +LOOKUP_SERVICE_ALLOWED_ORIGIN = "https://ops.datajoint.io" +LOOKUP_SERVICE_DOMAIN = "ops.datajoint.io" +LOOKUP_SERVICE_ROUTE = "/social-login/api/user" +# Everything LOOKUP_SERVICE_AUTH is changed, need to change: +# https://github.com/datajoint-company/dj-gitops/blob/main/applications/k8s/deployments/ops.datajoint.io/social_login_interceptor/client_store_secrets.yaml +LOOKUP_SERVICE_AUTH = { + "https://accounts.datajoint.io/auth/": { + "PROVIDER": "accounts.datajoint.io", + "ROUTE": "/auth", + }, + "https://accounts.datajoint.com/realms/datajoint": { + "PROVIDER": "accounts.datajoint.com", + "ROUTE": "/realms/datajoint/protocol/openid-connect", + }, + "https://keycloak-qa.datajoint.io/realms/datajoint": { + "PROVIDER": "keycloak-qa.datajoint.io", + "ROUTE": "/realms/datajoint/protocol/openid-connect", + }, +} +issuer = djsciops_settings.get_config()["djauth"]["issuer"] + + +def _client_login( + auth_client_id: str, + auth_client_secret: str, + auth_provider_domain: str = LOOKUP_SERVICE_AUTH[issuer]["PROVIDER"], + auth_provider_token_route: str = f"{LOOKUP_SERVICE_AUTH[issuer]['ROUTE']}/token", +): + connection = http.client.HTTPSConnection(auth_provider_domain) + headers = {"Content-type": "application/x-www-form-urlencoded"} + body = urllib.parse.urlencode( + { + "grant_type": "client_credentials", + "client_id": auth_client_id, + "client_secret": auth_client_secret, + } + ) + connection.request("POST", auth_provider_token_route, body, headers) + jwt_payload = json.loads(connection.getresponse().read().decode()) + return jwt_payload["access_token"] + + +def start_server(q: multiprocessing.Queue, callback_port: int): + """ + Starts Flask HTTP server. + Since werkzeug 2.0.3 has vulnerability issue, has to upgrade and + werkzeug.environ.shutdown_server() is deprecated after 2.0.3. + """ + app = flask.Flask("browser-interface") + + @app.route("/login-cancelled") + def login_cancelled(): + """ + Accepts requests which will cancel the user login. + """ + q.put({"cancelled": True, "code": None}) + return """ + + + + + + + + + """ + + @app.route("/login-completed") + def login_completed(): + """ + Redirect after user has successfully logged in. + """ + code = flask.request.args.get("code") + q.put({"cancelled": False, "code": code}) + return """ + + + + + + DataJoint login completed! Feel free to close this tab if it did not close automatically. + + """ + + app.run(host="0.0.0.0", port=callback_port, debug=False) + + +def _oidc_login( + auth_client_id: str, + auth_url: str = f"https://{LOOKUP_SERVICE_AUTH[issuer]['PROVIDER']}{LOOKUP_SERVICE_AUTH[issuer]['ROUTE']}/auth", + lookup_service_allowed_origin: str = LOOKUP_SERVICE_ALLOWED_ORIGIN, + lookup_service_domain: str = LOOKUP_SERVICE_DOMAIN, + lookup_service_route: str = LOOKUP_SERVICE_ROUTE, + lookup_service_auth_provider: str = LOOKUP_SERVICE_AUTH[issuer]["PROVIDER"], + code_challenge: str = "ubNp9Y0Y_FOENQ_Pz3zppyv2yyt0XtJsaPqUgGW9heA", + code_challenge_method: str = "S256", + code_verifier: str = "kFn5ZwL6ggOwU1OzKx0E1oZibIMC1ZbMC1WEUXcCV5mFoi015I9nB9CrgUJRkc3oiQT8uBbrvRvVzahM8OS0xJ51XdYaTdAlFeHsb6OZuBPmLD400ozVPrwCE192rtqI", + callback_port: int = 28282, + delay_seconds: int = 60, +): + """ + Primary OIDC login flow. + """ + + # Prepare user + log.warning( + "User authentication required to use DataJoint SciOps CLI tools. We'll be " + "launching a web browser to authenticate your DataJoint account." + ) + # allocate variables for access and context + code = None + cancelled = True + # Prepare HTTP server to communicate with browser + logging.getLogger("werkzeug").setLevel(logging.ERROR) + + q = multiprocessing.Queue() + server = multiprocessing.Process( + target=start_server, + args=( + q, + callback_port, + ), + ) + server.start() + # build url + query_params = dict( + scope="openid", + response_type="code", + client_id=auth_client_id, + code_challenge=code_challenge, + code_challenge_method=code_challenge_method, + redirect_uri=f"http://localhost:{callback_port}/login-completed", + ) + link = f"{auth_url}?{urllib.parse.urlencode(query_params)}" + # attempt to launch browser or provide instructions + browser_available = True + try: + webbrowser.get() + except webbrowser.Error: + browser_available = False + if browser_available: + log.info("Browser available. Launching...") + webbrowser.open(link, new=2) + else: + log.warning( + "Browser unavailable. On a browser client, please navigate to the " + f"following link to login: {link}" + ) + # cancel_process = multiprocessing.Process( + # target=_delayed_request, + # kwargs=dict( + # url=f"http://localhost:{callback_port}/login-cancelled", + # delay=delay_seconds, + # ), + # ) + # # cancel_process.start() + queue_in_flask = q.get(block=True) + cancelled = queue_in_flask["cancelled"] + code = queue_in_flask["code"] + # server.terminate() + # cancel_process.terminate() + # received a response + if cancelled: + server.terminate() + raise Exception( + "User login cancelled. User must be logged in to use DataJoint SciOps CLI tools." + ) + else: + # generate user info + connection = http.client.HTTPSConnection(lookup_service_domain) + headers = { + "Content-type": "application/json", + "Origin": lookup_service_allowed_origin, + } + body = json.dumps( + { + "auth_provider": lookup_service_auth_provider, + "redirect_uri": f"http://localhost:{callback_port}/login-completed", + "code_verifier": code_verifier, + "client_id": auth_client_id, + "code": code, + } + ) + connection.request("POST", lookup_service_route, body, headers) + response = connection.getresponse().read().decode() + try: + userdata = json.loads(response) + log.info("User successfully authenticated.") + return ( + userdata["access_token"], + userdata["username"], + userdata["refresh_token"], + ) + except json.decoder.JSONDecodeError: + log.error(response) + raise Exception("Login failed") + finally: + server.terminate() + + +def _delayed_request(*, url: str, delay: str = 0): + time.sleep(delay) + return urllib.request.urlopen(url) + + +def _decode_bearer_token(bearer_token): + log.debug(f"bearer_token: {bearer_token}") + jwt_data = json.loads( + base64.b64decode((bearer_token.split(".")[1] + "==").encode()).decode() + ) + log.debug(f"jwt_data: {jwt_data}") + return jwt_data + + +if sys.platform.startswith('win'): + # First define a modified version of Popen. + class _Popen(forking.Popen): + def __init__(self, *args, **kw): + if hasattr(sys, 'frozen'): + # We have to set original _MEIPASS2 value from sys._MEIPASS + # to get --onefile mode working. + os.putenv('_MEIPASS2', sys._MEIPASS) + try: + super(_Popen, self).__init__(*args, **kw) + finally: + if hasattr(sys, 'frozen'): + # On some platforms (e.g. AIX) 'os.unsetenv()' is not + # available. In those cases we cannot delete the variable + # but only set it to the empty string. The bootloader + # can handle this case. + if hasattr(os, 'unsetenv'): + os.unsetenv('_MEIPASS2') + else: + os.putenv('_MEIPASS2', '') + + # Second override 'Popen' class with our modified version. + forking.Popen = _Popen + + +class Session: + def __init__( + self, + aws_account_id: str, + s3_role: str, + auth_client_id: str, + auth_client_secret: str = None, + bearer_token: str = None, + ): + self.aws_account_id = aws_account_id + self.s3_role = s3_role + self.auth_client_id = auth_client_id + self.auth_client_secret = auth_client_secret + self.sts_arn = f"arn:aws:iam::{aws_account_id}:role/{s3_role}" + self.user = "client_credentials" + self.refresh_token = None + self.jwt = None + # OAuth2.0 authorization + if auth_client_secret: + self.bearer_token = _client_login( + auth_client_id=self.auth_client_id, + auth_client_secret=self.auth_client_secret, + ) + self.jwt = _decode_bearer_token(self.bearer_token) + elif not bearer_token: + self.bearer_token, self.user, self.refresh_token = _oidc_login( + auth_client_id=auth_client_id, + ) + self.jwt = _decode_bearer_token(self.bearer_token) + else: + self.jwt = _decode_bearer_token(self.bearer_token) + time_to_live = (self.jwt["exp"] - datetime.utcnow().timestamp()) / 60 / 60 + log.info( + f"Reusing provided bearer token with a life of {time_to_live} [HR]" + ) + self.bearer_token, self.user = (bearer_token, self.jwt["sub"]) + + self.sts_token = RefreshableBotoSession(session=self).refreshable_session() + self.s3 = self.sts_token.resource( + "s3", config=botocore.config.Config(s3={"use_accelerate_endpoint": True}) + ) + + def refresh_bearer_token( + self, + lookup_service_allowed_origin: str = LOOKUP_SERVICE_ALLOWED_ORIGIN, + lookup_service_domain: str = LOOKUP_SERVICE_DOMAIN, + lookup_service_route: str = LOOKUP_SERVICE_ROUTE, + lookup_service_auth_provider: str = LOOKUP_SERVICE_AUTH[issuer]["PROVIDER"], + ): + if self.auth_client_secret: + self.bearer_token = _client_login( + auth_client_id=self.auth_client_id, + auth_client_secret=self.auth_client_secret, + ) + self.jwt = _decode_bearer_token(self.bearer_token) + else: + # generate user info + connection = http.client.HTTPSConnection(lookup_service_domain) + headers = { + "Content-type": "application/json", + "Origin": lookup_service_allowed_origin, + } + body = json.dumps( + { + "auth_provider": lookup_service_auth_provider, + "refresh_token": self.refresh_token, + "client_id": self.auth_client_id, + } + ) + log.debug(f"Original refresh_token: {self.refresh_token}") + connection.request("PATCH", lookup_service_route, body, headers) + response = connection.getresponse().read().decode() + log.debug(f"response: {response}") + userdata = json.loads(response) + log.debug("User successfully reauthenticated.") + self.bearer_token = userdata["access_token"] + self.user = userdata["username"] + self.refresh_token = userdata["refresh_token"] + log.debug(f"refresh_token: {self.refresh_token}") + self.jwt = _decode_bearer_token(self.bearer_token) + + +class RefreshableBotoSession: + """ + Boto Helper class which lets us create refreshable session, so that we can cache the client or resource. + + Usage + ----- + session = RefreshableBotoSession().refreshable_session() + + client = session.client("s3") # we now can cache this client object without worrying about expiring credentials + """ + + def __init__(self, session, session_ttl: int = 12 * 60 * 60): + """ + Initialize `RefreshableBotoSession` + + Parameters + ---------- + session : Session + The session object to refresh + + session_ttl : int (optional) + An integer number to set the TTL for each session. Beyond this session, it will renew the token. + """ + + self.session = session + self.session_ttl = session_ttl + + def __get_session_credentials(self): + """ + Get session credentials + """ + sts_client = boto3.client(service_name="sts") + try: + sts_response = sts_client.assume_role_with_web_identity( + RoleArn=self.session.sts_arn, + RoleSessionName=self.session.user, + WebIdentityToken=self.session.bearer_token, + DurationSeconds=self.session_ttl, + ).get("Credentials") + except botocore.exceptions.ClientError as error: + log.debug(f"Error code: {error.response['Error']['Code']}") + if error.response["Error"]["Code"] == "ExpiredTokenException": + log.debug("Bearer token has expired... Reauthenticating now") + self.session.refresh_bearer_token() + sts_response = sts_client.assume_role_with_web_identity( + RoleArn=self.session.sts_arn, + RoleSessionName=self.session.user, + WebIdentityToken=self.session.bearer_token, + DurationSeconds=self.session_ttl, + ).get("Credentials") + else: + raise error + # Token expire time logging + bearer_expire_time = datetime.fromtimestamp(self.session.jwt["exp"]).strftime( + "%H:%M:%S" + ) + log.debug(f"Bearer token expire time: {bearer_expire_time}") + if "sts_token" in self.session.__dict__: + sts_expire_time = ( + self.session.sts_token._session.get_credentials() + .__dict__["_expiry_time"] + .replace(tzinfo=timezone.utc) + .astimezone(tz=None) + .strftime("%H:%M:%S") + ) + log.debug(f"STS token expire time: {sts_expire_time}") + + credentials = { + "access_key": sts_response.get("AccessKeyId"), + "secret_key": sts_response.get("SecretAccessKey"), + "token": sts_response.get("SessionToken"), + "expiry_time": sts_response.get("Expiration").isoformat(), + } + + return credentials + + def refreshable_session(self) -> boto3.Session: + """ + Get refreshable boto3 session. + """ + # get refreshable credentials + refreshable_credentials = RefreshableCredentials.create_from_metadata( + metadata=self.__get_session_credentials(), + refresh_using=self.__get_session_credentials, + method="sts-assume-role-with-web-identity", + ) + + # attach refreshable credentials current session + session = get_session() + session._credentials = refreshable_credentials + autorefresh_session = boto3.Session(botocore_session=session) + + return autorefresh_session \ No newline at end of file From 04cd6a74e01c401303ece3f7f2a938470c10bbe7 Mon Sep 17 00:00:00 2001 From: Ethan Ho <53266718+ethho@users.noreply.github.com> Date: Mon, 7 Oct 2024 10:45:55 -0500 Subject: [PATCH 2/5] Tweaks and minimal pytest --- datajoint/axon.py | 207 +-------------------------------------------- tests/test_axon.py | 14 +++ 2 files changed, 17 insertions(+), 204 deletions(-) create mode 100644 tests/test_axon.py diff --git a/datajoint/axon.py b/datajoint/axon.py index b6ff9a0a..92668824 100644 --- a/datajoint/axon.py +++ b/datajoint/axon.py @@ -69,179 +69,6 @@ def _client_login( return jwt_payload["access_token"] -def start_server(q: multiprocessing.Queue, callback_port: int): - """ - Starts Flask HTTP server. - Since werkzeug 2.0.3 has vulnerability issue, has to upgrade and - werkzeug.environ.shutdown_server() is deprecated after 2.0.3. - """ - app = flask.Flask("browser-interface") - - @app.route("/login-cancelled") - def login_cancelled(): - """ - Accepts requests which will cancel the user login. - """ - q.put({"cancelled": True, "code": None}) - return """ - - - - - - - - - """ - - @app.route("/login-completed") - def login_completed(): - """ - Redirect after user has successfully logged in. - """ - code = flask.request.args.get("code") - q.put({"cancelled": False, "code": code}) - return """ - - - - - - DataJoint login completed! Feel free to close this tab if it did not close automatically. - - """ - - app.run(host="0.0.0.0", port=callback_port, debug=False) - - -def _oidc_login( - auth_client_id: str, - auth_url: str = f"https://{LOOKUP_SERVICE_AUTH[issuer]['PROVIDER']}{LOOKUP_SERVICE_AUTH[issuer]['ROUTE']}/auth", - lookup_service_allowed_origin: str = LOOKUP_SERVICE_ALLOWED_ORIGIN, - lookup_service_domain: str = LOOKUP_SERVICE_DOMAIN, - lookup_service_route: str = LOOKUP_SERVICE_ROUTE, - lookup_service_auth_provider: str = LOOKUP_SERVICE_AUTH[issuer]["PROVIDER"], - code_challenge: str = "ubNp9Y0Y_FOENQ_Pz3zppyv2yyt0XtJsaPqUgGW9heA", - code_challenge_method: str = "S256", - code_verifier: str = "kFn5ZwL6ggOwU1OzKx0E1oZibIMC1ZbMC1WEUXcCV5mFoi015I9nB9CrgUJRkc3oiQT8uBbrvRvVzahM8OS0xJ51XdYaTdAlFeHsb6OZuBPmLD400ozVPrwCE192rtqI", - callback_port: int = 28282, - delay_seconds: int = 60, -): - """ - Primary OIDC login flow. - """ - - # Prepare user - log.warning( - "User authentication required to use DataJoint SciOps CLI tools. We'll be " - "launching a web browser to authenticate your DataJoint account." - ) - # allocate variables for access and context - code = None - cancelled = True - # Prepare HTTP server to communicate with browser - logging.getLogger("werkzeug").setLevel(logging.ERROR) - - q = multiprocessing.Queue() - server = multiprocessing.Process( - target=start_server, - args=( - q, - callback_port, - ), - ) - server.start() - # build url - query_params = dict( - scope="openid", - response_type="code", - client_id=auth_client_id, - code_challenge=code_challenge, - code_challenge_method=code_challenge_method, - redirect_uri=f"http://localhost:{callback_port}/login-completed", - ) - link = f"{auth_url}?{urllib.parse.urlencode(query_params)}" - # attempt to launch browser or provide instructions - browser_available = True - try: - webbrowser.get() - except webbrowser.Error: - browser_available = False - if browser_available: - log.info("Browser available. Launching...") - webbrowser.open(link, new=2) - else: - log.warning( - "Browser unavailable. On a browser client, please navigate to the " - f"following link to login: {link}" - ) - # cancel_process = multiprocessing.Process( - # target=_delayed_request, - # kwargs=dict( - # url=f"http://localhost:{callback_port}/login-cancelled", - # delay=delay_seconds, - # ), - # ) - # # cancel_process.start() - queue_in_flask = q.get(block=True) - cancelled = queue_in_flask["cancelled"] - code = queue_in_flask["code"] - # server.terminate() - # cancel_process.terminate() - # received a response - if cancelled: - server.terminate() - raise Exception( - "User login cancelled. User must be logged in to use DataJoint SciOps CLI tools." - ) - else: - # generate user info - connection = http.client.HTTPSConnection(lookup_service_domain) - headers = { - "Content-type": "application/json", - "Origin": lookup_service_allowed_origin, - } - body = json.dumps( - { - "auth_provider": lookup_service_auth_provider, - "redirect_uri": f"http://localhost:{callback_port}/login-completed", - "code_verifier": code_verifier, - "client_id": auth_client_id, - "code": code, - } - ) - connection.request("POST", lookup_service_route, body, headers) - response = connection.getresponse().read().decode() - try: - userdata = json.loads(response) - log.info("User successfully authenticated.") - return ( - userdata["access_token"], - userdata["username"], - userdata["refresh_token"], - ) - except json.decoder.JSONDecodeError: - log.error(response) - raise Exception("Login failed") - finally: - server.terminate() - - -def _delayed_request(*, url: str, delay: str = 0): - time.sleep(delay) - return urllib.request.urlopen(url) - - def _decode_bearer_token(bearer_token): log.debug(f"bearer_token: {bearer_token}") jwt_data = json.loads( @@ -251,31 +78,6 @@ def _decode_bearer_token(bearer_token): return jwt_data -if sys.platform.startswith('win'): - # First define a modified version of Popen. - class _Popen(forking.Popen): - def __init__(self, *args, **kw): - if hasattr(sys, 'frozen'): - # We have to set original _MEIPASS2 value from sys._MEIPASS - # to get --onefile mode working. - os.putenv('_MEIPASS2', sys._MEIPASS) - try: - super(_Popen, self).__init__(*args, **kw) - finally: - if hasattr(sys, 'frozen'): - # On some platforms (e.g. AIX) 'os.unsetenv()' is not - # available. In those cases we cannot delete the variable - # but only set it to the empty string. The bootloader - # can handle this case. - if hasattr(os, 'unsetenv'): - os.unsetenv('_MEIPASS2') - else: - os.putenv('_MEIPASS2', '') - - # Second override 'Popen' class with our modified version. - forking.Popen = _Popen - - class Session: def __init__( self, @@ -300,14 +102,10 @@ def __init__( auth_client_secret=self.auth_client_secret, ) self.jwt = _decode_bearer_token(self.bearer_token) - elif not bearer_token: - self.bearer_token, self.user, self.refresh_token = _oidc_login( - auth_client_id=auth_client_id, - ) - self.jwt = _decode_bearer_token(self.bearer_token) else: + assert bearer_token, "Bearer token is required for user authentication." self.jwt = _decode_bearer_token(self.bearer_token) - time_to_live = (self.jwt["exp"] - datetime.utcnow().timestamp()) / 60 / 60 + time_to_live = (self.jwt["exp"] - datetime.now(datetime.timezone.utc).timestamp()) / 60 / 60 log.info( f"Reusing provided bearer token with a life of {time_to_live} [HR]" ) @@ -332,6 +130,7 @@ def refresh_bearer_token( ) self.jwt = _decode_bearer_token(self.bearer_token) else: + assert self.refresh_token, "Refresh token is required for user authentication." # generate user info connection = http.client.HTTPSConnection(lookup_service_domain) headers = { diff --git a/tests/test_axon.py b/tests/test_axon.py new file mode 100644 index 00000000..1c4ec5f4 --- /dev/null +++ b/tests/test_axon.py @@ -0,0 +1,14 @@ +from datajoint.axon import Session +import pytest +import moto3 + + +class TestSession: + def test_can_init(self): + session = Session( + aws_account_id="123456789012", + s3_role="test-role", + auth_client_id="test-client-id", + auth_client_secret="test-client-secret", + ) + assert session.bearer_token, "Bearer token not set" From de5d8e67f1c1564d49663d7d6f51ef45a0df6326 Mon Sep 17 00:00:00 2001 From: Ethan Ho <53266718+ethho@users.noreply.github.com> Date: Mon, 7 Oct 2024 20:33:03 +0000 Subject: [PATCH 3/5] Pytests working --- datajoint/axon.py | 6 +++--- docker-compose.yaml | 2 +- pyproject.toml | 4 ++++ tests/test_axon.py | 11 +++++++++-- 4 files changed, 17 insertions(+), 6 deletions(-) diff --git a/datajoint/axon.py b/datajoint/axon.py index 92668824..419f9733 100644 --- a/datajoint/axon.py +++ b/datajoint/axon.py @@ -10,14 +10,13 @@ import http.client import botocore import botocore.config -from .log import log +from .logging import logger as log from time import time import multiprocessing import boto3 from botocore.credentials import RefreshableCredentials from botocore.session import get_session -from djsciops import settings as djsciops_settings try: # Python 3.4+ if sys.platform.startswith('win'): @@ -46,7 +45,7 @@ "ROUTE": "/realms/datajoint/protocol/openid-connect", }, } -issuer = djsciops_settings.get_config()["djauth"]["issuer"] +issuer = "https://keycloak-qa.datajoint.io/realms/datajoint" def _client_login( @@ -66,6 +65,7 @@ def _client_login( ) connection.request("POST", auth_provider_token_route, body, headers) jwt_payload = json.loads(connection.getresponse().read().decode()) + assert "access_token" in jwt_payload, f"Access token not found in response: {jwt_payload=}." return jwt_payload["access_token"] diff --git a/docker-compose.yaml b/docker-compose.yaml index 9088dc53..d8bffb70 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -75,5 +75,5 @@ services: set -e pip install -q -e ".[test]" pip freeze | grep datajoint - pytest --cov-report term-missing --cov=datajoint tests + pytest -vv -xs tests/test_axon.py diff --git a/pyproject.toml b/pyproject.toml index 097d168e..fac3955a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,10 @@ test = [ "pytest-cov", "black==24.2.0", "flake8", + "moto[s3]>=4.2.13", +] +axon = [ + "boto3", ] [project.urls] diff --git a/tests/test_axon.py b/tests/test_axon.py index 1c4ec5f4..60219f6d 100644 --- a/tests/test_axon.py +++ b/tests/test_axon.py @@ -1,12 +1,19 @@ from datajoint.axon import Session import pytest -import moto3 +import boto3 +from moto import mock_aws +@pytest.fixture +def moto_account_id(): + """Default account ID for moto""" + return "123456789012" + +@mock_aws class TestSession: def test_can_init(self): session = Session( - aws_account_id="123456789012", + aws_account_id=moto_account_id, s3_role="test-role", auth_client_id="test-client-id", auth_client_secret="test-client-secret", From 834dca6e5732e027f3a012057a0950176a0d0da3 Mon Sep 17 00:00:00 2001 From: Ethan Ho <53266718+ethho@users.noreply.github.com> Date: Mon, 7 Oct 2024 21:50:19 +0000 Subject: [PATCH 4/5] Client credentials flow working --- datajoint/axon.py | 72 ++++++++++++++++++++++++++++++++++- pyproject.toml | 3 ++ tests/test_axon.py | 95 +++++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 164 insertions(+), 6 deletions(-) diff --git a/datajoint/axon.py b/datajoint/axon.py index 419f9733..f9351e33 100644 --- a/datajoint/axon.py +++ b/datajoint/axon.py @@ -6,6 +6,11 @@ import sys import flask import webbrowser +import requests_oauthlib +import oauthlib +from oauthlib.oauth2 import BackendApplicationClient +from requests_oauthlib import OAuth2Session +import requests import urllib import http.client import botocore @@ -249,4 +254,69 @@ def refreshable_session(self) -> boto3.Session: session._credentials = refreshable_credentials autorefresh_session = boto3.Session(botocore_session=session) - return autorefresh_session \ No newline at end of file + return autorefresh_session + +def get_s3_client( + aws_account_id: str, + s3_role: str, + auth_client_id: str, + auth_client_secret: str = None, + bearer_token: str = None, + well_known_url: str = "https://keycloak-qa.datajoint.io/realms/datajoint/.well-known/openid-configuration", +): + """ + Get S3 client with the given credentials. + + Parameters + ---------- + aws_account_id : str + AWS account ID + + s3_role : str + S3 role + + auth_client_id : str + Auth client ID + + auth_client_secret : str (optional) + Auth client secret + + bearer_token : str (optional) + Bearer token + + well_known_url : str (optional) + Well-known URL for the OpenID configuration + + Returns + ------- + boto3.client + S3 client + """ + # Get token URL from well-known URL + well_known_resp = requests.get(well_known_url) + assert well_known_resp.status_code == 200, f"Failed to get well-known URL: {well_known_url}" + well_known_data = well_known_resp.json() + token_url = well_known_data.get("token_endpoint") + assert token_url, f"Token URL not found in well-known data: {well_known_data=}" + + # Client credentials flow + token = _client_credentials_flow( + auth_client_id, auth_client_secret, token_url + ) + + # + + +def _client_credentials_flow(client_id, client_secret, token_url): + client = BackendApplicationClient(client_id=client_id) + oauth = OAuth2Session(client=client) + try: + return oauth.fetch_token( + token_url=token_url, + client_id=client_id, + client_secret=client_secret, + ) + except oauthlib.oauth2.rfc6749.errors.UnauthorizedClientError as e: + msg = f"Error getting OAuth2 client: {e.description}" + log.error(msg) + raise ValueError(msg) from e diff --git a/pyproject.toml b/pyproject.toml index fac3955a..8bce31ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,12 +45,15 @@ classifiers = [ test = [ "pytest", "pytest-cov", + "pytest-dotenv", "black==24.2.0", "flake8", "moto[s3]>=4.2.13", ] axon = [ "boto3", + "requests_oauthlib", + "requests", ] [project.urls] diff --git a/tests/test_axon.py b/tests/test_axon.py index 60219f6d..d47970bb 100644 --- a/tests/test_axon.py +++ b/tests/test_axon.py @@ -1,7 +1,11 @@ -from datajoint.axon import Session +import os +from datajoint.axon import Session, get_s3_client +import json import pytest import boto3 from moto import mock_aws +import dotenv +dotenv.load_dotenv(dotenv.find_dotenv()) @pytest.fixture @@ -9,13 +13,94 @@ def moto_account_id(): """Default account ID for moto""" return "123456789012" + +@pytest.fixture +def keycloak_client_secret(): + secret = os.getenv("OAUTH_CLIENT_SECRET") + if not secret: + pytest.skip("No client secret found") + else: + return secret + + +@pytest.fixture +def keycloak_client_id(): + return os.getenv("OAUTH_CLIENT_ID", "works") + + +@pytest.fixture(scope="function") +def aws_credentials(): + """Mocked AWS Credentials for moto.""" + os.environ["AWS_ACCESS_KEY_ID"] = "testing" + os.environ["AWS_SECRET_ACCESS_KEY"] = "testing" + os.environ["AWS_SECURITY_TOKEN"] = "testing" + os.environ["AWS_SESSION_TOKEN"] = "testing" + os.environ["AWS_DEFAULT_REGION"] = "us-east-1" + + +@pytest.fixture(scope="function") +def s3_client(aws_credentials): + """ + Return a mocked S3 client + """ + with mock_aws(): + yield boto3.client("s3", region_name="us-east-1") + + +@pytest.fixture(scope="function") +def iam_client(aws_credentials): + """ + Return a mocked S3 client + """ + with mock_aws(): + yield boto3.client("iam", region_name="us-east-1") + + +@pytest.fixture +def s3_policy(iam_client): + """Create a policy with S3 read access using boto3.""" + policy_doc = { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": "s3:GetObject", + "Resource": "arn:aws:s3:::mybucket/*", + } + ], + } + return iam_client.create_policy( + PolicyName="test-policy", + Path="/", + PolicyDocument=json.dumps(policy_doc), + Description="Test policy", + ) + +@pytest.fixture +def s3_role(moto_account_id, s3_policy): + """Create a mock role and policy document for testing""" + return "123456789012" + + @mock_aws +@pytest.mark.skip class TestSession: - def test_can_init(self): + def test_can_init(self, s3_role, keycloak_client_id, keycloak_client_secret, moto_account_id): session = Session( aws_account_id=moto_account_id, - s3_role="test-role", - auth_client_id="test-client-id", - auth_client_secret="test-client-secret", + s3_role=s3_role, + auth_client_id=keycloak_client_id, + auth_client_secret=keycloak_client_secret, ) assert session.bearer_token, "Bearer token not set" + +def test_get_s3_client(s3_role, keycloak_client_id, keycloak_client_secret, moto_account_id): + client = get_s3_client( + auth_client_id=keycloak_client_id, + auth_client_secret=keycloak_client_secret, + aws_account_id=moto_account_id, + s3_role=s3_role, + bearer_token=None, + ) + assert client + From e0651c9703912587949414b30589f6722705674b Mon Sep 17 00:00:00 2001 From: Ethan Ho <53266718+ethho@users.noreply.github.com> Date: Wed, 9 Oct 2024 09:07:19 -0500 Subject: [PATCH 5/5] WIP --- datajoint/axon.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/datajoint/axon.py b/datajoint/axon.py index f9351e33..ccf685d9 100644 --- a/datajoint/axon.py +++ b/datajoint/axon.py @@ -96,7 +96,7 @@ def __init__( self.s3_role = s3_role self.auth_client_id = auth_client_id self.auth_client_secret = auth_client_secret - self.sts_arn = f"arn:aws:iam::{aws_account_id}:role/{s3_role}" + self.sts_arn = f"arn:aws:iam::{self.aws_account_id}:role/{self.s3_role}" self.user = "client_credentials" self.refresh_token = None self.jwt = None @@ -304,10 +304,11 @@ def get_s3_client( auth_client_id, auth_client_secret, token_url ) + breakpoint() # -def _client_credentials_flow(client_id, client_secret, token_url): +def _client_credentials_flow(client_id, client_secret, token_url) -> oauthlib.oauth2.rfc6749.tokens.OAuth2Token: client = BackendApplicationClient(client_id=client_id) oauth = OAuth2Session(client=client) try: